import time import traceback import tempfile from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.responses import Response, JSONResponse from pydantic import BaseModel, Field app = FastAPI(title="iamcodio Dia2 TTS") model = None def get_model(): global model if model is None: from dia2 import Dia2 print("[dia2] Loading Dia2-2B model...", flush=True) start = time.time() model = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16") print(f"[dia2] Model loaded in {time.time() - start:.1f}s", flush=True) return model class GenerateRequest(BaseModel): text: str = Field(..., description="Text with [S1]/[S2] speaker tags") cfg_scale: float = Field(default=6.0, ge=1.0, le=10.0) temperature: float = Field(default=0.8, ge=0.1, le=2.0) top_k: int = Field(default=50, ge=1, le=200) use_cuda_graph: bool = Field(default=True) use_voice_clone: bool = Field(default=True, description="Use iamcodio voice reference for cloning") @app.get("/health") def health(): return {"status": "ok", "model_loaded": model is not None} @app.post("/generate") def generate(req: GenerateRequest): if not req.text or req.text.isspace(): raise HTTPException(status_code=400, detail="Text input cannot be empty") try: from dia2 import GenerationConfig, SamplingConfig dia = get_model() config = GenerationConfig( cfg_scale=req.cfg_scale, audio=SamplingConfig(temperature=req.temperature, top_k=req.top_k), use_cuda_graph=req.use_cuda_graph, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: tmp_path = f.name # Voice cloning via prefix_speaker_1 (Dia2 API) voice_ref = Path("/app/voice-reference.wav") prefix_speaker_1 = None if req.use_voice_clone and voice_ref.exists(): prefix_speaker_1 = str(voice_ref) print(f"[dia2] Voice cloning from {prefix_speaker_1}", flush=True) else: print("[dia2] No voice reference — generating without clone", flush=True) start = time.time() result = dia.generate( req.text, config=config, output_wav=tmp_path, prefix_speaker_1=prefix_speaker_1, include_prefix=False, verbose=True, ) elapsed = time.time() - start print(f"[dia2] Generated in {elapsed:.2f}s", flush=True) wav_bytes = Path(tmp_path).read_bytes() Path(tmp_path).unlink(missing_ok=True) return Response( content=wav_bytes, media_type="audio/wav", headers={ "X-Generation-Time": f"{elapsed:.2f}", }, ) except Exception as e: traceback.print_exc() return JSONResponse( status_code=500, content={"error": str(e), "type": type(e).__name__}, )