bamara-tts / app.py
Gaoussin's picture
Update app.py
882c6ec verified
import os
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
os.makedirs("/tmp/hf", exist_ok=True)
from fastapi import FastAPI, Query
from fastapi.responses import StreamingResponse
from transformers import VitsModel, AutoTokenizer
import torch, scipy.io.wavfile as wavfile
import io
import edge_tts
app = FastAPI(title="Bambara TTS API")
# Load model once at startup
model = VitsModel.from_pretrained("facebook/mms-tts-bam")
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
sampling_rate = model.config.sampling_rate
@app.get("/tts/")
async def tts(text: str = Query(..., description="Bambara text to synthesize")):
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to("cpu") for k, v in inputs.items()}
with torch.no_grad():
output = model(**inputs).waveform
waveform = output[0]
# Stream audio instead of saving to disk
buffer = io.BytesIO()
wavfile.write(buffer, rate=sampling_rate, data=waveform.numpy())
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")
@app.get("/noneBmTts/")
async def noneBmTts(
text: str = Query(..., description="Text to synthesize"),
voice: str = Query(
"fr-FR-DeniseNeural", description="Voice ID (e.g., en-US-GuyNeural)"
),
):
try:
# Create the Communicate object with the requested text and voice
communicate = edge_tts.Communicate(text, voice)
buffer = io.BytesIO()
# Stream the audio chunks into the buffer
async for chunk in communicate.stream():
if chunk["type"] == "audio":
buffer.write(chunk["data"])
# Check if we actually got data
if buffer.tell() == 0:
raise HTTPException(
status_code=400, detail="Synthesis failed to produce audio."
)
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/mpeg")
except Exception as e:
# Catch errors like invalid voice names
raise HTTPException(status_code=400, detail=str(e))