Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Optional | |
| import tempfile | |
| import subprocess | |
| import os | |
| import torch | |
| import librosa | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| MODEL_NAME = os.getenv("MODEL_NAME", "openai/whisper-small") | |
| app = FastAPI( | |
| title="Voice Complaint Transcriber API", | |
| description="Converts citizen grievance audio into text using Whisper.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # later replace with your Vercel frontend URL | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class VoiceComplaintTranscriber: | |
| def __init__(self, model_name: str = MODEL_NAME): | |
| self.model_name = model_name | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
| self.processor = WhisperProcessor.from_pretrained(self.model_name) | |
| self.model = WhisperForConditionalGeneration.from_pretrained( | |
| self.model_name, | |
| torch_dtype=self.dtype, | |
| ).to(self.device) | |
| self.model.eval() | |
| def convert_to_wav(self, input_path: Path) -> Path: | |
| output_path = input_path.with_suffix(".wav") | |
| subprocess.run( | |
| [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", str(input_path), | |
| "-ar", "16000", | |
| "-ac", "1", | |
| str(output_path), | |
| ], | |
| check=True, | |
| stdout=subprocess.DEVNULL, | |
| stderr=subprocess.DEVNULL, | |
| ) | |
| return output_path | |
| def transcribe(self, audio_path: Path, language: Optional[str] = "hi") -> str: | |
| if audio_path.suffix.lower() != ".wav": | |
| audio_path = self.convert_to_wav(audio_path) | |
| audio_array, _ = librosa.load( | |
| str(audio_path), | |
| sr=16000, | |
| mono=True, | |
| ) | |
| inputs = self.processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| ) | |
| input_features = inputs.input_features.to( | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| generate_kwargs = { | |
| "inputs": input_features, | |
| "max_new_tokens": 256, | |
| } | |
| if language and language != "auto": | |
| generate_kwargs["forced_decoder_ids"] = self.processor.get_decoder_prompt_ids( | |
| language=language, | |
| task="transcribe", | |
| ) | |
| with torch.no_grad(): | |
| predicted_ids = self.model.generate(**generate_kwargs) | |
| text = self.processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True, | |
| )[0].strip() | |
| return text | |
| transcriber = VoiceComplaintTranscriber() | |
| def home(): | |
| return { | |
| "message": "Voice Complaint Transcriber API is running", | |
| "model": MODEL_NAME, | |
| "device": transcriber.device, | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model": MODEL_NAME, | |
| "device": transcriber.device, | |
| } | |
| async def transcribe_audio( | |
| file: UploadFile = File(...), | |
| language: str = Form("hi"), # hi / en / auto | |
| ): | |
| suffix = Path(file.filename).suffix or ".ogg" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
| temp_path = Path(temp_file.name) | |
| temp_file.write(await file.read()) | |
| wav_path = temp_path.with_suffix(".wav") | |
| try: | |
| text = transcriber.transcribe( | |
| audio_path=temp_path, | |
| language=language, | |
| ) | |
| return { | |
| "transcribed_text": text, | |
| "language": language, | |
| "model": MODEL_NAME, | |
| "method": "whisper_direct_fastapi", | |
| } | |
| finally: | |
| if temp_path.exists(): | |
| temp_path.unlink() | |
| if wav_path.exists(): | |
| wav_path.unlink() |