Somalitts commited on
Commit
64d5cde
·
verified ·
1 Parent(s): 7248b97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -16
app.py CHANGED
@@ -1,16 +1,24 @@
1
- import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" # Important for Docker
3
 
4
- from fastapi import FastAPI, UploadFile, File
5
- from fastapi.middleware.cors import CORSMiddleware
6
- import torchaudio
7
  import torch
 
 
 
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
- import io
10
 
11
- app = FastAPI()
 
 
 
 
 
 
 
 
 
12
 
13
- # Allow all origins (for Flutter)
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -18,27 +26,66 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
21
- # Load model
22
- processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
23
- model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @app.get("/")
26
- async def root():
27
  return {"message": "Somali Speech-to-Text API is running."}
28
 
29
  @app.post("/transcribe")
30
  async def transcribe(file: UploadFile = File(...)):
 
 
 
 
31
  audio_bytes = await file.read()
32
- audio_stream = io.BytesIO(audio_bytes)
33
-
34
- waveform, sample_rate = torchaudio.load(audio_stream)
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
36
  if sample_rate != 16000:
37
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
38
  waveform = resampler(waveform)
39
 
40
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
41
-
42
  with torch.no_grad():
43
  logits = model(**inputs).logits
44
 
 
 
 
1
 
2
+ import os
3
+ import io
 
4
  import torch
5
+ import torchaudio
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
+ from huggingface_hub import snapshot_download
10
 
11
+ # ---- Robust HF cache setup (writable in Docker/Spaces) ----
12
+ HF_HOME = os.environ.get("HF_HOME", "/tmp/hf")
13
+ os.environ["HF_HOME"] = HF_HOME
14
+ os.environ["TRANSFORMERS_CACHE"] = os.path.join(HF_HOME, "transformers")
15
+ os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
16
+
17
+ MODEL_ID = os.environ.get("MODEL_ID", "Mustafaa4a/ASR-Somali")
18
+ HF_TOKEN = os.environ.get("HF_TOKEN") # only needed for private repos
19
+
20
+ app = FastAPI(title="Somali ASR API")
21
 
 
22
  app.add_middleware(
23
  CORSMiddleware,
24
  allow_origins=["*"],
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ processor = None
30
+ model = None
31
+
32
+ @app.on_event("startup")
33
+ def _load_model():
34
+ global processor, model
35
+ try:
36
+ # Download the repo snapshot to a local, writable dir
37
+ local_dir = snapshot_download(
38
+ repo_id=MODEL_ID,
39
+ token=HF_TOKEN,
40
+ cache_dir=HF_HOME,
41
+ )
42
+ processor = Wav2Vec2Processor.from_pretrained(local_dir)
43
+ model = Wav2Vec2ForCTC.from_pretrained(local_dir)
44
+ model.eval()
45
+ except Exception as e:
46
+ # Surface a clear error instead of crashing Uvicorn silently
47
+ raise RuntimeError(f"Failed to load model '{MODEL_ID}': {e}")
48
+
49
+ @app.get("/health")
50
+ def health():
51
+ return {"status": "ok", "model_loaded": model is not None, "model_id": MODEL_ID}
52
 
53
  @app.get("/")
54
+ def root():
55
  return {"message": "Somali Speech-to-Text API is running."}
56
 
57
  @app.post("/transcribe")
58
  async def transcribe(file: UploadFile = File(...)):
59
+ if model is None or processor is None:
60
+ raise HTTPException(status_code=503, detail="Model not loaded yet. Try again shortly.")
61
+
62
+ # Read bytes
63
  audio_bytes = await file.read()
64
+ if not audio_bytes:
65
+ raise HTTPException(status_code=400, detail="Empty file")
66
+
67
+ # Load audio from bytes
68
+ try:
69
+ audio_stream = io.BytesIO(audio_bytes)
70
+ # torchaudio can auto-detect many formats if system codecs are present
71
+ waveform, sample_rate = torchaudio.load(audio_stream)
72
+ except Exception:
73
+ # As a fallback, try forcing WAV (in case the client always sends WAV)
74
+ try:
75
+ audio_stream = io.BytesIO(audio_bytes)
76
+ waveform, sample_rate = torchaudio.load(audio_stream, format="wav")
77
+ except Exception as e:
78
+ raise HTTPException(status_code=400, detail=f"Could not read audio: {e}")
79
 
80
+ # Mono + 16k resample for Wav2Vec2
81
+ if waveform.dim() == 2 and waveform.size(0) > 1:
82
+ waveform = torch.mean(waveform, dim=0, keepdim=True) # convert to mono
83
  if sample_rate != 16000:
84
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
85
  waveform = resampler(waveform)
86
 
87
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
88
+
89
  with torch.no_grad():
90
  logits = model(**inputs).logits
91