Somalitts commited on
Commit
0f29908
·
verified ·
1 Parent(s): f95c80f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -23
app.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" # Important for Docker
3
-
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
6
  import torchaudio
@@ -8,9 +6,13 @@ import torch
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
  import io
10
 
 
 
 
 
11
  app = FastAPI()
12
 
13
- # Allow all origins (for Flutter)
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -18,9 +20,19 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
21
- # Load model
22
- processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
23
- model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
 
 
 
 
 
 
 
 
 
 
24
 
25
  @app.get("/")
26
  async def root():
@@ -28,20 +40,26 @@ async def root():
28
 
29
  @app.post("/transcribe")
30
  async def transcribe(file: UploadFile = File(...)):
31
- audio_bytes = await file.read()
32
- audio_stream = io.BytesIO(audio_bytes)
33
-
34
- waveform, sample_rate = torchaudio.load(audio_stream)
35
-
36
- if sample_rate != 16000:
37
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
38
- waveform = resampler(waveform)
39
-
40
- inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
41
-
42
- with torch.no_grad():
43
- logits = model(**inputs).logits
44
-
45
- predicted_ids = torch.argmax(logits, dim=-1)
46
- transcription = processor.decode(predicted_ids[0])
47
- return {"transcription": transcription}
 
 
 
 
 
 
 
1
  import os
 
 
2
  from fastapi import FastAPI, UploadFile, File
3
  from fastapi.middleware.cors import CORSMiddleware
4
  import torchaudio
 
6
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
  import io
8
 
9
+ # DO NOT set the cache directory here anymore.
10
+ # Let the Dockerfile's ENV variables handle it.
11
+ # REMOVED: os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
12
+
13
  app = FastAPI()
14
 
15
+ # Allow all origins
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
 
20
  allow_headers=["*"],
21
  )
22
 
23
+ # --- Model Loading ---
24
+ # This will now use the cache path set by the Dockerfile's ENV variables (/app/hf-cache)
25
+ print("Loading model and processor...")
26
+ try:
27
+ processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
28
+ model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
29
+ print("Model and processor loaded successfully.")
30
+ except Exception as e:
31
+ print(f"FATAL: Could not load model. Error: {e}")
32
+ # In a real app, you might want to exit or handle this gracefully
33
+ processor = None
34
+ model = None
35
+
36
 
37
  @app.get("/")
38
  async def root():
 
40
 
41
  @app.post("/transcribe")
42
  async def transcribe(file: UploadFile = File(...)):
43
+ if not model or not processor:
44
+ return {"error": "Model is not loaded, please check server logs for errors."}
45
+
46
+ try:
47
+ audio_bytes = await file.read()
48
+ audio_stream = io.BytesIO(audio_bytes)
49
+
50
+ waveform, sample_rate = torchaudio.load(audio_stream)
51
+
52
+ if sample_rate != 16000:
53
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
54
+ waveform = resampler(waveform)
55
+
56
+ inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
57
+
58
+ with torch.no_grad():
59
+ logits = model(**inputs).logits
60
+
61
+ predicted_ids = torch.argmax(logits, dim=-1)
62
+ transcription = processor.decode(predicted_ids[0])
63
+ return {"transcription": transcription}
64
+ except Exception as e:
65
+ return {"error": f"An error occurred during transcription: {str(e)}"}