rahul7star commited on
Commit
d334bcd
·
verified ·
1 Parent(s): 851663b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -104
app.py CHANGED
@@ -1,131 +1,65 @@
1
- import os
2
- import uuid
3
- import torch
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- from fastapi.responses import FileResponse, HTMLResponse
7
-
8
- from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
9
  import os
10
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
11
  os.environ["TORCH_FORCE_CPU"] = "1"
12
 
13
  import torch
14
 
15
- # ===============================
16
- # HARD FORCE CPU torch.load
17
- # ===============================
18
  _original_torch_load = torch.load
19
 
20
  def cpu_only_torch_load(*args, **kwargs):
21
- # Force CPU regardless of how torch.load is called
22
  kwargs["map_location"] = torch.device("cpu")
23
  return _original_torch_load(*args, **kwargs)
24
 
25
  torch.load = cpu_only_torch_load
26
-
27
- # Extra safety: disable CUDA completely
28
  torch.cuda.is_available = lambda: False
29
 
30
- # -------------------------------------------------
31
- # App
32
- # -------------------------------------------------
33
- app = FastAPI(title="Chatterbox Multilingual TTS")
34
-
35
- # -------------------------------------------------
36
- # Globals (model loaded once)
37
- # -------------------------------------------------
38
- MODEL = None
39
- OUTPUT_DIR = "/tmp/tts_outputs"
40
- os.makedirs(OUTPUT_DIR, exist_ok=True)
41
-
42
- # -------------------------------------------------
43
- # Request schema
44
- # -------------------------------------------------
45
- class TTSRequest(BaseModel):
46
- text: str
47
- language: str = "en" # "en" or "hi"
48
- speaker: str | None = None
49
 
 
50
 
51
- # -------------------------------------------------
52
- # Model loader (NO .eval())
53
- # -------------------------------------------------
54
  MODEL = None
55
 
 
 
 
56
  def get_or_load_model():
57
  global MODEL
58
-
59
  if MODEL is None:
60
- print("🔄 Loading ChatterboxMultilingualTTS (CPU-only)")
61
-
62
- # ✅ THIS is the ONLY valid loader
63
  MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
64
-
65
- # Chatterbox is NOT torch.nn.Module → no .to()
66
- MODEL.eval()
67
-
68
- print("✅ Chatterbox model loaded successfully")
69
-
70
  return MODEL
71
 
72
-
73
-
74
- # -------------------------------------------------
75
- # API: TTS
76
- # -------------------------------------------------
77
- @app.post("/tts")
78
- def tts(req: TTSRequest):
79
- if req.language not in SUPPORTED_LANGUAGES:
80
- return {
81
- "error": f"Unsupported language. Supported: {SUPPORTED_LANGUAGES}"
82
- }
83
-
84
- model = get_or_load_model()
85
- out_path = os.path.join(OUTPUT_DIR, f"{uuid.uuid4().hex}.wav")
86
-
87
- # ✅ Correct inference pattern
88
- with torch.inference_mode():
89
- audio = model.tts(
90
- text=req.text,
91
- language=req.language,
92
- speaker=req.speaker,
93
- output_path=out_path,
94
- )
95
-
96
- return FileResponse(
97
- out_path,
98
- media_type="audio/wav",
99
- filename="speech.wav",
100
- )
101
-
102
-
103
- # -------------------------------------------------
104
- # Simple UI (for quick testing)
105
- # -------------------------------------------------
106
- @app.get("/", response_class=HTMLResponse)
107
- def ui():
108
- return """
109
- <html>
110
- <body>
111
- <h2>Chatterbox Multilingual TTS</h2>
112
- <form action="/tts" method="post">
113
- <textarea name="text" rows="4" cols="60">Hello, how are you?</textarea><br><br>
114
- <select name="language">
115
- <option value="en">English</option>
116
- <option value="hi">Hindi</option>
117
- </select><br><br>
118
- <button type="submit">Generate Speech</button>
119
- </form>
120
- </body>
121
- </html>
122
- """
123
-
124
-
125
- # -------------------------------------------------
126
- # Warm-up (optional, safe)
127
- # -------------------------------------------------
128
- @app.on_event("startup")
129
- def warmup():
130
  get_or_load_model()
 
 
 
 
131
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===============================
2
+ # FORCE CPU ONLY (VERY TOP)
3
+ # ===============================
 
 
 
 
 
4
  import os
5
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
6
  os.environ["TORCH_FORCE_CPU"] = "1"
7
 
8
  import torch
9
 
10
+ # ---- HARD FORCE torch.load → CPU ----
 
 
11
  _original_torch_load = torch.load
12
 
13
  def cpu_only_torch_load(*args, **kwargs):
 
14
  kwargs["map_location"] = torch.device("cpu")
15
  return _original_torch_load(*args, **kwargs)
16
 
17
  torch.load = cpu_only_torch_load
 
 
18
  torch.cuda.is_available = lambda: False
19
 
20
+ # ===============================
21
+ # STANDARD IMPORTS
22
+ # ===============================
23
+ from fastapi import FastAPI
24
+ from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
27
 
28
+ # ===============================
29
+ # GLOBAL MODEL CACHE
30
+ # ===============================
31
  MODEL = None
32
 
33
+ # ===============================
34
+ # MODEL LOADER
35
+ # ===============================
36
  def get_or_load_model():
37
  global MODEL
 
38
  if MODEL is None:
39
+ print("🔄 Loading ChatterboxMultilingualTTS (CPU ONLY)")
 
 
40
  MODEL = ChatterboxMultilingualTTS.from_pretrained("cpu")
41
+ print("✅ Model loaded on CPU")
 
 
 
 
 
42
  return MODEL
43
 
44
+ # ===============================
45
+ # FASTAPI LIFESPAN
46
+ # ===============================
47
+ @asynccontextmanager
48
+ async def lifespan(app: FastAPI):
49
+ # Warmup on startup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  get_or_load_model()
51
+ yield
52
+ # (no shutdown logic needed)
53
+
54
+ app = FastAPI(lifespan=lifespan)
55
 
56
+ # ===============================
57
+ # HEALTH CHECK
58
+ # ===============================
59
+ @app.get("/health")
60
+ def health():
61
+ return {
62
+ "status": "ok",
63
+ "device": "cpu",
64
+ "cuda_available": torch.cuda.is_available()
65
+ }