iamcodio commited on
Commit
2e4af38
·
verified ·
1 Parent(s): 79ea526

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +23 -17
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import io
2
  import time
 
3
  import tempfile
4
  from pathlib import Path
5
 
6
  from fastapi import FastAPI, HTTPException
7
- from fastapi.responses import Response
8
  from pydantic import BaseModel, Field
9
 
10
  app = FastAPI(title="iamcodio Dia2 TTS")
@@ -16,10 +16,10 @@ def get_model():
16
  global model
17
  if model is None:
18
  from dia2 import Dia2
19
- print("[dia2] Loading Dia2-2B model...")
20
  start = time.time()
21
  model = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16")
22
- print(f"[dia2] Model loaded in {time.time() - start:.1f}s")
23
  return model
24
 
25
 
@@ -41,19 +41,19 @@ def generate(req: GenerateRequest):
41
  if not req.text or req.text.isspace():
42
  raise HTTPException(status_code=400, detail="Text input cannot be empty")
43
 
44
- from dia2 import GenerationConfig, SamplingConfig
 
45
 
46
- dia = get_model()
47
- config = GenerationConfig(
48
- cfg_scale=req.cfg_scale,
49
- audio=SamplingConfig(temperature=req.temperature, top_k=req.top_k),
50
- use_cuda_graph=req.use_cuda_graph,
51
- )
52
 
53
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
54
- tmp_path = f.name
55
 
56
- try:
57
  start = time.time()
58
  result = dia.generate(
59
  req.text,
@@ -62,9 +62,11 @@ def generate(req: GenerateRequest):
62
  verbose=True,
63
  )
64
  elapsed = time.time() - start
65
- print(f"[dia2] Generated in {elapsed:.2f}s")
66
 
67
  wav_bytes = Path(tmp_path).read_bytes()
 
 
68
  return Response(
69
  content=wav_bytes,
70
  media_type="audio/wav",
@@ -72,5 +74,9 @@ def generate(req: GenerateRequest):
72
  "X-Generation-Time": f"{elapsed:.2f}",
73
  },
74
  )
75
- finally:
76
- Path(tmp_path).unlink(missing_ok=True)
 
 
 
 
 
 
1
  import time
2
+ import traceback
3
  import tempfile
4
  from pathlib import Path
5
 
6
  from fastapi import FastAPI, HTTPException
7
+ from fastapi.responses import Response, JSONResponse
8
  from pydantic import BaseModel, Field
9
 
10
  app = FastAPI(title="iamcodio Dia2 TTS")
 
16
  global model
17
  if model is None:
18
  from dia2 import Dia2
19
+ print("[dia2] Loading Dia2-2B model...", flush=True)
20
  start = time.time()
21
  model = Dia2.from_repo("nari-labs/Dia2-2B", device="cuda", dtype="bfloat16")
22
+ print(f"[dia2] Model loaded in {time.time() - start:.1f}s", flush=True)
23
  return model
24
 
25
 
 
41
  if not req.text or req.text.isspace():
42
  raise HTTPException(status_code=400, detail="Text input cannot be empty")
43
 
44
+ try:
45
+ from dia2 import GenerationConfig, SamplingConfig
46
 
47
+ dia = get_model()
48
+ config = GenerationConfig(
49
+ cfg_scale=req.cfg_scale,
50
+ audio=SamplingConfig(temperature=req.temperature, top_k=req.top_k),
51
+ use_cuda_graph=req.use_cuda_graph,
52
+ )
53
 
54
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
55
+ tmp_path = f.name
56
 
 
57
  start = time.time()
58
  result = dia.generate(
59
  req.text,
 
62
  verbose=True,
63
  )
64
  elapsed = time.time() - start
65
+ print(f"[dia2] Generated in {elapsed:.2f}s", flush=True)
66
 
67
  wav_bytes = Path(tmp_path).read_bytes()
68
+ Path(tmp_path).unlink(missing_ok=True)
69
+
70
  return Response(
71
  content=wav_bytes,
72
  media_type="audio/wav",
 
74
  "X-Generation-Time": f"{elapsed:.2f}",
75
  },
76
  )
77
+ except Exception as e:
78
+ traceback.print_exc()
79
+ return JSONResponse(
80
+ status_code=500,
81
+ content={"error": str(e), "type": type(e).__name__},
82
+ )