Gaoussin commited on
Commit
882c6ec
·
verified ·
1 Parent(s): e578353

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -74
app.py CHANGED
@@ -1,89 +1,71 @@
1
  import os
2
- import io
3
- import torch
4
- import librosa
5
- import edge_tts
6
- import scipy.io.wavfile as wavfile
7
- from fastapi import FastAPI, Query, File, UploadFile, HTTPException
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from fastapi.responses import StreamingResponse
10
- from transformers import VitsModel, AutoTokenizer, Wav2Vec2ForCTC, AutoProcessor
11
 
12
- # 1. Environment and App Setup
13
  os.environ["HF_HOME"] = "/tmp/hf"
14
- app = FastAPI(title="Bambara AI API")
15
-
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"],
19
- allow_credentials=True,
20
- allow_methods=["*"],
21
- allow_headers=["*"],
22
- )
23
-
24
- device = "cpu"
25
-
26
- # 2. Load Models (Switching to 300M for stability)
27
- # ASR Model
28
- asr_model_id = "facebook/mms-300m-1107" # Smaller, faster, more stable
29
- asr_processor = AutoProcessor.from_pretrained(asr_model_id)
30
- asr_model = Wav2Vec2ForCTC.from_pretrained(asr_model_id).to(device)
31
-
32
- # Load Bambara Adapter
33
- asr_processor.tokenizer.set_target_lang("bam")
34
- asr_model.load_adapter("bam")
35
-
36
- # TTS Model
37
- tts_model_id = "facebook/mms-tts-bam"
38
- tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_id)
39
- tts_model = VitsModel.from_pretrained(tts_model_id).to(device)
40
-
41
- @app.post("/transcribe")
42
- async def transcribe(audio_file: UploadFile = File(...)):
43
- try:
44
- # Read file
45
- content = await audio_file.read()
46
- if not content:
47
- raise HTTPException(status_code=400, detail="Empty audio file")
48
 
49
- # Load audio into memory
50
- # Resampling here to 16kHz is mandatory
51
- audio_data, _ = librosa.load(io.BytesIO(content), sr=16000)
 
 
 
52
 
53
- # Prepare for model
54
- inputs = asr_processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
55
-
56
- # Inference
57
- with torch.inference_mode():
58
- logits = asr_model(**inputs).logits
59
 
60
- # Decode
61
- predicted_ids = torch.argmax(logits, dim=-1)
62
- transcription = asr_processor.batch_decode(predicted_ids)[0]
63
 
64
- return {"text": transcription}
 
 
 
65
 
66
- except Exception as e:
67
- print(f"Error: {e}")
68
- raise HTTPException(status_code=500, detail=str(e))
69
 
70
  @app.get("/tts/")
71
- async def tts(text: str = Query(..., description="Bambara text")):
72
- inputs = tts_tokenizer(text, return_tensors="pt").to(device)
73
- with torch.inference_mode():
74
- output = tts_model(**inputs).waveform
75
-
 
 
 
 
 
76
  buffer = io.BytesIO()
77
- wavfile.write(buffer, rate=tts_model.config.sampling_rate, data=output[0].cpu().numpy())
78
  buffer.seek(0)
 
79
  return StreamingResponse(buffer, media_type="audio/wav")
80
 
 
81
  @app.get("/noneBmTts/")
82
- async def noneBmTts(text: str, voice: str = "fr-FR-DeniseNeural"):
83
- communicate = edge_tts.Communicate(text, voice)
84
- buffer = io.BytesIO()
85
- async for chunk in communicate.stream():
86
- if chunk["type"] == "audio":
87
- buffer.write(chunk["data"])
88
- buffer.seek(0)
89
- return StreamingResponse(buffer, media_type="audio/mpeg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
 
 
 
 
2
 
 
3
  os.environ["HF_HOME"] = "/tmp/hf"
4
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf"
5
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf"
6
+ os.makedirs("/tmp/hf", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from fastapi import FastAPI, Query
9
+ from fastapi.responses import StreamingResponse
10
+ from transformers import VitsModel, AutoTokenizer
11
+ import torch, scipy.io.wavfile as wavfile
12
+ import io
13
+ import edge_tts
14
 
 
 
 
 
 
 
15
 
16
+ app = FastAPI(title="Bambara TTS API")
 
 
17
 
18
+ # Load model once at startup
19
+ model = VitsModel.from_pretrained("facebook/mms-tts-bam")
20
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-bam")
21
+ sampling_rate = model.config.sampling_rate
22
 
 
 
 
23
 
24
  @app.get("/tts/")
25
+ async def tts(text: str = Query(..., description="Bambara text to synthesize")):
26
+ inputs = tokenizer(text, return_tensors="pt")
27
+ inputs = {k: v.to("cpu") for k, v in inputs.items()}
28
+
29
+ with torch.no_grad():
30
+ output = model(**inputs).waveform
31
+
32
+ waveform = output[0]
33
+
34
+ # Stream audio instead of saving to disk
35
  buffer = io.BytesIO()
36
+ wavfile.write(buffer, rate=sampling_rate, data=waveform.numpy())
37
  buffer.seek(0)
38
+
39
  return StreamingResponse(buffer, media_type="audio/wav")
40
 
41
+
42
  @app.get("/noneBmTts/")
43
+ async def noneBmTts(
44
+ text: str = Query(..., description="Text to synthesize"),
45
+ voice: str = Query(
46
+ "fr-FR-DeniseNeural", description="Voice ID (e.g., en-US-GuyNeural)"
47
+ ),
48
+ ):
49
+ try:
50
+ # Create the Communicate object with the requested text and voice
51
+ communicate = edge_tts.Communicate(text, voice)
52
+
53
+ buffer = io.BytesIO()
54
+
55
+ # Stream the audio chunks into the buffer
56
+ async for chunk in communicate.stream():
57
+ if chunk["type"] == "audio":
58
+ buffer.write(chunk["data"])
59
+
60
+ # Check if we actually got data
61
+ if buffer.tell() == 0:
62
+ raise HTTPException(
63
+ status_code=400, detail="Synthesis failed to produce audio."
64
+ )
65
+
66
+ buffer.seek(0)
67
+ return StreamingResponse(buffer, media_type="audio/mpeg")
68
+
69
+ except Exception as e:
70
+ # Catch errors like invalid voice names
71
+ raise HTTPException(status_code=400, detail=str(e))