iamcodio's picture
Fix voice cloning: use prefix_speaker_1 (Dia2 API)
9f5414e verified
import time
import traceback
import tempfile
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response, JSONResponse
from pydantic import BaseModel, Field
app = FastAPI(title="iamcodio Dia2 TTS")
model = None
def get_model():
global model
if model is None:
from dia2 import Dia2
print("[dia2] Loading Dia2-2B model...", flush=True)
start = time.time()
model = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16")
print(f"[dia2] Model loaded in {time.time() - start:.1f}s", flush=True)
return model
class GenerateRequest(BaseModel):
text: str = Field(..., description="Text with [S1]/[S2] speaker tags")
cfg_scale: float = Field(default=6.0, ge=1.0, le=10.0)
temperature: float = Field(default=0.8, ge=0.1, le=2.0)
top_k: int = Field(default=50, ge=1, le=200)
use_cuda_graph: bool = Field(default=True)
use_voice_clone: bool = Field(default=True, description="Use iamcodio voice reference for cloning")
@app.get("/health")
def health():
return {"status": "ok", "model_loaded": model is not None}
@app.post("/generate")
def generate(req: GenerateRequest):
if not req.text or req.text.isspace():
raise HTTPException(status_code=400, detail="Text input cannot be empty")
try:
from dia2 import GenerationConfig, SamplingConfig
dia = get_model()
config = GenerationConfig(
cfg_scale=req.cfg_scale,
audio=SamplingConfig(temperature=req.temperature, top_k=req.top_k),
use_cuda_graph=req.use_cuda_graph,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
tmp_path = f.name
# Voice cloning via prefix_speaker_1 (Dia2 API)
voice_ref = Path("/app/voice-reference.wav")
prefix_speaker_1 = None
if req.use_voice_clone and voice_ref.exists():
prefix_speaker_1 = str(voice_ref)
print(f"[dia2] Voice cloning from {prefix_speaker_1}", flush=True)
else:
print("[dia2] No voice reference — generating without clone", flush=True)
start = time.time()
result = dia.generate(
req.text,
config=config,
output_wav=tmp_path,
prefix_speaker_1=prefix_speaker_1,
include_prefix=False,
verbose=True,
)
elapsed = time.time() - start
print(f"[dia2] Generated in {elapsed:.2f}s", flush=True)
wav_bytes = Path(tmp_path).read_bytes()
Path(tmp_path).unlink(missing_ok=True)
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={
"X-Generation-Time": f"{elapsed:.2f}",
},
)
except Exception as e:
traceback.print_exc()
return JSONResponse(
status_code=500,
content={"error": str(e), "type": type(e).__name__},
)