Gaoussin commited on
Commit
8bb18e8
·
verified ·
1 Parent(s): 7642f85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -59
app.py CHANGED
@@ -1,71 +1,43 @@
1
- import os
2
-
3
- os.environ["HF_HOME"] = "/tmp/hf"
4
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
- os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
- os.makedirs("/tmp/hf", exist_ok=True)
7
-
8
- from fastapi import FastAPI, Query
9
- from fastapi.responses import StreamingResponse
10
- from transformers import VitsModel, AutoTokenizer
11
- import torch, scipy.io.wavfile as wavfile
12
  import io
13
- import edge_tts
14
-
15
-
16
- app = FastAPI(title="Bambara TTS API")
17
-
18
- # Load model once at startup
19
- model = VitsModel.from_pretrained("facebook/mms-tts-bam")
20
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
21
- sampling_rate = model.config.sampling_rate
22
-
23
-
24
- @app.get("/tts/")
25
- async def tts(text: str = Query(..., description="Bambara text to synthesize")):
26
- inputs = tokenizer(text, return_tensors="pt")
27
- inputs = {k: v.to("cpu") for k, v in inputs.items()}
28
-
29
- with torch.no_grad():
30
- output = model(**inputs).waveform
31
-
32
- waveform = output[0]
33
 
34
- # Stream audio instead of saving to disk
35
- buffer = io.BytesIO()
36
- wavfile.write(buffer, rate=sampling_rate, data=waveform.numpy())
37
- buffer.seek(0)
38
 
39
- return StreamingResponse(buffer, media_type="audio/wav")
 
 
 
40
 
 
 
 
 
 
41
 
42
- @app.get("/noneBmTts/")
43
- async def noneBmTts(
44
- text: str = Query(..., description="Text to synthesize"),
45
- voice: str = Query(
46
- "fr-FR-DeniseNeural", description="Voice ID (e.g., en-US-GuyNeural)"
47
- ),
48
- ):
49
  try:
50
- # Create the Communicate object with the requested text and voice
51
- communicate = edge_tts.Communicate(text, voice)
 
 
 
 
52
 
53
- buffer = io.BytesIO()
 
 
54
 
55
- # Stream the audio chunks into the buffer
56
- async for chunk in communicate.stream():
57
- if chunk["type"] == "audio":
58
- buffer.write(chunk["data"])
59
 
60
- # Check if we actually got data
61
- if buffer.tell() == 0:
62
- raise HTTPException(
63
- status_code=400, detail="Synthesis failed to produce audio."
64
- )
65
 
66
- buffer.seek(0)
67
- return StreamingResponse(buffer, media_type="audio/mpeg")
68
 
69
  except Exception as e:
70
- # Catch errors like invalid voice names
71
- raise HTTPException(status_code=400, detail=str(e))
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
3
+ import torch
4
+ import librosa
 
 
 
 
 
 
 
5
  import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ app = FastAPI()
 
 
 
8
 
9
+ # Load model and processor once at startup
10
+ MODEL_ID = "facebook/mms-1b-all"
11
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
12
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
13
 
14
+ @app.post("/transcribe")
15
+ async def transcribe(audio_file: UploadFile = File(...)):
16
+ # 1. Check if a file was actually uploaded
17
+ if not audio_file:
18
+ raise HTTPException(status_code=400, detail="No file uploaded")
19
 
 
 
 
 
 
 
 
20
  try:
21
+ # 2. Read the file into memory
22
+ audio_bytes = await audio_file.read()
23
+
24
+ # 3. Load and Resample to 16,000 Hz using librosa
25
+ # io.BytesIO(audio_bytes) lets librosa treat the bytes like a file
26
+ audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
27
 
28
+ # 4. Setup Bambara Adapter
29
+ processor.tokenizer.set_target_lang("bam")
30
+ model.load_adapter("bam")
31
 
32
+ # 5. Perform Inference
33
+ inputs = processor(audio_data, sampling_rate=16_000, return_tensors="pt")
34
+ with torch.no_grad():
35
+ logits = model(**inputs).logits
36
 
37
+ predicted_ids = torch.argmax(logits, dim=-1)
38
+ transcription = processor.batch_decode(predicted_ids)[0]
 
 
 
39
 
40
+ return {"text": transcription}
 
41
 
42
  except Exception as e:
43
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")