Gaoussin commited on
Commit
1945a83
·
verified ·
1 Parent(s): 906c6bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -1,43 +1,63 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from transformers import Wav2Vec2ForCTC, AutoProcessor
3
  import torch
4
  import librosa
5
- import io
 
 
6
 
7
- app = FastAPI()
 
 
8
 
9
- # Load model and processor once at startup
10
- MODEL_ID = "facebook/mms-1b-all"
11
- processor = AutoProcessor.from_pretrained(MODEL_ID)
12
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
13
 
14
- @app.post("/transcribe/")
15
- async def transcribe(audio_file: UploadFile = File(...)):
16
- # 1. Check if a file was actually uploaded
17
- if not audio_file:
18
- raise HTTPException(status_code=400, detail="No file uploaded")
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- # 2. Read the file into memory
22
- audio_bytes = await audio_file.read()
23
-
24
- # 3. Load and Resample to 16,000 Hz using librosa
25
- # io.BytesIO(audio_bytes) lets librosa treat the bytes like a file
26
- audio_data, _ = librosa.load(io.BytesIO(audio_bytes), sr=16000)
27
 
28
- # 4. Setup Bambara Adapter
29
- processor.tokenizer.set_target_lang("bam")
30
- model.load_adapter("bam")
31
 
32
- # 5. Perform Inference
33
- inputs = processor(audio_data, sampling_rate=16_000, return_tensors="pt")
34
- with torch.no_grad():
 
 
35
  logits = model(**inputs).logits
36
 
 
37
  predicted_ids = torch.argmax(logits, dim=-1)
38
  transcription = processor.batch_decode(predicted_ids)[0]
39
 
40
  return {"text": transcription}
41
 
42
  except Exception as e:
43
- raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")
 
 
1
+ import os
2
+ import io
3
  import torch
4
  import librosa
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
8
 
9
+ # Set cache to writable directory
10
+ os.environ["HF_HOME"] = "/tmp/hf"
11
+ os.makedirs("/tmp/hf", exist_ok=True)
12
 
13
+ app = FastAPI(title="Bambara ASR Dedicated API")
 
 
 
14
 
15
+ # Enable CORS for your frontend
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
 
24
+ # Load ASR components globally
25
+ device = "cpu"
26
+ model_id = "facebook/mms-1b-all"
27
+
28
+ print("Loading processor and model...")
29
+ processor = AutoProcessor.from_pretrained(model_id)
30
+ model = Wav2Vec2ForCTC.from_pretrained(model_id).to(device)
31
+
32
+ # Pre-load Bambara adapter to prevent lag/OOM on first request
33
+ processor.tokenizer.set_target_lang("bam")
34
+ model.load_adapter("bam")
35
+ print("Bambara adapter loaded. System Ready.")
36
+
37
+ @app.post("/transcribe")
38
+ async def transcribe(audio_file: UploadFile = File(...)):
39
  try:
40
+ # Read file stream
41
+ content = await audio_file.read()
42
+ if not content:
43
+ return {"text": "Error: Empty audio file"}
 
 
44
 
45
+ # Load & Resample (Critical: Model expects 16,000Hz)
46
+ audio_data, _ = librosa.load(io.BytesIO(content), sr=16000)
 
47
 
48
+ # Prepare inputs
49
+ inputs = processor(audio_data, sampling_rate=16000, return_tensors="pt").to(device)
50
+
51
+ # Inference (inference_mode is more memory efficient than no_grad)
52
+ with torch.inference_mode():
53
  logits = model(**inputs).logits
54
 
55
+ # Decode output
56
  predicted_ids = torch.argmax(logits, dim=-1)
57
  transcription = processor.batch_decode(predicted_ids)[0]
58
 
59
  return {"text": transcription}
60
 
61
  except Exception as e:
62
+ print(f"Server Error: {e}")
63
+ return {"text": f"Error: {str(e)}"}