Gaoussin commited on
Commit
e578353
·
verified ·
1 Parent(s): a799512

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -9,13 +9,10 @@ from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import StreamingResponse
10
  from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
11
 
12
- # 1. Set cache before importing/loading models
13
  os.environ["HF_HOME"] = "/tmp/hf"
14
- os.makedirs("/tmp/hf", exist_ok=True)
15
-
16
  app = FastAPI(title="Bambara AI API")
17
 
18
- # CRITICAL: Allow your frontend to talk to your HF Space
19
  app.add_middleware(
20
  CORSMiddleware,
21
  allow_origins=["*"],
@@ -24,55 +21,63 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # 2. Load Models (Memory Efficient)
28
- # Use .to("cpu") explicitly if you don't have a GPU on the free tier
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
-
31
- # TTS Model
32
- tts_model_id = "facebook/mms-tts-bam"
33
- tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_id)
34
- tts_model = VitsModel.from_pretrained(tts_model_id).to(device)
35
 
36
- # ASR (Speech-to-Text) Model
37
- asr_model_id = "facebook/mms-1b-all"
 
38
  asr_processor = AutoProcessor.from_pretrained(asr_model_id)
39
  asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id).to(device)
40
 
41
- # Pre-load the Bambara adapter so it doesn't slow down the first request
42
  asr_processor.tokenizer.set_target_lang("bam")
43
  asr_model.load_adapter("bam")
44
 
45
- @app.get("/tts/")
46
- async def tts(text: str = Query(..., description="Bambara text")):
47
- inputs = tts_tokenizer(text, return_tensors="pt").to(device)
48
- with torch.no_grad():
49
- output = tts_model(**inputs).waveform
50
-
51
- buffer = io.BytesIO()
52
- wavfile.write(buffer, rate=tts_model.config.sampling_rate, data=output[0].cpu().numpy())
53
- buffer.seek(0)
54
- return StreamingResponse(buffer, media_type="audio/wav")
55
 
56
  @app.post("/transcribe")
57
  async def transcribe(audio_file: UploadFile = File(...)):
58
  try:
59
- # Read and load audio
60
- audio_bytes = await audio_file.read()
61
- audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
 
 
 
 
 
62
 
63
- # Prepare inputs
64
  inputs = asr_processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
65
 
66
- with torch.no_grad():
 
67
  logits = asr_model(**inputs).logits
68
 
 
69
  predicted_ids = torch.argmax(logits, dim=-1)
70
  transcription = asr_processor.batch_decode(predicted_ids)[0]
71
 
72
  return {"text": transcription}
 
73
  except Exception as e:
 
74
  raise HTTPException(status_code=500, detail=str(e))
75
 
 
 
 
 
 
 
 
 
 
 
 
76
  @app.get("/noneBmTts/")
77
  async def noneBmTts(text: str, voice: str = "fr-FR-DeniseNeural"):
78
  communicate = edge_tts.Communicate(text, voice)
 
9
  from fastapi.responses import StreamingResponse
10
  from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
11
 
12
+ # 1. Environment and App Setup
13
  os.environ["HF_HOME"] = "/tmp/hf"
 
 
14
  app = FastAPI(title="Bambara AI API")
15
 
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
 
21
  allow_headers=["*"],
22
  )
23
 
24
+ device = "cpu"
 
 
 
 
 
 
 
25
 
26
+ # 2. Load Models (Switching to 300M for stability)
27
+ # ASR Model
28
+ asr_model_id = "facebook/mms-300m-1107" # Smaller, faster, more stable
29
  asr_processor = AutoProcessor.from_pretrained(asr_model_id)
30
  asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id).to(device)
31
 
32
+ # Load Bambara Adapter
33
  asr_processor.tokenizer.set_target_lang("bam")
34
  asr_model.load_adapter("bam")
35
 
36
+ # TTS Model
37
+ tts_model_id = "facebook/mms-tts-bam"
38
+ tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_id)
39
+ tts_model = VitsModel.from_pretrained(tts_model_id).to(device)
 
 
 
 
 
 
40
 
41
  @app.post("/transcribe")
42
  async def transcribe(audio_file: UploadFile = File(...)):
43
  try:
44
+ # Read file
45
+ content = await audio_file.read()
46
+ if not content:
47
+ raise HTTPException(status_code=400, detail="Empty audio file")
48
+
49
+ # Load audio into memory
50
+ # Resampling here to 16kHz is mandatory
51
+ audio_data, _ = librosa.load(io.BytesIO(content), sr=16000)
52
 
53
+ # Prepare for model
54
  inputs = asr_processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
55
 
56
+ # Inference
57
+ with torch.inference_mode():
58
  logits = asr_model(**inputs).logits
59
 
60
+ # Decode
61
  predicted_ids = torch.argmax(logits, dim=-1)
62
  transcription = asr_processor.batch_decode(predicted_ids)[0]
63
 
64
  return {"text": transcription}
65
+
66
  except Exception as e:
67
+ print(f"Error: {e}")
68
  raise HTTPException(status_code=500, detail=str(e))
69
 
70
+ @app.get("/tts/")
71
+ async def tts(text: str = Query(..., description="Bambara text")):
72
+ inputs = tts_tokenizer(text, return_tensors="pt").to(device)
73
+ with torch.inference_mode():
74
+ output = tts_model(**inputs).waveform
75
+
76
+ buffer = io.BytesIO()
77
+ wavfile.write(buffer, rate=tts_model.config.sampling_rate, data=output[0].cpu().numpy())
78
+ buffer.seek(0)
79
+ return StreamingResponse(buffer, media_type="audio/wav")
80
+
81
  @app.get("/noneBmTts/")
82
  async def noneBmTts(text: str, voice: str = "fr-FR-DeniseNeural"):
83
  communicate = edge_tts.Communicate(text, voice)