Spaces:
Paused
Paused
| from fastapi import FastAPI, UploadFile, File, HTTPException, status, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import JSONResponse | |
| import tempfile | |
| import os | |
| import sync_hc | |
| import logging | |
| from pydantic import BaseModel | |
| from speech_models.whisperx_manager import WhisperXModelManager | |
| from utils.audio_buffer import AudioBuffer | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class TranscriptionRequest(BaseModel): | |
| language: str = "en" | |
| task: str = "transcribe" | |
| chunk_duration: float = 30.0 | |
| class StreamingConfig(BaseModel): | |
| language: str = "en" | |
| task: str = "transcribe" | |
| chunk_duration: float = 30.0 | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Speech Transcription API", version="1.0.0") | |
| # Global model manager instance | |
| model_manager = WhisperXModelManager(model_name="tiny") | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| if not model_manager.is_loaded: | |
| raise RuntimeError("Failed to load model at startup") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return JSONResponse( | |
| content={"status": "healthy", "model_loaded": model_manager.is_loaded} | |
| ) | |
| async def readiness_check(): | |
| """Readiness check endpoint""" | |
| if model_manager.is_loaded: | |
| return JSONResponse( | |
| content={"status": "ready", "model_name": model_manager.model_name} | |
| ) | |
| else: | |
| return JSONResponse( | |
| content={"status": "not ready", "model_loaded": False}, | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE | |
| ) | |
| async def transcribe_audio( | |
| file: UploadFile = File(...), | |
| language: str = "en", | |
| chunk_duration: float = 30.0 | |
| ): | |
| """Transcribe uploaded audio file with chunking""" | |
| if not model_manager.is_loaded: | |
| raise HTTPException( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| detail="Model not loaded" | |
| ) | |
| # Validate file extension | |
| allowed_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac'} | |
| file_extension = os.path.splitext(file.filename)[1].lower() | |
| if file_extension not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail=f"Invalid file type. Allowed types: {', '.join(allowed_extensions)}" | |
| ) | |
| # Validate chunk duration | |
| if chunk_duration <= 0 or chunk_duration > 60: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Chunk duration must be between 1 and 60 seconds" | |
| ) | |
| try: | |
| # Save uploaded file temporarily | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: | |
| temp_file.write(await file.read()) | |
| temp_file_path = temp_file.name | |
| # Transcribe audio with chunking | |
| result = model_manager.transcribe(temp_file_path, language, chunk_duration) | |
| # Clean up temporary file | |
| os.unlink(temp_file_path) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| # Clean up temporary file if exists | |
| if 'temp_file_path' in locals(): | |
| try: | |
| os.unlink(temp_file_path) | |
| except: | |
| pass | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Transcription failed: {str(e)}" | |
| ) | |
| async def websocket_transcribe(websocket: WebSocket): | |
| """WebSocket endpoint for streaming transcription with chunking""" | |
| await websocket.accept() | |
| if not model_manager.is_loaded: | |
| await websocket.send_json({ | |
| "error": "Model not loaded", | |
| "status": "error" | |
| }) | |
| await websocket.close() | |
| return | |
| audio_buffer = None | |
| try: | |
| # Receive configuration | |
| config_data = await websocket.receive_json() | |
| language = config_data.get("language", "en") | |
| task = config_data.get("task", "transcribe") | |
| chunk_duration = config_data.get("chunk_duration", 30.0) | |
| # Initialize audio buffer | |
| audio_buffer = AudioBuffer(chunk_duration=chunk_duration) | |
| logger.info(f"Starting streaming transcription session - Language: {language}, Chunk Duration: {chunk_duration}s") | |
| # Send acknowledgment | |
| await websocket.send_json({ | |
| "status": "connected", | |
| "message": "Ready to receive audio data", | |
| "chunk_duration": chunk_duration | |
| }) | |
| # Process incoming audio chunks | |
| while True: | |
| try: | |
| # Receive audio data | |
| data = await websocket.receive_bytes() | |
| if not data: | |
| continue | |
| # Add to buffer and check if chunk is ready | |
| ready_chunk = audio_buffer.add_data(data) | |
| if ready_chunk is not None: | |
| logger.info("Processing buffered audio chunk") | |
| # Process complete chunk | |
| result = await model_manager.transcribe_stream( | |
| ready_chunk.tobytes(), | |
| language, | |
| chunk_duration | |
| ) | |
| # Send transcription result | |
| await websocket.send_json({ | |
| "status": "transcription", | |
| "result": result, | |
| "chunk_complete": True | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket disconnected") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error processing audio chunk: {str(e)}") | |
| await websocket.send_json({ | |
| "error": str(e), | |
| "status": "error" | |
| }) | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {str(e)}") | |
| finally: | |
| # Process any remaining audio data | |
| if audio_buffer and len(audio_buffer.get_remaining()) > 0: | |
| try: | |
| remaining_audio = audio_buffer.get_remaining() | |
| if len(remaining_audio) > 16000: # At least 1 second | |
| result = await model_manager.transcribe_stream( | |
| remaining_audio.tobytes(), | |
| language | |
| ) | |
| await websocket.send_json({ | |
| "status": "transcription", | |
| "result": result, | |
| "final": True | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing final chunk: {str(e)}") | |
| await websocket.close() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |