Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,11 +5,12 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
|
|
| 5 |
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
|
| 6 |
os.makedirs("/tmp/hf", exist_ok=True)
|
| 7 |
|
| 8 |
-
from fastapi import FastAPI, Query
|
| 9 |
from fastapi.responses import StreamingResponse
|
| 10 |
-
from transformers import VitsModel, AutoTokenizer
|
| 11 |
import torch, scipy.io.wavfile as wavfile
|
| 12 |
import io
|
|
|
|
| 13 |
import edge_tts
|
| 14 |
|
| 15 |
|
|
@@ -19,6 +20,10 @@ app = FastAPI(title="Bambara TTS API")
|
|
| 19 |
model = VitsModel.from_pretrained("facebook/mms-tts-bam")
|
| 20 |
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
|
| 21 |
sampling_rate = model.config.sampling_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@app.get("/tts/")
|
|
@@ -69,3 +74,37 @@ async def noneBmTts(
|
|
| 69 |
except Exception as e:
|
| 70 |
# Catch errors like invalid voice names
|
| 71 |
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
|
| 6 |
os.makedirs("/tmp/hf", exist_ok=True)
|
| 7 |
|
| 8 |
+
from fastapi import FastAPI, Query, File, UploadFile, HTTPException
|
| 9 |
from fastapi.responses import StreamingResponse
|
| 10 |
+
from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
|
| 11 |
import torch, scipy.io.wavfile as wavfile
|
| 12 |
import io
|
| 13 |
+
import librosa
|
| 14 |
import edge_tts
|
| 15 |
|
| 16 |
|
|
|
|
| 20 |
model = VitsModel.from_pretrained("facebook/mms-tts-bam")
|
| 21 |
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
|
| 22 |
sampling_rate = model.config.sampling_rate
|
| 23 |
+
# Load model once when the server starts
|
| 24 |
+
speech_model_id = "facebook/mms-1b-all"
|
| 25 |
+
processor = AutoProcessor.from_pretrained(speech_model_id)
|
| 26 |
+
speech_model = Wav2Vec2ForCTC.from_pretrained(speech_model_id)
|
| 27 |
|
| 28 |
|
| 29 |
@app.get("/tts/")
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
# Catch errors like invalid voice names
|
| 76 |
raise HTTPException(status_code=400, detail=str(e))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@app.post("/transcribe")
|
| 82 |
+
async def transcribe(audio_file: UploadFile = File(...)):
|
| 83 |
+
# 1. Check if a file was actually uploaded
|
| 84 |
+
if not audio_file:
|
| 85 |
+
raise HTTPException(status_code=400, detail="No file uploaded")
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# 2. Read the file into memory
|
| 89 |
+
audio_bytes = await audio_file.read()
|
| 90 |
+
|
| 91 |
+
# 3. Load and Resample to 16,000 Hz using librosa
|
| 92 |
+
# io.BytesIO(audio_bytes) lets librosa treat the bytes like a file
|
| 93 |
+
audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
|
| 94 |
+
|
| 95 |
+
# 4. Setup Bambara Adapter
|
| 96 |
+
processor.tokenizer.set_target_lang("bam")
|
| 97 |
+
model.load_adapter("bam")
|
| 98 |
+
|
| 99 |
+
# 5. Perform Inference
|
| 100 |
+
inputs = processor(audio_data, sampling_rate=16_000, return_tensors="pt")
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
logits = speech_model(**inputs).logits
|
| 103 |
+
|
| 104 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 105 |
+
transcription = processor.batch_decode(predicted_ids)[0]
|
| 106 |
+
|
| 107 |
+
return {"text": transcription}
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
|