Somalitts commited on
Commit
727d35b
·
verified ·
1 Parent(s): 615271a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -1,37 +1,41 @@
1
  from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse
3
  import torchaudio
4
  import torch
5
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 
6
  import io
7
 
8
- # Load model and processor
9
- processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
10
- model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
11
- model.eval()
12
 
13
- # Initialize FastAPI
14
- app = FastAPI(
15
- title="Somali Speech-to-Text API",
16
- description="Upload a Somali audio file (.wav) and receive text transcription using ASR model.",
17
- version="1.0",
 
 
18
  )
19
 
 
 
 
 
20
  @app.post("/transcribe")
21
- async def transcribe(audio: UploadFile = File(...)):
22
- # Read audio bytes
23
- audio_bytes = await audio.read()
24
- waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes))
 
25
 
26
- # Ensure 16kHz sample rate
27
  if sample_rate != 16000:
28
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
 
29
 
30
- # Process input
31
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
32
  with torch.no_grad():
33
  logits = model(**inputs).logits
34
 
35
  predicted_ids = torch.argmax(logits, dim=-1)
36
  transcription = processor.decode(predicted_ids[0])
37
- return JSONResponse(content={"transcription": transcription})
 
1
  from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  import torchaudio
4
  import torch
5
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
+ import uvicorn
7
  import io
8
 
9
+ app = FastAPI()
 
 
 
10
 
11
+ # Allow requests from Flutter (localhost or any domain)
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
  )
19
 
20
+ # Load model and processor once at startup
21
+ processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
22
+ model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
23
+
24
  @app.post("/transcribe")
25
+ async def transcribe_audio(file: UploadFile = File(...)):
26
+ contents = await file.read()
27
+ audio_bytes = io.BytesIO(contents)
28
+
29
+ waveform, sample_rate = torchaudio.load(audio_bytes)
30
 
 
31
  if sample_rate != 16000:
32
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
33
+ waveform = resampler(waveform)
34
 
 
35
  inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
36
  with torch.no_grad():
37
  logits = model(**inputs).logits
38
 
39
  predicted_ids = torch.argmax(logits, dim=-1)
40
  transcription = processor.decode(predicted_ids[0])
41
+ return {"transcription": transcription}