Update app.py
Browse files
app.py
CHANGED
|
@@ -34,6 +34,12 @@ from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
|
| 34 |
# ===============================
|
| 35 |
MODEL = None
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# ===============================
|
| 38 |
# MODEL LOADER
|
| 39 |
# ===============================
|
|
@@ -90,6 +96,17 @@ def health():
|
|
| 90 |
"cuda_available": torch.cuda.is_available()
|
| 91 |
}
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
# ===============================
|
| 94 |
# TTS INPUT SCHEMA
|
| 95 |
# ===============================
|
|
@@ -103,6 +120,13 @@ class TTSPayload(BaseModel):
|
|
| 103 |
# ===============================
|
| 104 |
@app.post("/tts")
|
| 105 |
def generate_tts(payload: TTSPayload):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
model = get_or_load_model()
|
| 107 |
|
| 108 |
# Determine final text
|
|
@@ -126,21 +150,22 @@ def generate_tts(payload: TTSPayload):
|
|
| 126 |
sr = model.sr
|
| 127 |
|
| 128 |
# Convert numpy -> WAV bytes
|
| 129 |
-
import io
|
| 130 |
-
from scipy.io.wavfile import write as write_wav
|
| 131 |
buf = io.BytesIO()
|
| 132 |
write_wav(buf, sr, wav.astype(np.float32))
|
| 133 |
buf.seek(0)
|
| 134 |
audio_bytes = buf.read()
|
| 135 |
|
|
|
|
|
|
|
|
|
|
| 136 |
# Return as base64
|
| 137 |
-
import base64
|
| 138 |
return {
|
| 139 |
"sr": sr,
|
| 140 |
-
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8")
|
|
|
|
|
|
|
| 141 |
}
|
| 142 |
|
| 143 |
-
|
| 144 |
# ===============================
|
| 145 |
# RUN: uvicorn app:app --host 0.0.0.0 --port 7860
|
| 146 |
# ===============================
|
|
|
|
| 34 |
# ===============================
|
| 35 |
MODEL = None
|
| 36 |
|
| 37 |
+
# ===============================
|
| 38 |
+
# MAX QUOTA (from ENV)
|
| 39 |
+
# ===============================
|
| 40 |
+
TTS_MAX_QUOTA = int(os.getenv("TTS_MAX_QUOTA", 10)) # default 10 requests/day
|
| 41 |
+
tts_usage = 0 # simple in-memory counter for demo
|
| 42 |
+
|
| 43 |
# ===============================
|
| 44 |
# MODEL LOADER
|
| 45 |
# ===============================
|
|
|
|
| 96 |
"cuda_available": torch.cuda.is_available()
|
| 97 |
}
|
| 98 |
|
| 99 |
+
# ===============================
|
| 100 |
+
# QUOTA INFO
|
| 101 |
+
# ===============================
|
| 102 |
+
@app.get("/quota")
|
| 103 |
+
def get_quota():
|
| 104 |
+
return {
|
| 105 |
+
"used": tts_usage,
|
| 106 |
+
"limit": TTS_MAX_QUOTA,
|
| 107 |
+
"remaining": max(0, TTS_MAX_QUOTA - tts_usage)
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
# ===============================
|
| 111 |
# TTS INPUT SCHEMA
|
| 112 |
# ===============================
|
|
|
|
| 120 |
# ===============================
|
| 121 |
@app.post("/tts")
|
| 122 |
def generate_tts(payload: TTSPayload):
|
| 123 |
+
global tts_usage
|
| 124 |
+
if tts_usage >= TTS_MAX_QUOTA:
|
| 125 |
+
return {
|
| 126 |
+
"error": "Quota exceeded",
|
| 127 |
+
"message": f"Daily limit of {TTS_MAX_QUOTA} TTS requests reached. Try again tomorrow."
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
model = get_or_load_model()
|
| 131 |
|
| 132 |
# Determine final text
|
|
|
|
| 150 |
sr = model.sr
|
| 151 |
|
| 152 |
# Convert numpy -> WAV bytes
|
|
|
|
|
|
|
| 153 |
buf = io.BytesIO()
|
| 154 |
write_wav(buf, sr, wav.astype(np.float32))
|
| 155 |
buf.seek(0)
|
| 156 |
audio_bytes = buf.read()
|
| 157 |
|
| 158 |
+
# Increment quota usage
|
| 159 |
+
tts_usage += 1
|
| 160 |
+
|
| 161 |
# Return as base64
|
|
|
|
| 162 |
return {
|
| 163 |
"sr": sr,
|
| 164 |
+
"audio_base64": base64.b64encode(audio_bytes).decode("utf-8"),
|
| 165 |
+
"quota_used": tts_usage,
|
| 166 |
+
"quota_limit": TTS_MAX_QUOTA
|
| 167 |
}
|
| 168 |
|
|
|
|
| 169 |
# ===============================
|
| 170 |
# RUN: uvicorn app:app --host 0.0.0.0 --port 7860
|
| 171 |
# ===============================
|