whisperurdu-api / app.py
ibrahim145's picture
Update app.py
037e421 verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import torch
import tempfile
import os
import soundfile as sf
import numpy as np
app = FastAPI(
title="Whisper Urdu ASR API",
description="Transcribe Urdu audio using a fine-tuned Whisper model",
)
# Globals
model = None
processor = None
device = "cpu" # βœ… Force CPU for Hugging Face Spaces
@app.get("/")
def home():
return {
"status": "βœ… Urdu Whisper API is running",
"message": "Use /docs or /transcribe endpoint to upload a .wav file for Urdu transcription.",
}
def load_model():
"""Lazy load model only once"""
global model, processor
if model is None or processor is None:
model_id = "Abdul145/whisper-medium-urdu-custom"
print("πŸ”„ Loading model...")
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id)
model.to(device)
model.eval()
print("βœ… Model loaded on CPU")
@app.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
try:
load_model()
# Save temp audio file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
# Load with soundfile (reliable for Spaces)
speech_array, sampling_rate = sf.read(tmp_path)
os.remove(tmp_path)
# Convert stereo β†’ mono
if speech_array.ndim > 1:
speech_array = np.mean(speech_array, axis=1)
# Resample to 16k if needed
if sampling_rate != 16000:
import librosa
speech_array = librosa.resample(
speech_array.astype(np.float32), orig_sr=sampling_rate, target_sr=16000
)
sampling_rate = 16000
# Ensure float32
speech_array = np.asarray(speech_array, dtype=np.float32)
# Convert to input
input_features = processor(
speech_array, sampling_rate=sampling_rate, return_tensors="pt"
).input_features.to(device)
# Generate
with torch.no_grad():
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
return {"transcription": transcription.strip()}
except Exception as e:
print("❌ Transcription error:", e)
return JSONResponse(content={"error": str(e)}, status_code=500)