Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI REST API for Whisper German ASR | |
| Provides endpoints for audio transcription | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| import librosa | |
| import numpy as np | |
| from pathlib import Path | |
| import io | |
| from typing import Optional | |
| import logging | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Whisper German ASR API", | |
| description="REST API for German speech recognition using fine-tuned Whisper model", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for model | |
| model = None | |
| processor = None | |
| device = None | |
| class TranscriptionResponse(BaseModel): | |
| """Response model for transcription""" | |
| transcription: str | |
| language: str = "de" | |
| duration: Optional[float] = None | |
| model: str = "whisper-small-german" | |
| class HealthResponse(BaseModel): | |
| """Response model for health check""" | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| def load_model(model_path: str = "./whisper_test_tuned"): | |
| """Load the fine-tuned Whisper model""" | |
| global model, processor, device | |
| logger.info(f"Loading model from: {model_path}") | |
| model_path = Path(model_path) | |
| # Check for checkpoint directories | |
| if model_path.is_dir(): | |
| checkpoints = list(model_path.glob('checkpoint-*')) | |
| if checkpoints: | |
| latest = max(checkpoints, key=lambda p: int(p.name.split('-')[1])) | |
| model_path = latest | |
| logger.info(f"Using checkpoint: {latest.name}") | |
| model = WhisperForConditionalGeneration.from_pretrained(str(model_path)) | |
| processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| # Set German language conditioning | |
| model.config.forced_decoder_ids = processor.get_decoder_prompt_ids( | |
| language="german", | |
| task="transcribe" | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| model.eval() | |
| logger.info(f"Model loaded successfully on {device}") | |
| async def startup_event(): | |
| """Load model on startup""" | |
| try: | |
| load_model() | |
| except Exception as e: | |
| logger.error(f"Failed to load model on startup: {e}") | |
| # Don't fail startup, allow manual model loading | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Whisper German ASR API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "transcribe": "/transcribe (POST)", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy" if model is not None else "model_not_loaded", | |
| model_loaded=model is not None, | |
| device=device if device else "unknown" | |
| ) | |
| async def transcribe_audio( | |
| file: UploadFile = File(...), | |
| language: str = "de" | |
| ): | |
| """ | |
| Transcribe audio file to text | |
| Args: | |
| file: Audio file (wav, mp3, flac, etc.) | |
| language: Language code (default: de for German) | |
| Returns: | |
| TranscriptionResponse with transcription text | |
| """ | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Read audio file | |
| contents = await file.read() | |
| # Load audio with librosa | |
| audio, sr = librosa.load(io.BytesIO(contents), sr=16000, mono=True) | |
| duration = len(audio) / sr | |
| # Process audio | |
| input_features = processor( | |
| audio, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features.to(device) | |
| # Generate transcription | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| max_length=448, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| logger.info(f"Transcribed {file.filename}: {transcription[:50]}...") | |
| return TranscriptionResponse( | |
| transcription=transcription, | |
| language=language, | |
| duration=duration | |
| ) | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") | |
| async def reload_model(model_path: str = "./whisper_test_tuned"): | |
| """Reload the model (admin endpoint)""" | |
| try: | |
| load_model(model_path) | |
| return {"status": "success", "message": "Model reloaded successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |