rahul7star commited on
Commit
00fb245
·
verified ·
1 Parent(s): cc3d0ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -117
app.py CHANGED
@@ -3,17 +3,14 @@ import io
3
  import random
4
  import numpy as np
5
  import torch
6
- import soundfile as sf
7
  from fastapi import FastAPI, Form
8
- from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse
9
 
10
- from src.chatterbox.mtl_tts import (
11
- ChatterboxMultilingualTTS,
12
- SUPPORTED_LANGUAGES
13
- )
14
 
15
  # ===============================
16
- # CPU ONLY PATCH
17
  # ===============================
18
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
19
 
@@ -23,57 +20,41 @@ def _cpu_only_torch_load(*args, **kwargs):
23
  return _original_torch_load(*args, **kwargs)
24
  torch.load = _cpu_only_torch_load
25
 
26
-
27
  # ===============================
28
  # LANGUAGE CONFIG
29
  # ===============================
30
  LANGUAGE_CONFIG = {
31
  "en": {
32
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac",
33
  },
34
  "hi": {
35
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac",
36
  },
37
  "fr": {
38
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac",
39
  },
40
  "he": {
41
- "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac",
42
  },
43
  }
44
 
45
  # ===============================
46
- # MODEL LOAD (CPU SAFE)
47
  # ===============================
48
  MODEL = None
49
- DEVICE = "cpu"
50
 
51
  def get_or_load_model():
52
  global MODEL
53
-
54
  if MODEL is None:
55
- print("🔄 Loading Chatterbox model (CPU-only)...")
56
-
57
  MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
58
-
59
- # ⚠️ DO NOT call .to()
60
- # ChatterboxMultilingualTTS is NOT a torch.nn.Module
61
-
62
  MODEL.eval()
63
-
64
- # Disable grads if exposed
65
- if hasattr(MODEL, "parameters"):
66
- for p in MODEL.parameters():
67
- p.requires_grad = False
68
-
69
- print("✅ Model loaded successfully")
70
-
71
  return MODEL
72
 
73
-
74
  get_or_load_model()
75
 
76
-
77
  # ===============================
78
  # UTILITIES
79
  # ===============================
@@ -82,7 +63,6 @@ def set_seed(seed: int):
82
  random.seed(seed)
83
  np.random.seed(seed)
84
 
85
-
86
  def format_for_singing(lyrics: str) -> str:
87
  lines = []
88
  for line in lyrics.splitlines():
@@ -99,24 +79,31 @@ def format_for_singing(lyrics: str) -> str:
99
  lines.append(f"{line} ♪ ...")
100
  return "\n".join(lines)
101
 
102
-
103
  # ===============================
104
  # FASTAPI APP
105
  # ===============================
106
- app = FastAPI(title="Chatterbox TTS API", version="1.0")
107
-
 
 
108
 
 
 
 
109
  @app.get("/health")
110
  def health():
111
- return {"status": "ok", "device": DEVICE}
112
-
 
 
 
113
 
114
  # ===============================
115
  # TTS API
116
  # ===============================
117
  @app.post("/tts")
