STT_Api / app.py
Somalitts's picture
Update app.py
727d35b verified
raw
history blame
1.28 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import torchaudio
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import uvicorn
import io
app = FastAPI()
# Allow requests from Flutter (localhost or any domain)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and processor once at startup
processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
contents = await file.read()
audio_bytes = io.BytesIO(contents)
waveform, sample_rate = torchaudio.load(audio_bytes)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
return {"transcription": transcription}