File size: 3,013 Bytes
79ea526
2e4af38
79ea526
 
 
 
2e4af38
79ea526
 
 
 
 
 
 
 
 
 
 
2e4af38
79ea526
 
2e4af38
79ea526
 
 
 
 
 
 
 
 
9f5414e
79ea526
 
 
 
 
 
 
 
 
 
 
 
2e4af38
 
79ea526
2e4af38
 
 
 
 
 
79ea526
2e4af38
 
79ea526
9f5414e
cf1ff8b
9f5414e
 
 
 
 
 
cf1ff8b
79ea526
 
 
 
 
9f5414e
 
79ea526
 
 
2e4af38
79ea526
 
2e4af38
 
79ea526
 
 
 
 
 
 
2e4af38
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import time
import traceback
import tempfile
from pathlib import Path

from fastapi import FastAPI, HTTPException
from fastapi.responses import Response, JSONResponse
from pydantic import BaseModel, Field

app = FastAPI(title="iamcodio Dia2 TTS")

model = None


def get_model():
    global model
    if model is None:
        from dia2 import Dia2
        print("[dia2] Loading Dia2-2B model...", flush=True)
        start = time.time()
        model = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16")
        print(f"[dia2] Model loaded in {time.time() - start:.1f}s", flush=True)
    return model


class GenerateRequest(BaseModel):
    text: str = Field(..., description="Text with [S1]/[S2] speaker tags")
    cfg_scale: float = Field(default=6.0, ge=1.0, le=10.0)
    temperature: float = Field(default=0.8, ge=0.1, le=2.0)
    top_k: int = Field(default=50, ge=1, le=200)
    use_cuda_graph: bool = Field(default=True)
    use_voice_clone: bool = Field(default=True, description="Use iamcodio voice reference for cloning")


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": model is not None}


@app.post("/generate")
def generate(req: GenerateRequest):
    if not req.text or req.text.isspace():
        raise HTTPException(status_code=400, detail="Text input cannot be empty")

    try:
        from dia2 import GenerationConfig, SamplingConfig

        dia = get_model()
        config = GenerationConfig(
            cfg_scale=req.cfg_scale,
            audio=SamplingConfig(temperature=req.temperature, top_k=req.top_k),
            use_cuda_graph=req.use_cuda_graph,
        )

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            tmp_path = f.name

        # Voice cloning via prefix_speaker_1 (Dia2 API)
        voice_ref = Path("/app/voice-reference.wav")
        prefix_speaker_1 = None
        if req.use_voice_clone and voice_ref.exists():
            prefix_speaker_1 = str(voice_ref)
            print(f"[dia2] Voice cloning from {prefix_speaker_1}", flush=True)
        else:
            print("[dia2] No voice reference — generating without clone", flush=True)

        start = time.time()
        result = dia.generate(
            req.text,
            config=config,
            output_wav=tmp_path,
            prefix_speaker_1=prefix_speaker_1,
            include_prefix=False,
            verbose=True,
        )
        elapsed = time.time() - start
        print(f"[dia2] Generated in {elapsed:.2f}s", flush=True)

        wav_bytes = Path(tmp_path).read_bytes()
        Path(tmp_path).unlink(missing_ok=True)

        return Response(
            content=wav_bytes,
            media_type="audio/wav",
            headers={
                "X-Generation-Time": f"{elapsed:.2f}",
            },
        )
    except Exception as e:
        traceback.print_exc()
        return JSONResponse(
            status_code=500,
            content={"error": str(e), "type": type(e).__name__},
        )