rahul7star commited on
Commit
4fadd0f
·
verified ·
1 Parent(s): d334bcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -3
app.py CHANGED
@@ -21,7 +21,11 @@ torch.cuda.is_available = lambda: False
21
  # STANDARD IMPORTS
22
  # ===============================
23
  from fastapi import FastAPI
24
- from contextlib import asynccontextmanager
 
 
 
 
25
 
26
  from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
27
 
@@ -42,14 +46,36 @@ def get_or_load_model():
42
  return MODEL
43
 
44
  # ===============================
45
- # FASTAPI LIFESPAN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # ===============================
 
 
47
  @asynccontextmanager
48
  async def lifespan(app: FastAPI):
49
  # Warmup on startup
50
  get_or_load_model()
51
  yield
52
- # (no shutdown logic needed)
53
 
54
  app = FastAPI(lifespan=lifespan)
55
 
@@ -63,3 +89,51 @@ def health():
63
  "device": "cpu",
64
  "cuda_available": torch.cuda.is_available()
65
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # STANDARD IMPORTS
22
  # ===============================
23
  from fastapi import FastAPI
24
+ from pydantic import BaseModel
25
+ import base64
26
+ import numpy as np
27
+ import io
28
+ from scipy.io.wavfile import write as write_wav
29
 
30
  from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
31
 
 
46
  return MODEL
47
 
48
  # ===============================
49
+ # SINGING FORMATTER
50
+ # ===============================
51
+ def format_for_singing(lyrics: str) -> str:
52
+ lines = []
53
+ for line in lyrics.splitlines():
54
+ line = line.strip()
55
+ if not line:
56
+ continue
57
+ # Stretch vowels lightly
58
+ line = (
59
+ line.replace("a", "aa")
60
+ .replace("e", "ee")
61
+ .replace("i", "ii")
62
+ .replace("o", "oo")
63
+ .replace("u", "uu")
64
+ )
65
+ lines.append(f"{line} ♪ ...")
66
+ return "\n".join(lines)
67
+
68
+ # ===============================
69
+ # FASTAPI APP + LIFESPAN
70
  # ===============================
71
+ from contextlib import asynccontextmanager
72
+
73
  @asynccontextmanager
74
  async def lifespan(app: FastAPI):
75
  # Warmup on startup
76
  get_or_load_model()
77
  yield
78
+ # No shutdown logic needed
79
 
80
  app = FastAPI(lifespan=lifespan)
81
 
 
89
  "device": "cpu",
90
  "cuda_available": torch.cuda.is_available()
91
  }
92
+
93
+ # ===============================
94
+ # TTS INPUT SCHEMA
95
+ # ===============================
96
+ class TTSPayload(BaseModel):
97
+ text: str
98
+ language_id: str = "en"
99
+ mode: str = "Speak 🗣️" # or "Sing 🎵"
100
+
101
+ # ===============================
102
+ # TTS ENDPOINT
103
+ # ===============================
104
+ @app.post("/tts")
105
+ def generate_tts(payload: TTSPayload):
106
+ model = get_or_load_model()
107
+
108
+ # Determine final text
109
+ if payload.mode == "Sing 🎵":
110
+ if not payload.text.strip():
111
+ return {"error": "Lyrics required for Sing mode."}
112
+ final_text = format_for_singing(payload.text)
113
+ else:
114
+ if not payload.text.strip():
115
+ return {"error": "Text required for Speak mode."}
116
+ final_text = payload.text
117
+
118
+ # CPU-safe inference
119
+ with torch.no_grad():
120
+ sr, wav = model.generate(
121
+ final_text[:300],
122
+ language_id=payload.language_id,
123
+ )
124
+
125
+ # Convert numpy -> WAV bytes
126
+ buf = io.BytesIO()
127
+ write_wav(buf, sr, wav.astype(np.float32))
128
+ buf.seek(0)
129
+ audio_bytes = buf.read()
130
+
131
+ # Return as base64
132
+ return {
133
+ "sr": sr,
134
+ "audio_base64": base64.b64encode(audio_bytes).decode("utf-8")
135
+ }
136
+
137
+ # ===============================
138
+ # RUN: uvicorn app:app --host 0.0.0.0 --port 7860
139
+ # ===============================