Somalitts commited on
Commit
c8ea8ab
·
verified ·
1 Parent(s): 900e386

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -66
app.py CHANGED
@@ -1,18 +1,16 @@
1
  import os
2
- import io
3
- import uuid
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
6
- import torch
7
  import torchaudio
 
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
-
10
- # This line is good practice but less critical now with the updated Dockerfile using HF_HOME
11
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
12
 
13
  app = FastAPI()
14
 
15
- # CORS middleware for allowing requests from your mobile app
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
@@ -20,69 +18,30 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- # Load the ASR model and processor once at startup to save time on each request
24
- try:
25
- processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
26
- model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
27
- except Exception as e:
28
- # If the model fails to load, the app can't work.
29
- # Log this for debugging. In a real app, you might exit or return an error state.
30
- print(f"FATAL: Could not load model. Error: {e}")
31
- model = None
32
- processor = None
33
 
34
  @app.get("/")
35
  async def root():
36
- """A simple endpoint to check if the API is running."""
37
  return {"message": "Somali Speech-to-Text API is running."}
38
 
39
  @app.post("/transcribe")
40
  async def transcribe(file: UploadFile = File(...)):
41
- """
42
- Receives an audio file, transcribes it, and returns the text.
43
- """
44
- if not model or not processor:
45
- return {"error": "ASR model is not available."}
46
-
47
- # Use a temporary file to reliably load the audio data.
48
- # This helps torchaudio correctly identify the audio format.
49
- temp_dir = "/tmp"
50
- os.makedirs(temp_dir, exist_ok=True) # Ensure the directory exists
51
- # Use the original filename from the upload to preserve the extension (e.g., .m4a, .wav)
52
- temp_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_{file.filename}")
53
-
54
- try:
55
- # Save the uploaded file's content to the temporary file
56
- with open(temp_file_path, "wb") as buffer:
57
- buffer.write(await file.read())
58
-
59
- # Load the audio using the file path, which is more reliable
60
- waveform, sample_rate = torchaudio.load(temp_file_path)
61
-
62
- # Resample the audio to the 16kHz required by the model
63
- if sample_rate != 16000:
64
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
65
- waveform = resampler(waveform)
66
-
67
- # Process the audio waveform
68
- # .squeeze() removes any redundant channels/dimensions
69
- inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True)
70
-
71
- # Perform inference
72
- with torch.no_grad():
73
- logits = model(inputs.input_values).logits
74
-
75
- # Decode the model's output to text
76
- predicted_ids = torch.argmax(logits, dim=-1)
77
- transcription = processor.batch_decode(predicted_ids)[0]
78
-
79
- return {"transcription": transcription.upper()} # Returning in uppercase is common for ASR
80
-
81
- except Exception as e:
82
- # If anything goes wrong during processing, return a specific error
83
- # This helps in debugging on the mobile client side
84
- return {"error": f"Failed to process audio file. Reason: {str(e)}"}
85
- finally:
86
- # Clean up the temporary file after processing is complete
87
- if os.path.exists(temp_file_path):
88
- os.remove(temp_file_path)
 
1
  import os
2
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache" # Important for Docker
3
+
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.middleware.cors import CORSMiddleware
 
6
  import torchaudio
7
+ import torch
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
+ import io
 
 
10
 
11
  app = FastAPI()
12
 
13
+ # Allow all origins (for Flutter)
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
18
  allow_headers=["*"],
19
  )
20
 
21
+ # Load model
22
+ processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
23
+ model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
 
 
 
 
 
 
 
24
 
25
  @app.get("/")
26
  async def root():
 
27
  return {"message": "Somali Speech-to-Text API is running."}
28
 
29
  @app.post("/transcribe")
30
  async def transcribe(file: UploadFile = File(...)):
31
+ audio_bytes = await file.read()
32
+ audio_stream = io.BytesIO(audio_bytes)
33
+
34
+ waveform, sample_rate = torchaudio.load(audio_stream)
35
+
36
+ if sample_rate != 16000:
37
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
38
+ waveform = resampler(waveform)
39
+
40
+ inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
41
+
42
+ with torch.no_grad():
43
+ logits = model(**inputs).logits
44
+
45
+ predicted_ids = torch.argmax(logits, dim=-1)
46
+ transcription = processor.decode(predicted_ids[0])
47
+ return {"transcription": transcription}