Somalitts commited on
Commit
4acc86d
·
verified ·
1 Parent(s): a81dacb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -25
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" # Important for Docker
3
-
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
6
- import torchaudio
7
  import torch
 
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
- import io
 
 
10
 
11
  app = FastAPI()
12
 
13
- # Allow all origins (for Flutter)
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -18,30 +20,69 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
21
- # Load model
22
- processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
23
- model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
 
 
 
 
 
 
 
24
 
25
  @app.get("/")
26
  async def root():
 
27
  return {"message": "Somali Speech-to-Text API is running."}
28
 
29
  @app.post("/transcribe")
30
  async def transcribe(file: UploadFile = File(...)):
31
- audio_bytes = await file.read()
32
- audio_stream = io.BytesIO(audio_bytes)
33
-
34
- waveform, sample_rate = torchaudio.load(audio_stream)
35
-
36
- if sample_rate != 16000:
37
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
38
- waveform = resampler(waveform)
39
-
40
- inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
41
-
42
- with torch.no_grad():
43
- logits = model(**inputs).logits
44
-
45
- predicted_ids = torch.argmax(logits, dim=-1)
46
- transcription = processor.decode(predicted_ids[0])
47
- return {"transcription": transcription}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
+ import uuid
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
 
6
  import torch
7
+ import torchaudio
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
+
10
+ # This line is good practice but less critical now with the updated Dockerfile using HF_HOME
11
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
12
 
13
  app = FastAPI()
14
 
15
+ # CORS middleware for allowing requests from your mobile app
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
 
20
  allow_headers=["*"],
21
  )
22
 
23
+ # Load the ASR model and processor once at startup to save time on each request
24
+ try:
25
+ processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
26
+ model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
27
+ except Exception as e:
28
+ # If the model fails to load, the app can't work.
29
+ # Log this for debugging. In a real app, you might exit or return an error state.
30
+ print(f"FATAL: Could not load model. Error: {e}")
31
+ model = None
32
+ processor = None
33
 
34
  @app.get("/")
35
  async def root():
36
+ """A simple endpoint to check if the API is running."""
37
  return {"message": "Somali Speech-to-Text API is running."}
38
 
39
  @app.post("/transcribe")
40
  async def transcribe(file: UploadFile = File(...)):
41
+ """
42
+ Receives an audio file, transcribes it, and returns the text.
43
+ """
44
+ if not model or not processor:
45
+ return {"error": "ASR model is not available."}
46
+
47
+ # Use a temporary file to reliably load the audio data.
48
+ # This helps torchaudio correctly identify the audio format.
49
+ temp_dir = "/tmp"
50
+ os.makedirs(temp_dir, exist_ok=True) # Ensure the directory exists
51
+ # Use the original filename from the upload to preserve the extension (e.g., .m4a, .wav)
52
+ temp_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_{file.filename}")
53
+
54
+ try:
55
+ # Save the uploaded file's content to the temporary file
56
+ with open(temp_file_path, "wb") as buffer:
57
+ buffer.write(await file.read())
58
+
59
+ # Load the audio using the file path, which is more reliable
60
+ waveform, sample_rate = torchaudio.load(temp_file_path)
61
+
62
+ # Resample the audio to the 16kHz required by the model
63
+ if sample_rate != 16000:
64
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
65
+ waveform = resampler(waveform)
66
+
67
+ # Process the audio waveform
68
+ # .squeeze() removes any redundant channels/dimensions
69
+ inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True)
70
+
71
+ # Perform inference
72
+ with torch.no_grad():
73
+ logits = model(inputs.input_values).logits
74
+
75
+ # Decode the model's output to text
76
+ predicted_ids = torch.argmax(logits, dim=-1)
77
+ transcription = processor.batch_decode(predicted_ids)[0]
78
+
79
+ return {"transcription": transcription.upper()} # Returning in uppercase is common for ASR
80
+
81
+ except Exception as e:
82
+ # If anything goes wrong during processing, return a specific error
83
+ # This helps in debugging on the mobile client side
84
+ return {"error": f"Failed to process audio file. Reason: {str(e)}"}
85
+ finally:
86
+ # Clean up the temporary file after processing is complete
87
+ if os.path.exists(temp_file_path):
88
+ os.remove(temp_file_path)