from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq import torch import tempfile import os import soundfile as sf import numpy as np app = FastAPI( title="Whisper Urdu ASR API", description="Transcribe Urdu audio using a fine-tuned Whisper model", ) # Globals model = None processor = None device = "cpu" # ✅ Force CPU for Hugging Face Spaces @app.get("/") def home(): return { "status": "✅ Urdu Whisper API is running", "message": "Use /docs or /transcribe endpoint to upload a .wav file for Urdu transcription.", } def load_model(): """Lazy load model only once""" global model, processor if model is None or processor is None: model_id = "Abdul145/whisper-medium-urdu-custom" print("🔄 Loading model...") processor = AutoProcessor.from_pretrained(model_id) model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) model.to(device) model.eval() print("✅ Model loaded on CPU") @app.post("/transcribe") async def transcribe_audio(file: UploadFile = File(...)): try: load_model() # Save temp audio file with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(await file.read()) tmp_path = tmp.name # Load with soundfile (reliable for Spaces) speech_array, sampling_rate = sf.read(tmp_path) os.remove(tmp_path) # Convert stereo → mono if speech_array.ndim > 1: speech_array = np.mean(speech_array, axis=1) # Resample to 16k if needed if sampling_rate != 16000: import librosa speech_array = librosa.resample( speech_array.astype(np.float32), orig_sr=sampling_rate, target_sr=16000 ) sampling_rate = 16000 # Ensure float32 speech_array = np.asarray(speech_array, dtype=np.float32) # Convert to input input_features = processor( speech_array, sampling_rate=sampling_rate, return_tensors="pt" ).input_features.to(device) # Generate with torch.no_grad(): predicted_ids = model.generate(input_features) transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] return {"transcription": transcription.strip()} except Exception as e: print("❌ Transcription error:", e) return JSONResponse(content={"error": str(e)}, status_code=500)