import os import json import torch import librosa import numpy as np from pathlib import Path from typing import Dict, Optional, List from datetime import datetime from fastapi import FastAPI, File, UploadFile, HTTPException from pydantic import BaseModel from transformers import pipeline # --- Configuration --- WHISPER_MODEL = os.getenv("WHISPER_MODEL", "small") # small, medium, large WHISPER_PORT = int(os.getenv("WHISPER_PORT", 8000)) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 # Global model cache _whisper_pipeline = None _model_info = { "model_name": WHISPER_MODEL, "device": DEVICE, "dtype": str(TORCH_DTYPE), "cuda_available": torch.cuda.is_available() } # --- Models --- class TranscriptionResponse(BaseModel): text: str language: str = "en" confidence: Optional[float] = None duration: float = 0.0 timestamp: str = "" # --- Utility Functions --- def get_whisper_pipeline(): """Get or initialize the Whisper pipeline (cached).""" global _whisper_pipeline if _whisper_pipeline is not None: return _whisper_pipeline print(f"๐Ÿ”„ Loading Whisper model: {WHISPER_MODEL} on {DEVICE} with dtype {TORCH_DTYPE}") _whisper_pipeline = pipeline( "automatic-speech-recognition", model=f"openai/whisper-{WHISPER_MODEL}", device=DEVICE, torch_dtype=TORCH_DTYPE ) print(f"โœ… Whisper model loaded successfully") return _whisper_pipeline def load_and_resample_audio(audio_path: str, target_sr: int = 16000) -> tuple: """Load audio file and resample to 16kHz (required by Whisper).""" try: # Load audio file with librosa (no ffmpeg needed) audio, sr = librosa.load(audio_path, sr=target_sr, mono=True) duration = librosa.get_duration(y=audio, sr=sr) print(f"๐Ÿ“ Loaded audio: {Path(audio_path).name} | Duration: {duration:.2f}s | SR: {sr}Hz") return audio, sr, duration except Exception as e: print(f"โŒ Error loading audio: {e}") raise async def transcribe_audio(audio_path: str) -> Dict: """Transcribe audio file using Whisper.""" try: # Load audio audio, sr, duration = load_and_resample_audio(audio_path) # Get pipeline pipeline_model = get_whisper_pipeline() print(f"๐ŸŽค Transcribing {Path(audio_path).name}...") # Transcribe result = pipeline_model( audio, chunk_length_s=30, stride_length_s=(4, 2), batch_size=24 if torch.cuda.is_available() else 4 ) print(f"โœ… Transcription complete") return { "text": result.get("text", "").strip(), "language": "en", # Whisper doesn't return language detection reliably "confidence": None, # Whisper doesn't provide per-segment confidence "duration": duration, "timestamp": datetime.now().isoformat() } except Exception as e: print(f"โŒ Transcription error: {e}") raise # --- FastAPI App --- app = FastAPI( title="Whisper Transcription Server", description="FastAPI server for audio transcription using OpenAI Whisper", version="1.0.0" ) @app.on_event("startup") async def startup(): print(f"๐Ÿš€ Whisper Server starting on port {WHISPER_PORT}") print(f"๐Ÿ“Š Configuration:") print(f" - Model: {WHISPER_MODEL}") print(f" - Device: {DEVICE}") print(f" - CUDA Available: {torch.cuda.is_available()}") print(f" - Torch Dtype: {TORCH_DTYPE}") # Pre-load model get_whisper_pipeline() @app.get("/health") async def health_check(): """Check server health and model status.""" return { "status": "healthy", "model_info": _model_info, "cuda_available": torch.cuda.is_available(), "device": DEVICE } @app.get("/") async def root(): """Root endpoint with server info.""" return { "server": "Whisper Transcription Backend", "model": WHISPER_MODEL, "device": DEVICE, "endpoints": { "/health": "Server health check", "/transcribe": "POST - Transcribe audio file", "/transcribe_file": "POST - Alternative transcribe endpoint" } } @app.post("/transcribe") async def transcribe(file: UploadFile = File(...)): """ Transcribe an uploaded audio file. Accepts: mp3, wav, m4a, flac, ogg, aac """ if not file.filename: raise HTTPException(status_code=400, detail="No file provided") # Check file extension allowed_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac'} file_ext = Path(file.filename).suffix.lower() if file_ext not in allowed_extensions: raise HTTPException( status_code=400, detail=f"Unsupported file format: {file_ext}. Allowed: {allowed_extensions}" ) temp_file = None try: # Save uploaded file temporarily temp_path = Path(f"temp_{file.filename}") with open(temp_path, 'wb') as f: content = await file.read() f.write(content) temp_file = temp_path print(f"๐Ÿ“ค Processing uploaded file: {file.filename} ({len(content)} bytes)") # Transcribe result = await transcribe_audio(str(temp_path)) return { "audio_file": file.filename, "text": result["text"], "language": result["language"], "duration": result["duration"], "timestamp": result["timestamp"] } except Exception as e: print(f"โŒ Transcription failed: {e}") raise HTTPException(status_code=500, detail=str(e)) finally: # Cleanup if temp_file and temp_file.exists(): temp_file.unlink() print(f"๐Ÿงน Cleaned up temp file: {temp_file}") @app.post("/transcribe_file") async def transcribe_file(file: UploadFile = File(...)): """Alternative endpoint name for transcription.""" return await transcribe(file) @app.post("/transcribe_batch") async def transcribe_batch(files: List[UploadFile] = File(...)): """ Transcribe multiple audio files in parallel. """ if not files: raise HTTPException(status_code=400, detail="No files provided") results = [] for file in files: try: result = await transcribe(file) results.append({ "status": "success", "data": result }) except Exception as e: results.append({ "status": "error", "filename": file.filename, "error": str(e) }) return { "total": len(files), "results": results } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=WHISPER_PORT)