arwin0727 commited on
Commit
9110de1
·
verified ·
1 Parent(s): e2d018c

Upload miner.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. miner.py +426 -0
miner.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import json
5
+ import os
6
+ import sys
7
+ import wave
8
+ from pathlib import Path
9
+ from typing import Any, Mapping
10
+
11
+ import numpy as np
12
+
13
+ REPO = Path(__file__).resolve().parent
14
+ _VOCENCE_YAML = "vocence_config.yaml"
15
+ _MAX_AUDIO_SEC = 30
16
+ _VOCENCE_OUTPUT_HZ = 24000
17
+
18
+
19
+ def _resample_to_hz_mono_f32(waveform: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
20
+ """Linear / polyphase resample mono float32 ``[-1, 1]`` to ``target_sr`` (uses librosa)."""
21
+ if orig_sr == target_sr:
22
+ return np.asarray(waveform, dtype=np.float32)
23
+ import librosa
24
+
25
+ y = np.asarray(waveform, dtype=np.float32)
26
+ if y.ndim > 1:
27
+ y = np.mean(y, axis=-1).astype(np.float32)
28
+ return librosa.resample(y, orig_sr=int(orig_sr), target_sr=int(target_sr)).astype(np.float32)
29
+
30
+
31
+ def load_tts_inference_engine(
32
+ *,
33
+ llama_checkpoint_path: str,
34
+ decoder_checkpoint_path: str,
35
+ decoder_config_name: str = "modded_dac_vq",
36
+ device: str = "cuda",
37
+ half: bool = False,
38
+ compile_model: bool = False,
39
+ ) -> Any:
40
+
41
+ from tools.server.model_manager import ModelManager
42
+
43
+ manager = ModelManager(
44
+ mode="tts",
45
+ device=device,
46
+ half=half,
47
+ compile=compile_model,
48
+ llama_checkpoint_path=llama_checkpoint_path,
49
+ decoder_checkpoint_path=decoder_checkpoint_path,
50
+ decoder_config_name=decoder_config_name,
51
+ )
52
+ return manager.tts_inference_engine
53
+
54
+
55
+ def synthesize_wav(
56
+ engine: Any,
57
+ *,
58
+ text: str,
59
+ reference_audio_path: str | None = None,
60
+ reference_text: str | None = None,
61
+ max_new_tokens: int = 1024,
62
+ chunk_length: int = 200,
63
+ top_p: float = 0.8,
64
+ repetition_penalty: float = 1.1,
65
+ temperature: float = 0.8,
66
+ seed: int | None = None,
67
+ ) -> tuple[int, np.ndarray]:
68
+ """One non-streaming TTS request; returns ``(sample_rate_hz, mono float32)``."""
69
+ from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
70
+
71
+ if bool(reference_audio_path) ^ bool(reference_text):
72
+ raise ValueError("provide both reference_audio_path and reference_text, or neither")
73
+
74
+ references: list[ServeReferenceAudio] = []
75
+ if reference_audio_path:
76
+ ref_path = Path(reference_audio_path)
77
+ if not ref_path.is_file():
78
+ raise FileNotFoundError(f"reference audio not found: {ref_path}")
79
+ references = [
80
+ ServeReferenceAudio(
81
+ audio=ref_path.read_bytes(),
82
+ text=reference_text or "",
83
+ )
84
+ ]
85
+
86
+ req = ServeTTSRequest(
87
+ text=text,
88
+ references=references,
89
+ reference_id=None,
90
+ max_new_tokens=max_new_tokens,
91
+ chunk_length=chunk_length,
92
+ top_p=top_p,
93
+ repetition_penalty=repetition_penalty,
94
+ temperature=temperature,
95
+ format="wav",
96
+ streaming=False,
97
+ seed=seed,
98
+ )
99
+
100
+ sample_rate: int | None = None
101
+ audio: np.ndarray | None = None
102
+ for result in engine.inference(req):
103
+ if result.code == "error":
104
+ err = result.error or "unknown inference error"
105
+ raise RuntimeError(str(err))
106
+ if result.code == "final" and result.audio is not None:
107
+ sample_rate, audio = result.audio
108
+ break
109
+
110
+ if sample_rate is None or audio is None:
111
+ raise RuntimeError("no audio produced")
112
+
113
+ arr = np.asarray(audio, dtype=np.float32)
114
+ if arr.ndim > 1:
115
+ arr = np.mean(arr, axis=-1).astype(np.float32)
116
+ return int(sample_rate), arr
117
+
118
+
119
+ def _read_vocence_yaml(repo: Path) -> dict[str, Any]:
120
+ path = repo / _VOCENCE_YAML
121
+ if not path.is_file():
122
+ return {}
123
+ from yaml import safe_load
124
+
125
+ with path.open("r", encoding="utf-8") as fh:
126
+ data = safe_load(fh)
127
+ return data if isinstance(data, Mapping) else {}
128
+
129
+
130
+ def _f32_to_wav_bytes(waveform: np.ndarray, sample_rate: int) -> bytes:
131
+ w = np.clip(np.asarray(waveform, dtype=np.float32), -1.0, 1.0)
132
+ s16 = (w * 32767.0).astype(np.int16)
133
+ buf = io.BytesIO()
134
+ with wave.open(buf, "wb") as wv:
135
+ wv.setnchannels(1)
136
+ wv.setsampwidth(2)
137
+ wv.setframerate(sample_rate)
138
+ wv.writeframes(s16.tobytes())
139
+ return buf.getvalue()
140
+
141
+
142
+ def _resolve_path(repo: Path, raw: str) -> Path:
143
+ p = Path(raw).expanduser()
144
+ return p.resolve() if p.is_absolute() else (repo / p).resolve()
145
+
146
+
147
+ def _hf_token() -> str | None:
148
+ return (os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or "").strip() or None
149
+
150
+
151
+ def _weights_dir_for_repo_id(hf_repo: Path, repo_id: str) -> Path:
152
+ safe = repo_id.replace("/", "__").replace(":", "_")
153
+ return (hf_repo / "_vocence_hf_weights" / safe).resolve()
154
+
155
+
156
+ def download_runtime_hub_model(
157
+ hf_repo: Path,
158
+ repo_id: str,
159
+ *,
160
+ revision: str | None = None,
161
+ ) -> Path:
162
+ """Download ``repo_id`` into ``hf_repo/_vocence_hf_weights/<sanitized>/`` and return that directory."""
163
+ from huggingface_hub import snapshot_download
164
+
165
+ dest = _weights_dir_for_repo_id(hf_repo, repo_id)
166
+ dest.mkdir(parents=True, exist_ok=True)
167
+ snapshot_download(
168
+ repo_id=repo_id,
169
+ revision=revision,
170
+ local_dir=str(dest),
171
+ local_dir_use_symlinks=False,
172
+ token=_hf_token(),
173
+ )
174
+ return dest
175
+
176
+
177
+ def _resolve_checkpoint_path(
178
+ raw: str | None,
179
+ *,
180
+ model_root: Path,
181
+ hf_repo: Path,
182
+ ) -> Path | None:
183
+ if raw is None or not str(raw).strip():
184
+ return None
185
+ s = str(raw).strip()
186
+ p = Path(s).expanduser()
187
+ if p.is_absolute():
188
+ return p.resolve()
189
+ for base in (model_root, hf_repo):
190
+ cand = (base / s).resolve()
191
+ if cand.exists():
192
+ return cand
193
+ return (hf_repo / s).resolve()
194
+
195
+
196
+ def _infer_fish_codec_paths(model_root: Path) -> tuple[str, str]:
197
+ matches = sorted(model_root.rglob("codec.pth"), key=lambda x: len(x.parts))
198
+ if not matches:
199
+ raise FileNotFoundError(
200
+ f"No codec.pth under {model_root}; set fish_speech.llama_checkpoint_path and "
201
+ f"fish_speech.decoder_checkpoint_path in {_VOCENCE_YAML}."
202
+ )
203
+ codec = matches[0]
204
+ parent = codec.parent
205
+ return str(parent), str(codec)
206
+
207
+
208
+ def _instruction_pipes_to_brackets(instruction: str) -> str:
209
+ s = instruction.strip()
210
+ if not s:
211
+ return ""
212
+ parts = [p.strip() for p in s.split("|") if p.strip()]
213
+ return "".join(f"[{p}]" for p in parts)
214
+
215
+
216
+ def _tts_prompt_from_instruction_and_text(instruction: str, text: str) -> str:
217
+ tags = _instruction_pipes_to_brackets(instruction)
218
+ body = text.strip()
219
+ if not tags:
220
+ return body
221
+ if not body:
222
+ return tags
223
+ return f"{tags} {body}"
224
+
225
+
226
+ class Miner:
227
+
228
+ def __init__(self, path_hf_repo: Path) -> None:
229
+ self._repo = Path(path_hf_repo).resolve()
230
+ cfg = _read_vocence_yaml(self._repo)
231
+ limits = cfg.get("limits") or {}
232
+ self._cap_text = int(limits.get("max_text_chars", 2000))
233
+ self._cap_instruction = int(limits.get("max_instruction_chars", 600))
234
+
235
+ gen = cfg.get("generation") or {}
236
+ out_sr = int(gen.get("sample_rate", _VOCENCE_OUTPUT_HZ))
237
+ if out_sr != _VOCENCE_OUTPUT_HZ:
238
+ raise ValueError(
239
+ f"generation.sample_rate must be {_VOCENCE_OUTPUT_HZ} (got {out_sr}); "
240
+ f"edit {self._repo / _VOCENCE_YAML}."
241
+ )
242
+ self._output_sr = out_sr
243
+
244
+ fs = cfg.get("fish_speech") or {}
245
+ rt = cfg.get("runtime") or {}
246
+
247
+ hub_id = (rt.get("hub_model_id") or rt.get("model_id") or "").strip()
248
+ rev_raw = (
249
+ rt.get("model_revision")
250
+ or rt.get("hub_revision")
251
+ or os.environ.get("VOCENCE_MODEL_REVISION")
252
+ or ""
253
+ )
254
+ revision = str(rev_raw).strip() or None
255
+
256
+ model_root = self._repo
257
+ if hub_id:
258
+ model_root = download_runtime_hub_model(self._repo, hub_id, revision=revision)
259
+
260
+ repo_root = (fs.get("repo_root") or os.environ.get("FISH_SPEECH_ROOT") or "").strip()
261
+ if repo_root:
262
+ rr = _resolve_path(self._repo, repo_root)
263
+ if rr.is_dir() and str(rr) not in sys.path:
264
+ sys.path.insert(0, str(rr))
265
+
266
+ llama_raw = (fs.get("llama_checkpoint_path") or os.environ.get("FISH_SPEECH_LLAMA_PATH") or "").strip()
267
+ dec_raw = (fs.get("decoder_checkpoint_path") or os.environ.get("FISH_SPEECH_DECODER_PATH") or "").strip()
268
+
269
+ llama_path = _resolve_checkpoint_path(llama_raw or None, model_root=model_root, hf_repo=self._repo)
270
+ dec_path = _resolve_checkpoint_path(dec_raw or None, model_root=model_root, hf_repo=self._repo)
271
+
272
+ if llama_path is not None and dec_path is not None:
273
+ llama_p, decoder_p = str(llama_path), str(dec_path)
274
+ elif llama_path is not None and dec_path is None:
275
+ llama_p = str(llama_path)
276
+ cand = sorted(Path(llama_p).rglob("codec.pth"), key=lambda x: len(x.parts))
277
+ if not cand:
278
+ raise FileNotFoundError(f"No codec.pth under {llama_p}; set fish_speech.decoder_checkpoint_path.")
279
+ decoder_p = str(cand[0])
280
+ elif dec_path is not None and llama_path is None:
281
+ decoder_p = str(dec_path)
282
+ llama_p = str(dec_path.parent)
283
+ else:
284
+ llama_p, decoder_p = _infer_fish_codec_paths(model_root)
285
+
286
+ device = str(fs.get("device") or rt.get("device_preference") or os.environ.get("FISH_SPEECH_DEVICE") or "cuda")
287
+ half = bool(fs.get("half", False))
288
+ compile_model = bool(fs.get("compile", False))
289
+ decoder_config = str(fs.get("decoder_config_name", "modded_dac_vq"))
290
+
291
+ self._engine = load_tts_inference_engine(
292
+ llama_checkpoint_path=llama_p,
293
+ decoder_checkpoint_path=decoder_p,
294
+ decoder_config_name=decoder_config,
295
+ device=device,
296
+ half=half,
297
+ compile_model=compile_model,
298
+ )
299
+
300
+ self._max_new_tokens = int(fs.get("max_new_tokens", 1024))
301
+ self._chunk_length = int(fs.get("chunk_length", 200))
302
+ self._top_p = float(fs.get("top_p", 0.8))
303
+ self._repetition_penalty = float(fs.get("repetition_penalty", 1.1))
304
+ self._temperature = float(fs.get("temperature", 0.8))
305
+ self._seed = fs.get("seed")
306
+ self._seed_i: int | None = int(self._seed) if self._seed is not None else None
307
+
308
+ self._meta = {
309
+ "adapter": str(rt.get("adapter", "finetuned-tts")),
310
+ "hub_model_id": hub_id or None,
311
+ "model_revision": revision,
312
+ "weights_local_dir": str(model_root) if hub_id else None,
313
+ "llama_checkpoint_path": llama_p,
314
+ "decoder_checkpoint_path": decoder_p,
315
+ "device": device,
316
+ "output_sample_rate": self._output_sr,
317
+ }
318
+
319
+ def get_status(self) -> dict[str, Any]:
320
+ return {"tts_engine": "finetuned-tts", **self._meta}
321
+
322
+ def warmup(self) -> None:
323
+ self.generate_wav(
324
+ "gender: neutral | pitch: mid | speed: normal | age_group: adult | "
325
+ "emotion: neutral | tone: neutral | accent: generic",
326
+ "Warmup complete.",
327
+ )
328
+
329
+ def generate_wav(self, instruction: str, text: str) -> tuple[np.ndarray, int]:
330
+ t = text[: self._cap_text] if self._cap_text else text
331
+ ins = instruction[: self._cap_instruction] if self._cap_instruction else instruction
332
+ prompt = _tts_prompt_from_instruction_and_text(ins, t)
333
+
334
+ sr, wav = synthesize_wav(
335
+ self._engine,
336
+ text=prompt,
337
+ max_new_tokens=self._max_new_tokens,
338
+ chunk_length=self._chunk_length,
339
+ top_p=self._top_p,
340
+ repetition_penalty=self._repetition_penalty,
341
+ temperature=self._temperature,
342
+ seed=self._seed_i,
343
+ )
344
+ wav_out = _resample_to_hz_mono_f32(wav, int(sr), self._output_sr)
345
+ return wav_out, self._output_sr
346
+
347
+
348
+ _engine: Miner | None = None
349
+ _health: dict[str, Any] = {}
350
+
351
+
352
+ def _run_dev_server() -> None:
353
+ from contextlib import asynccontextmanager
354
+
355
+ import uvicorn
356
+ from fastapi import Body, FastAPI, HTTPException, status
357
+ from fastapi.responses import Response
358
+ from pydantic import BaseModel
359
+
360
+ @asynccontextmanager
361
+ async def _lifespan(_app: Any):
362
+ global _engine, _health
363
+ cfg = _read_vocence_yaml(REPO)
364
+ gen = cfg.get("generation") or {}
365
+ _health = {"sample_rate": int(gen.get("sample_rate", _VOCENCE_OUTPUT_HZ))}
366
+ try:
367
+ _engine = Miner(REPO)
368
+ _health["adapter"] = json.dumps(_engine.get_status())
369
+ except Exception as e:
370
+ _engine = None
371
+ _health["adapter"] = json.dumps({"tts_engine": "not loaded"})
372
+ _health["error"] = f"{type(e).__name__}: {e}"
373
+ yield
374
+ _engine = None
375
+
376
+ class HealthResponse(BaseModel):
377
+ status: str
378
+ model_loaded: bool
379
+ sample_rate: int | None = None
380
+ adapter: str | None = None
381
+
382
+ app = FastAPI(title="Vocence finetuned-tts TTS (dev)", lifespan=_lifespan)
383
+
384
+ @app.get("/health", response_model=HealthResponse)
385
+ async def health() -> HealthResponse:
386
+ ok = _engine is not None
387
+ err = _health.get("error")
388
+ return HealthResponse(
389
+ status="healthy" if ok else (f"unhealthy: {err}" if err else "unhealthy"),
390
+ model_loaded=ok,
391
+ sample_rate=_health.get("sample_rate"),
392
+ adapter=_health.get("adapter", "finetuned-tts"),
393
+ )
394
+
395
+ max_text = int((_read_vocence_yaml(REPO).get("limits") or {}).get("max_text_chars", 2000))
396
+ max_inst = int((_read_vocence_yaml(REPO).get("limits") or {}).get("max_instruction_chars", 600))
397
+
398
+ @app.post("/speak", response_class=Response, response_model=None)
399
+ async def speak(
400
+ text: str = Body(..., min_length=1, max_length=max_text, embed=True),
401
+ instruction: str = Body(..., min_length=1, max_length=max_inst, embed=True),
402
+ ) -> Response:
403
+ if _engine is None:
404
+ raise HTTPException(
405
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
406
+ detail=f"TTS engine not loaded: {_health.get('error', 'unknown')}",
407
+ )
408
+ waveform, sample_rate = _engine.generate_wav(instruction=instruction, text=text)
409
+ w = np.asarray(waveform)
410
+ if w.ndim != 1 or w.size == 0:
411
+ raise HTTPException(status_code=400, detail="invalid waveform")
412
+ duration = float(w.shape[0]) / float(sample_rate)
413
+ if duration <= 0 or duration > _MAX_AUDIO_SEC:
414
+ raise HTTPException(status_code=400, detail="invalid duration")
415
+ return Response(content=_f32_to_wav_bytes(w, int(sample_rate)), media_type="audio/wav")
416
+
417
+ import logging
418
+
419
+ logging.basicConfig(level=logging.INFO)
420
+ host = os.environ.get("HOST", "0.0.0.0")
421
+ port = int(os.environ.get("PORT", "8765"))
422
+ uvicorn.run(app, host=host, port=port)
423
+
424
+
425
+ if __name__ == "__main__":
426
+ _run_dev_server()