rahul7star commited on
Commit
f036d34
·
verified ·
1 Parent(s): 2883c2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -141
app.py CHANGED
@@ -1,162 +1,99 @@
1
  import os
2
- import io
3
- import random
4
- import numpy as np
5
  import torch
6
- from scipy.io import wavfile
7
- from fastapi import FastAPI, Form
8
- from fastapi.responses import StreamingResponse, JSONResponse
9
 
10
  from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
11
 
12
- # ===============================
13
- # CPU-ONLY HARD PATCH
14
- # ===============================
15
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
16
 
17
- _original_torch_load = torch.load
18
- def _cpu_only_torch_load(*args, **kwargs):
19
- kwargs.setdefault("map_location", torch.device("cpu"))
20
- return _original_torch_load(*args, **kwargs)
21
- torch.load = _cpu_only_torch_load
 
22
 
23
- # ===============================
24
- # LANGUAGE CONFIG
25
- # ===============================
26
- LANGUAGE_CONFIG = {
27
- "en": {
28
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac"
29
- },
30
- "hi": {
31
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac"
32
- },
33
- "fr": {
34
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac"
35
- },
36
- "he": {
37
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac"
38
- },
39
- }
40
 
41
- # ===============================
42
- # MODEL LOADING (SAFE)
43
- # ===============================
44
- MODEL = None
45
 
 
 
 
46
  def get_or_load_model():
47
  global MODEL
48
  if MODEL is None:
49
- print("🔄 Loading Chatterbox model (CPU-only)")
50
- MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
51
- MODEL.eval()
52
- print("✅ Model loaded")
53
- return MODEL
54
-
55
- # Load at startup
56
- get_or_load_model()
57
-
58
- # ===============================
59
- # UTILITIES
60
- # ===============================
61
- def set_seed(seed: int):
62
- torch.manual_seed(seed)
63
- random.seed(seed)
64
- np.random.seed(seed)
65
-
66
- def format_for_singing(lyrics: str) -> str:
67
- lines = []
68
- for line in lyrics.splitlines():
69
- line = line.strip()
70
- if not line:
71
- continue
72
- line = (
73
- line.replace("a", "aa")
74
- .replace("e", "ee")
75
- .replace("i", "ii")
76
- .replace("o", "oo")
77
- .replace("u", "uu")
78
  )
79
- lines.append(f"{line} ♪ ...")
80
- return "\n".join(lines)
81
-
82
- # ===============================
83
- # FASTAPI APP
84
- # ===============================
85
- app = FastAPI(
86
- title="Chatterbox Multilingual TTS",
87
- version="1.0"
88
- )
89
 
90
- # ===============================
91
- # HEALTH API
92
- # ===============================
93
- @app.get("/health")
94
- def health():
95
- return {
96
- "status": "ok",
97
- "device": "cpu",
98
- "languages": list(SUPPORTED_LANGUAGES.keys())
99
- }
100
 
101
- # ===============================
102
- # TTS API
103
- # ===============================
104
  @app.post("/tts")
105
- def tts(
106
- mode: str = Form("Speak"), # Speak | Sing
107
- text: str = Form(""),
108
- lyrics: str = Form(""),
109
- language_id: str = Form("hi"),
110
- exaggeration: float = Form(0.5),
111
- temperature: float = Form(0.8),
112
- cfg_weight: float = Form(0.5),
113
- seed: int = Form(0),
114
- ):
115
- model = get_or_load_model()
116
 
117
- if seed != 0:
118
- set_seed(seed)
119
-
120
- if mode.lower() == "sing":
121
- if not lyrics.strip():
122
- return JSONResponse(
123
- {"error": "Lyrics required for Sing mode"},
124
- status_code=400
125
- )
126
- final_text = format_for_singing(lyrics)
127
- else:
128
- if not text.strip():
129
- return JSONResponse(
130
- {"error": "Text required for Speak mode"},
131
- status_code=400
132
- )
133
- final_text = text
134
-
135
- kwargs = {
136
- "exaggeration": exaggeration,
137
- "temperature": temperature,
138
- "cfg_weight": cfg_weight,
139
- }
140
-
141
- prompt = LANGUAGE_CONFIG.get(language_id, {}).get("audio")
142
- if prompt:
143
- kwargs["audio_prompt_path"] = prompt
144
-
145
- with torch.no_grad():
146
- wav = model.generate(
147
- final_text[:300],
148
- language_id=language_id,
149
- **kwargs
150
  )
