Gaoussin commited on
Commit
f87e51e
·
verified ·
1 Parent(s): 673b955

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -82
app.py CHANGED
@@ -1,110 +1,74 @@
1
  import os
 
 
 
 
 
 
 
 
2
 
 
3
  os.environ["HF_HOME"] = "/tmp/hf"
4
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
- os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
  os.makedirs("/tmp/hf", exist_ok=True)
7
 
8
- from fastapi import FastAPI, Query, File, UploadFile, HTTPException
9
- from fastapi.responses import StreamingResponse
10
- from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
11
- import torch, scipy.io.wavfile as wavfile
12
- import io
13
- import librosa
14
- import edge_tts
15
 
 
 
 
16
 
17
- app = FastAPI(title="Bambara TTS API")
 
 
 
18
 
19
- # Load model once at startup
20
- model = VitsModel.from_pretrained("facebook/mms-tts-bam")
21
- tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
22
- sampling_rate = model.config.sampling_rate
23
- # Load model once when the server starts
24
- speech_model_id = "facebook/mms-1b-all"
25
- processor = AutoProcessor.from_pretrained(speech_model_id)
26
- speech_model = Wav2Vec2ForCTC.from_pretrained(speech_model_id)
27
 
 
 
 
28
 
29
  @app.get("/tts/")
30
- async def tts(text: str = Query(..., description="Bambara text to synthesize")):
31
- inputs = tokenizer(text, return_tensors="pt")
32
- inputs = {k: v.to("cpu") for k, v in inputs.items()}
33
-
34
  with torch.no_grad():
35
- output = model(**inputs).waveform
36
-
37
- waveform = output[0]
38
-
39
- # Stream audio instead of saving to disk
40
  buffer = io.BytesIO()
41
- wavfile.write(buffer, rate=sampling_rate, data=waveform.numpy())
42
  buffer.seek(0)
43
-
44
  return StreamingResponse(buffer, media_type="audio/wav")
45
 
46
-
47
- @app.get("/noneBmTts/")
48
- async def noneBmTts(
49
- text: str = Query(..., description="Text to synthesize"),
50
- voice: str = Query(
51
- "fr-FR-DeniseNeural", description="Voice ID (e.g., en-US-GuyNeural)"
52
- ),
53
- ):
54
- try:
55
- # Create the Communicate object with the requested text and voice
56
- communicate = edge_tts.Communicate(text, voice)
57
-
58
- buffer = io.BytesIO()
59
-
60
- # Stream the audio chunks into the buffer
61
- async for chunk in communicate.stream():
62
- if chunk["type"] == "audio":
63
- buffer.write(chunk["data"])
64
-
65
- # Check if we actually got data
66
- if buffer.tell() == 0:
67
- raise HTTPException(
68
- status_code=400, detail="Synthesis failed to produce audio."
69
- )
70
-
71
- buffer.seek(0)
72
- return StreamingResponse(buffer, media_type="audio/mpeg")
73
-
74
- except Exception as e:
75
- # Catch errors like invalid voice names
76
- raise HTTPException(status_code=400, detail=str(e))
77
-
78
-
79
-
80
-
81
  @app.post("/transcribe")
82
  async def transcribe(audio_file: UploadFile = File(...)):
83
- # 1. Check if a file was actually uploaded
84
- if not audio_file:
85
- raise HTTPException(status_code=400, detail="No file uploaded")
86
-
87
  try:
88
- # 2. Read the file into memory
89
  audio_bytes = await audio_file.read()
90
-
91
- # 3. Load and Resample to 16,000 Hz using librosa
92
- # io.BytesIO(audio_bytes) lets librosa treat the bytes like a file
93
  audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
94
 
95
- # 4. Setup Bambara Adapter
96
- processor.tokenizer.set_target_lang("bam")
97
- model.load_adapter("bam")
98
-
99
- # 5. Perform Inference
100
- inputs = processor(audio_data, sampling_rate=16_000, return_tensors="pt")
101
  with torch.no_grad():
102
- logits = speech_model(**inputs).logits
103
 
104
  predicted_ids = torch.argmax(logits, dim=-1)
105
- transcription = processor.batch_decode(predicted_ids)[0]
106
 
107
  return {"text": transcription}
108
-
109
  except Exception as e:
110
- raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
+ import torch
4
+ import librosa
5
+ import edge_tts
6
+ import scipy.io.wavfile as wavfile
7
+ from fastapi import FastAPI, Query, File, UploadFile, HTTPException
8
+ from fastapi.responses import StreamingResponse
9
+ from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
10
 
11
+ # 1. Set cache before importing/loading models
12
  os.environ["HF_HOME"] = "/tmp/hf"
 
 
13
  os.makedirs("/tmp/hf", exist_ok=True)
14
 
15
+ app = FastAPI(title="Bambara AI API")
 
 
 
 
 
 
16
 
17
+ # 2. Load Models (Memory Efficient)
18
+ # Use .to("cpu") explicitly if you don't have a GPU on the free tier
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # TTS Model
22
+ tts_model_id = "facebook/mms-tts-bam"
23
+ tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_id)
24
+ tts_model = VitsModel.from_pretrained(tts_model_id).to(device)
25
 
26
+ # ASR (Speech-to-Text) Model
27
+ asr_model_id = "facebook/mms-1b-all"
28
+ asr_processor = AutoProcessor.from_pretrained(asr_model_id)
29
+ asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id).to(device)
 
 
 
 
30
 
31
+ # Pre-load the Bambara adapter so it doesn't slow down the first request
32
+ asr_processor.tokenizer.set_target_lang("bam")
33
+ asr_model.load_adapter("bam")
34
 
35
  @app.get("/tts/")
36
+ async def tts(text: str = Query(..., description="Bambara text")):
37
+ inputs = tts_tokenizer(text, return_tensors="pt").to(device)
 
 
38
  with torch.no_grad():
39
+ output = tts_model(**inputs).waveform
40
+
 
 
 
41
  buffer = io.BytesIO()
42
+ wavfile.write(buffer, rate=tts_model.config.sampling_rate, data=output[0].cpu().numpy())
43
  buffer.seek(0)
 
44
  return StreamingResponse(buffer, media_type="audio/wav")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @app.post("/transcribe")
47
  async def transcribe(audio_file: UploadFile = File(...)):
 
 
 
 
48
  try:
49
+ # Read and load audio
50
  audio_bytes = await audio_file.read()
 
 
 
51
  audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
52
 
53
+ # Prepare inputs
54
+ inputs = asr_processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
55
+
 
 
 
56
  with torch.no_grad():
57
+ logits = asr_model(**inputs).logits
58
 
59
  predicted_ids = torch.argmax(logits, dim=-1)
60
+ transcription = asr_processor.batch_decode(predicted_ids)[0]
61
 
62
  return {"text": transcription}
 
63
  except Exception as e:
64
+ raise HTTPException(status_code=500, detail=str(e))
65
+
66
+ @app.get("/noneBmTts/")
67
+ async def noneBmTts(text: str, voice: str = "fr-FR-DeniseNeural"):
68
+ communicate = edge_tts.Communicate(text, voice)
69
+ buffer = io.BytesIO()
70
+ async for chunk in communicate.stream():
71
+ if chunk["type"] == "audio":
72
+ buffer.write(chunk["data"])
73
+ buffer.seek(0)
74
+ return StreamingResponse(buffer, media_type="audio/mpeg")