Gaoussin commited on
Commit
673b955
·
verified ·
1 Parent(s): 5e43fc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -2
app.py CHANGED
@@ -5,11 +5,12 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
  os.makedirs("/tmp/hf", exist_ok=True)
7
 
8
- from fastapi import FastAPI, Query
9
  from fastapi.responses import StreamingResponse
10
- from transformers import VitsModel, AutoTokenizer
11
  import torch, scipy.io.wavfile as wavfile
12
  import io
 
13
  import edge_tts
14
 
15
 
@@ -19,6 +20,10 @@ app = FastAPI(title="Bambara TTS API")
19
  model = VitsModel.from_pretrained("facebook/mms-tts-bam")
20
  tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
21
  sampling_rate = model.config.sampling_rate
 
 
 
 
22
 
23
 
24
  @app.get("/tts/")
@@ -69,3 +74,37 @@ async def noneBmTts(
69
  except Exception as e:
70
  # Catch errors like invalid voice names
71
  raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
  os.makedirs("/tmp/hf", exist_ok=True)
7
 
8
+ from fastapi import FastAPI, Query, File, UploadFile, HTTPException
9
  from fastapi.responses import StreamingResponse
10
+ from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
11
  import torch, scipy.io.wavfile as wavfile
12
  import io
13
+ import librosa
14
  import edge_tts
15
 
16
 
 
20
  model = VitsModel.from_pretrained("facebook/mms-tts-bam")
21
  tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
22
  sampling_rate = model.config.sampling_rate
23
+ # Load model once when the server starts
24
+ speech_model_id = "facebook/mms-1b-all"
25
+ processor = AutoProcessor.from_pretrained(speech_model_id)
26
+ speech_model = Wav2Vec2ForCTC.from_pretrained(speech_model_id)
27
 
28
 
29
  @app.get("/tts/")
 
74
  except Exception as e:
75
  # Catch errors like invalid voice names
76
  raise HTTPException(status_code=400, detail=str(e))
77
+
78
+
79
+
80
+
81
+ @app.post("/transcribe")
82
+ async def transcribe(audio_file: UploadFile = File(...)):
83
+ # 1. Check if a file was actually uploaded
84
+ if not audio_file:
85
+ raise HTTPException(status_code=400, detail="No file uploaded")
86
+
87
+ try:
88
+ # 2. Read the file into memory
89
+ audio_bytes = await audio_file.read()
90
+
91
+ # 3. Load and Resample to 16,000 Hz using librosa
92
+ # io.BytesIO(audio_bytes) lets librosa treat the bytes like a file
93
+ audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
94
+
95
+ # 4. Setup Bambara Adapter
96
+ processor.tokenizer.set_target_lang("bam")
97
+ model.load_adapter("bam")
98
+
99
+ # 5. Perform Inference
100
+ inputs = processor(audio_data, sampling_rate=16_000, return_tensors="pt")
101
+ with torch.no_grad():
102
+ logits = speech_model(**inputs).logits
103
+
104
+ predicted_ids = torch.argmax(logits, dim=-1)
105
+ transcription = processor.batch_decode(predicted_ids)[0]
106
+
107
+ return {"text": transcription}
108
+
109
+ except Exception as e:
110
+ raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")