151
 
152
- wav = wav.squeeze(0).cpu().numpy()
153
-
154
- buffer = io.BytesIO()
155
- wavfile.write(buffer, model.sr, wav)
156
- buffer.seek(0)
157
-
158
- return StreamingResponse(
159
- buffer,
160
  media_type="audio/wav",
161
- headers={"Content-Disposition": "inline; filename=output.wav"}
162
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import uuid
 
 
3
  import torch
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from fastapi.responses import FileResponse, HTMLResponse
7
 
8
  from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
9
 
10
+ # -------------------------------------------------
11
+ # App
12
+ # -------------------------------------------------
13
+ app = FastAPI(title="Chatterbox Multilingual TTS")
14
 
15
+ # -------------------------------------------------
16
+ # Globals (model loaded once)
17
+ # -------------------------------------------------
18
+ MODEL = None
19
+ OUTPUT_DIR = "/tmp/tts_outputs"
20
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
21
 
22
+ # -------------------------------------------------
23
+ # Request schema
24
+ # -------------------------------------------------
25
+ class TTSRequest(BaseModel):
26
+ text: str
27
+ language: str = "en" # "en" or "hi"
28
+ speaker: str | None = None
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
30
 
31
+ # -------------------------------------------------
32
+ # Model loader (NO .eval())
33
+ # -------------------------------------------------
34
  def get_or_load_model():
35
  global MODEL
36
  if MODEL is None:
37
+ MODEL = ChatterboxMultilingualTTS(
38
+ device="cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ return MODEL
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # -------------------------------------------------
44
+ # API: TTS
45
+ # -------------------------------------------------
46
  @app.post("/tts")
47
+ def tts(req: TTSRequest):
48
+ if req.language not in SUPPORTED_LANGUAGES:
49
+ return {
50
+ "error": f"Unsupported language. Supported: {SUPPORTED_LANGUAGES}"
51
+ }
 
 
 
 
 
 
52
 
53
+ model = get_or_load_model()
54
+ out_path = os.path.join(OUTPUT_DIR, f"{uuid.uuid4().hex}.wav")
55
+
56
+ # Correct inference pattern
57
+ with torch.inference_mode():
58
+ audio = model.tts(
59
+ text=req.text,
60
+ language=req.language,
61
+ speaker=req.speaker,
62
+ output_path=out_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
 
65
+ return FileResponse(
66
+ out_path,
 
 
 
 
 
 
67
  media_type="audio/wav",
68
+ filename="speech.wav",
69
  )
70
+
71
+
72
+ # -------------------------------------------------
73
+ # Simple UI (for quick testing)
74
+ # -------------------------------------------------
75
+ @app.get("/", response_class=HTMLResponse)
76
+ def ui():
77
+ return """
78
+ <html>
79
+ <body>
80
+ <h2>Chatterbox Multilingual TTS</h2>
81
+ <form action="/tts" method="post">
82
+ <textarea name="text" rows="4" cols="60">Hello, how are you?</textarea><br><br>
83
+ <select name="language">
84
+ <option value="en">English</option>
85
+ <option value="hi">Hindi</option>
86
+ </select><br><br>
87
+ <button type="submit">Generate Speech</button>
88
+ </form>
89
+ </body>
90
+ </html>
91
+ """
92
+
93
+
94
+ # -------------------------------------------------
95
+ # Warm-up (optional, safe)
96
+ # -------------------------------------------------
97
+ @app.on_event("startup")
98
+ def warmup():
99
+ get_or_load_model()