118
- def tts_api(
119
- mode: str = Form("Speak"),
120
  text: str = Form(""),
121
  lyrics: str = Form(""),
122
  language_id: str = Form("hi"),
@@ -130,13 +117,19 @@ def tts_api(
130
  if seed != 0:
131
  set_seed(seed)
132
 
133
- if mode == "Sing":
134
  if not lyrics.strip():
135
- return JSONResponse({"error": "Lyrics required for Sing mode"}, status_code=400)
 
 
 
136
  final_text = format_for_singing(lyrics)
137
  else:
138
  if not text.strip():
139
- return JSONResponse({"error": "Text required for Speak mode"}, status_code=400)
 
 
 
140
  final_text = text
141
 
142
  kwargs = {
@@ -159,7 +152,7 @@ def tts_api(
159
  wav = wav.squeeze(0).cpu().numpy()
160
 
161
  buffer = io.BytesIO()
162
- sf.write(buffer, wav, model.sr, format="WAV")
163
  buffer.seek(0)
164
 
165
  return StreamingResponse(
@@ -167,77 +160,3 @@ def tts_api(
167
  media_type="audio/wav",
168
  headers={"Content-Disposition": "inline; filename=output.wav"}
169
  )
170
-
171
-
172
- # ===============================
173
- # SIMPLE WEB UI
174
- # ===============================
175
- @app.get("/", response_class=HTMLResponse)
176
- def ui():
177
- langs = "".join(
178
- f"<option value='{k}'>{v}</option>"
179
- for k, v in SUPPORTED_LANGUAGES.items()
180
- )
181
-
182
- return f"""
183
- <!DOCTYPE html>
184
- <html>
185
- <head>
186
- <title>Chatterbox TTS</title>
187
- <style>
188
- body {{ font-family: Arial; max-width: 800px; margin: auto; }}
189
- textarea {{ width: 100%; height: 120px; }}
190
- select, button {{ padding: 6px; }}
191
- </style>
192
- </head>
193
- <body>
194
-
195
- <h2>🎤 Chatterbox Multilingual TTS</h2>
196
-
197
- <label>Mode:</label>
198
- <select id="mode">
199
- <option value="Speak">Speak</option>
200
- <option value="Sing">Sing</option>
201
- </select>
202
-
203
- <br><br>
204
-
205
- <label>Language:</label>
206
- <select id="language">{langs}</select>
207
-
208
- <br><br>
209
-
210
- <label>Text (Speak):</label>
211
- <textarea id="text"></textarea>
212
-
213
- <label>Lyrics (Sing):</label>
214
- <textarea id="lyrics"></textarea>
215
-
216
- <br>
217
-
218
- <button onclick="run()">Generate</button>
219
-
220
- <br><br>
221
- <audio id="player" controls></audio>
222
-
223
- <script>
224
- async function run() {{
225
- const form = new FormData();
226
- form.append("mode", document.getElementById("mode").value);
227
- form.append("language_id", document.getElementById("language").value);
228
- form.append("text", document.getElementById("text").value);
229
- form.append("lyrics", document.getElementById("lyrics").value);
230
-
231
- const res = await fetch("/tts", {{
232
- method: "POST",
233
- body: form
234
- }});
235
-
236
- const blob = await res.blob();
237
- document.getElementById("player").src = URL.createObjectURL(blob);
238
- }}
239
- </script>
240
-
241
- </body>
242
- </html>
243
- """
 
3
  import random
4
  import numpy as np
5
  import torch
6
+ from scipy.io import wavfile
7
  from fastapi import FastAPI, Form
8
+ from fastapi.responses import StreamingResponse, JSONResponse
9
 
10
+ from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
 
 
 
11
 
12
  # ===============================
13
+ # CPU-ONLY HARD PATCH
14
  # ===============================
15
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
16
 
 
20
  return _original_torch_load(*args, **kwargs)
21
  torch.load = _cpu_only_torch_load
22
 
 
23
  # ===============================
24
  # LANGUAGE CONFIG
25
  # ===============================
26
  LANGUAGE_CONFIG = {
27
  "en": {
28
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/en_f1.flac"
29
  },
30
  "hi": {
31
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/hi_f1.flac"
32
  },
33
  "fr": {
34
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/fr_f1.flac"
35
  },
36
  "he": {
37
+ "audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/he_m1.flac"
38
  },
39
  }
40
 
41
  # ===============================
42
+ # MODEL LOADING (SAFE)
43
  # ===============================
44
  MODEL = None
 
45
 
46
  def get_or_load_model():
47
  global MODEL
 
48
  if MODEL is None:
49
+ print("🔄 Loading Chatterbox model (CPU-only)")
 
50
  MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
 
 
 
 
51
  MODEL.eval()
52
+ print("✅ Model loaded")
 
 
 
 
 
 
 
53
  return MODEL
54
 
55
+ # Load at startup
56
  get_or_load_model()
57
 
 
58
  # ===============================
59
  # UTILITIES
60
  # ===============================
 
63
  random.seed(seed)
64
  np.random.seed(seed)
65
 
 
66
  def format_for_singing(lyrics: str) -> str:
67
  lines = []
68
  for line in lyrics.splitlines():
 
79
  lines.append(f"{line} ♪ ...")
80
  return "\n".join(lines)
81
 
 
82
  # ===============================
83
  # FASTAPI APP
84
  # ===============================
85
+ app = FastAPI(
86
+ title="Chatterbox Multilingual TTS",
87
+ version="1.0"
88
+ )
89
 
90
+ # ===============================
91
+ # HEALTH API
92
+ # ===============================
93
  @app.get("/health")
94
  def health():
95
+ return {
96
+ "status": "ok",
97
+ "device": "cpu",
98
+ "languages": list(SUPPORTED_LANGUAGES.keys())
99
+ }
100
 
101
  # ===============================
102
  # TTS API
103
  # ===============================
104
  @app.post("/tts")
105
+ def tts(
106
+ mode: str = Form("Speak"), # Speak | Sing
107
  text: str = Form(""),
108
  lyrics: str = Form(""),
109
  language_id: str = Form("hi"),
 
117
  if seed != 0:
118
  set_seed(seed)
119
 
120
+ if mode.lower() == "sing":
121
  if not lyrics.strip():
122
+ return JSONResponse(
123
+ {"error": "Lyrics required for Sing mode"},
124
+ status_code=400
125
+ )
126
  final_text = format_for_singing(lyrics)
127
  else:
128
  if not text.strip():
129
+ return JSONResponse(
130
+ {"error": "Text required for Speak mode"},
131
+ status_code=400
132
+ )
133
  final_text = text
134
 
135
  kwargs = {
 
152
  wav = wav.squeeze(0).cpu().numpy()
153
 
154
  buffer = io.BytesIO()
155
+ wavfile.write(buffer, model.sr, wav)
156
  buffer.seek(0)
157
 
158
  return StreamingResponse(
 
160
  media_type="audio/wav",
161
  headers={"Content-Disposition": "inline; filename=output.wav"}
162
  )