""" FastAPI REST API for Whisper German ASR Provides endpoints for audio transcription """ from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import WhisperForConditionalGeneration, WhisperProcessor import librosa import numpy as np from pathlib import Path import io from typing import Optional import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Whisper German ASR API", description="REST API for German speech recognition using fine-tuned Whisper model", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model model = None processor = None device = None class TranscriptionResponse(BaseModel): """Response model for transcription""" transcription: str language: str = "de" duration: Optional[float] = None model: str = "whisper-small-german" class HealthResponse(BaseModel): """Response model for health check""" status: str model_loaded: bool device: str def load_model(model_path: str = "./whisper_test_tuned"): """Load the fine-tuned Whisper model""" global model, processor, device logger.info(f"Loading model from: {model_path}") model_path = Path(model_path) # Check for checkpoint directories if model_path.is_dir(): checkpoints = list(model_path.glob('checkpoint-*')) if checkpoints: latest = max(checkpoints, key=lambda p: int(p.name.split('-')[1])) model_path = latest logger.info(f"Using checkpoint: {latest.name}") model = WhisperForConditionalGeneration.from_pretrained(str(model_path)) processor = WhisperProcessor.from_pretrained("openai/whisper-small") # Set German language conditioning model.config.forced_decoder_ids = processor.get_decoder_prompt_ids( language="german", task="transcribe" ) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() logger.info(f"Model loaded successfully on {device}") @app.on_event("startup") async def startup_event(): """Load model on startup""" try: load_model() except Exception as e: logger.error(f"Failed to load model on startup: {e}") # Don't fail startup, allow manual model loading @app.get("/", response_model=dict) async def root(): """Root endpoint""" return { "message": "Whisper German ASR API", "version": "1.0.0", "endpoints": { "health": "/health", "transcribe": "/transcribe (POST)", "docs": "/docs" } } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" return HealthResponse( status="healthy" if model is not None else "model_not_loaded", model_loaded=model is not None, device=device if device else "unknown" ) @app.post("/transcribe", response_model=TranscriptionResponse) async def transcribe_audio( file: UploadFile = File(...), language: str = "de" ): """ Transcribe audio file to text Args: file: Audio file (wav, mp3, flac, etc.) language: Language code (default: de for German) Returns: TranscriptionResponse with transcription text """ if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Read audio file contents = await file.read() # Load audio with librosa audio, sr = librosa.load(io.BytesIO(contents), sr=16000, mono=True) duration = len(audio) / sr # Process audio input_features = processor( audio, sampling_rate=16000, return_tensors="pt" ).input_features.to(device) # Generate transcription with torch.no_grad(): predicted_ids = model.generate( input_features, max_length=448, num_beams=5, early_stopping=True ) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] logger.info(f"Transcribed {file.filename}: {transcription[:50]}...") return TranscriptionResponse( transcription=transcription, language=language, duration=duration ) except Exception as e: logger.error(f"Transcription error: {e}") raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") @app.post("/reload-model") async def reload_model(model_path: str = "./whisper_test_tuned"): """Reload the model (admin endpoint)""" try: load_model(model_path) return {"status": "success", "message": "Model reloaded successfully"} except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)