Spaces:
Paused
Paused
| import time | |
| import traceback | |
| import tempfile | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import Response, JSONResponse | |
| from pydantic import BaseModel, Field | |
| app = FastAPI(title="iamcodio Dia2 TTS") | |
| model = None | |
| def get_model(): | |
| global model | |
| if model is None: | |
| from dia2 import Dia2 | |
| print("[dia2] Loading Dia2-2B model...", flush=True) | |
| start = time.time() | |
| model = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16") | |
| print(f"[dia2] Model loaded in {time.time() - start:.1f}s", flush=True) | |
| return model | |
| class GenerateRequest(BaseModel): | |
| text: str = Field(..., description="Text with [S1]/[S2] speaker tags") | |
| cfg_scale: float = Field(default=6.0, ge=1.0, le=10.0) | |
| temperature: float = Field(default=0.8, ge=0.1, le=2.0) | |
| top_k: int = Field(default=50, ge=1, le=200) | |
| use_cuda_graph: bool = Field(default=True) | |
| use_voice_clone: bool = Field(default=True, description="Use iamcodio voice reference for cloning") | |
| def health(): | |
| return {"status": "ok", "model_loaded": model is not None} | |
| def generate(req: GenerateRequest): | |
| if not req.text or req.text.isspace(): | |
| raise HTTPException(status_code=400, detail="Text input cannot be empty") | |
| try: | |
| from dia2 import GenerationConfig, SamplingConfig | |
| dia = get_model() | |
| config = GenerationConfig( | |
| cfg_scale=req.cfg_scale, | |
| audio=SamplingConfig(temperature=req.temperature, top_k=req.top_k), | |
| use_cuda_graph=req.use_cuda_graph, | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| tmp_path = f.name | |
| # Voice cloning via prefix_speaker_1 (Dia2 API) | |
| voice_ref = Path("/app/voice-reference.wav") | |
| prefix_speaker_1 = None | |
| if req.use_voice_clone and voice_ref.exists(): | |
| prefix_speaker_1 = str(voice_ref) | |
| print(f"[dia2] Voice cloning from {prefix_speaker_1}", flush=True) | |
| else: | |
| print("[dia2] No voice reference — generating without clone", flush=True) | |
| start = time.time() | |
| result = dia.generate( | |
| req.text, | |
| config=config, | |
| output_wav=tmp_path, | |
| prefix_speaker_1=prefix_speaker_1, | |
| include_prefix=False, | |
| verbose=True, | |
| ) | |
| elapsed = time.time() - start | |
| print(f"[dia2] Generated in {elapsed:.2f}s", flush=True) | |
| wav_bytes = Path(tmp_path).read_bytes() | |
| Path(tmp_path).unlink(missing_ok=True) | |
| return Response( | |
| content=wav_bytes, | |
| media_type="audio/wav", | |
| headers={ | |
| "X-Generation-Time": f"{elapsed:.2f}", | |
| }, | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": str(e), "type": type(e).__name__}, | |
| ) | |