Spaces:
Sleeping
Sleeping
| # main.py | |
| import json | |
| import torch | |
| import torchaudio | |
| import requests | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import logging | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from transformers import AutoProcessor, AutoModelForCTC | |
| from pydantic import BaseModel | |
| from pydantic_settings import BaseSettings | |
| from pydub import AudioSegment | |
| from contextlib import asynccontextmanager | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class Settings(BaseSettings): | |
| model_name: str = "facebook/mms-1b-all" | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| text_model_url: str = "https://remostartdev-farmlingua-ai-conversational.hf.space/ask" | |
| timeout_sec: int = 3000 | |
| # Audio processing settings | |
| sample_rate: int = 16000 | |
| max_audio_seconds: int = 300 | |
| chunk_seconds: int = 20 | |
| overlap_seconds: int = 2 | |
| host: str = "0.0.0.0" | |
| port: int = 7860 | |
| workers: int = 1 | |
| # CORS settings | |
| cors_origins: list = ["*"] | |
| cors_methods: list = ["*"] | |
| cors_headers: list = ["*"] | |
| class Config: | |
| env_file = ".env" | |
| env_prefix = "STT_" | |
| settings = Settings() | |
| # Calculate derived constants | |
| MAX_SAMPLES = settings.sample_rate * settings.max_audio_seconds | |
| CHUNK_SIZE = settings.chunk_seconds * settings.sample_rate | |
| OVERLAP = settings.overlap_seconds * settings.sample_rate | |
| STEP = CHUNK_SIZE - OVERLAP | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| logger.info(f"Starting STT service with device: {settings.device}") | |
| logger.info(f"Loading model: {settings.model_name}") | |
| try: | |
| # Initialize processor and model | |
| app.state.processor = AutoProcessor.from_pretrained(settings.model_name) | |
| app.state.model = AutoModelForCTC.from_pretrained(settings.model_name).to(settings.device) | |
| app.state.model.eval() | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {str(e)}") | |
| raise | |
| yield | |
| logger.info("Shutting down STT service") | |
| if hasattr(app.state, 'model'): | |
| del app.state.model | |
| torch.cuda.empty_cache() | |
| app = FastAPI( | |
| title="Universal Audio STT", | |
| version="1.5.0", | |
| description="Speech-to-Text service with support for multiple audio formats", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=settings.cors_origins, | |
| allow_credentials=True, | |
| allow_methods=settings.cors_methods, | |
| allow_headers=settings.cors_headers, | |
| ) | |
| class STTResponse(BaseModel): | |
| transcript: str | |
| downstream_response: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| processing_time_ms: Optional[float] = None | |
| class HealthResponse(BaseModel): | |
| status: str | |
| device: str | |
| model: str | |
| max_audio_seconds: int | |
| uptime: Optional[float] = None | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "transcript": "", | |
| "downstream_response": None, | |
| "error": f"Internal server error: {str(exc)}" | |
| } | |
| ) | |
| def load_audio_safe(file_bytes: bytes) -> tuple[np.ndarray | None, str | None]: | |
| """Load audio file using pydub (supports more formats) with torchaudio fallback.""" | |
| if not file_bytes: | |
| return None, "Empty audio file" | |
| if len(file_bytes) == 0: | |
| return None, "Empty audio file" | |
| with tempfile.NamedTemporaryFile(suffix='.audio', delete=False) as tmp: | |
| tmp.write(file_bytes) | |
| tmp_path = tmp.name | |
| try: | |
| try: | |
| audio = AudioSegment.from_file(tmp_path) | |
| if len(audio) == 0: | |
| return None, "Audio file is empty" | |
| if audio.channels > 1: | |
| audio = audio.set_channels(1) | |
| if audio.frame_rate != settings.sample_rate: | |
| audio = audio.set_frame_rate(settings.sample_rate) | |
| samples = np.array(audio.get_array_of_samples()).astype(np.float32) | |
| if audio.sample_width == 1: # 8-bit | |
| samples = samples / 127.5 - 1.0 | |
| elif audio.sample_width == 2: # 16-bit | |
| samples = samples / 32768.0 | |
| elif audio.sample_width == 3: # 24-bit | |
| samples = samples / 8388608.0 | |
| elif audio.sample_width == 4: # 32-bit | |
| samples = samples / 2147483648.0 | |
| else: | |
| if len(samples) > 0: | |
| max_val = np.max(np.abs(samples)) | |
| if max_val > 0: | |
| samples = samples / max_val | |
| except Exception as pydub_error: | |
| logger.warning(f"Pydub failed, trying torchaudio: {pydub_error}") | |
| try: | |
| waveform, sr = torchaudio.load(tmp_path) | |
| if waveform.numel() == 0: | |
| return None, "Audio contains no samples" | |
| waveform = waveform.mean(dim=0) | |
| if sr != settings.sample_rate: | |
| waveform = torchaudio.functional.resample( | |
| waveform, orig_freq=sr, new_freq=settings.sample_rate | |
| ) | |
| samples = waveform.numpy() | |
| except Exception as torchaudio_error: | |
| logger.error(f"Both pydub and torchaudio failed: {torchaudio_error}") | |
| return None, f"Unsupported audio format. Supported formats: MP3, WAV, M4A, FLAC, OGG, etc." | |
| except Exception as e: | |
| logger.error(f"Failed to load audio: {str(e)}") | |
| return None, f"Failed to process audio file: {str(e)}" | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| if len(samples) == 0: | |
| return None, "Audio contains no samples" | |
| if len(samples) > MAX_SAMPLES: | |
| return None, f"Audio exceeds {settings.max_audio_seconds // 60} minute limit ({settings.max_audio_seconds} seconds)" | |
| return samples, None | |
| def chunk_audio(audio: np.ndarray): | |
| """Split audio into overlapping chunks for processing.""" | |
| for start in range(0, len(audio), STEP): | |
| chunk = audio[start:start + CHUNK_SIZE] | |
| if len(chunk) < settings.sample_rate: # Less than 1 second | |
| break | |
| yield chunk | |
| def transcribe_chunk(chunk: np.ndarray, processor, model) -> str: | |
| """Transcribe a single audio chunk.""" | |
| inputs = processor( | |
| chunk, | |
| sampling_rate=settings.sample_rate, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| with torch.no_grad(): | |
| logits = model(inputs.input_values.to(settings.device)).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| return processor.batch_decode(predicted_ids)[0].strip() | |
| def transcribe_long(audio: np.ndarray, processor, model) -> str: | |
| """Transcribe long audio by processing in chunks.""" | |
| texts = [] | |
| for chunk in chunk_audio(audio): | |
| text = transcribe_chunk(chunk, processor, model) | |
| if text: | |
| texts.append(text) | |
| return " ".join(texts) | |
| def forward_to_text_model(text: str) -> Optional[Dict[str, Any]]: | |
| """Forward transcribed text to downstream text model.""" | |
| if not text or not text.strip(): | |
| return None | |
| try: | |
| response = requests.post( | |
| settings.text_model_url, | |
| json={"query": text}, | |
| timeout=settings.timeout_sec, | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.Timeout: | |
| logger.warning("Downstream text model timeout") | |
| return None | |
| except requests.exceptions.RequestException as e: | |
| logger.warning(f"Downstream text model error: {str(e)}") | |
| return None | |
| async def stt(audio: UploadFile = File(...)): | |
| """ | |
| Speech-to-Text endpoint. | |
| Accepts audio files in various formats (MP3, WAV, M4A, FLAC, OGG, etc.) | |
| and returns transcribed text. | |
| """ | |
| import time | |
| start_time = time.time() | |
| # Validate file | |
| if not audio.content_type or not audio.content_type.startswith('audio/'): | |
| logger.warning(f"Invalid content type: {audio.content_type}") | |
| try: | |
| audio_bytes = await audio.read() | |
| logger.info(f"Received audio file: {audio.filename}, size: {len(audio_bytes)} bytes") | |
| except Exception as e: | |
| logger.error(f"Failed to read audio file: {str(e)}") | |
| raise HTTPException(status_code=400, detail="Failed to read audio file") | |
| # Load audio | |
| audio_data, error = load_audio_safe(audio_bytes) | |
| if error: | |
| logger.warning(f"Audio loading failed: {error}") | |
| return STTResponse( | |
| transcript="", | |
| downstream_response=None, | |
| error=error, | |
| processing_time_ms=(time.time() - start_time) * 1000 | |
| ) | |
| # Transcribe | |
| try: | |
| transcript = transcribe_long( | |
| audio_data, | |
| app.state.processor, | |
| app.state.model | |
| ) | |
| logger.info(f"Transcription successful, length: {len(transcript)} chars") | |
| except Exception as e: | |
| logger.error(f"Transcription failed: {str(e)}") | |
| return STTResponse( | |
| transcript="", | |
| downstream_response=None, | |
| error=f"Transcription failed: {str(e)}", | |
| processing_time_ms=(time.time() - start_time) * 1000 | |
| ) | |
| downstream = None | |
| if transcript and transcript.strip(): | |
| try: | |
| downstream = forward_to_text_model(transcript) | |
| except Exception as e: | |
| logger.warning(f"Downstream processing failed: {str(e)}") | |
| processing_time_ms = (time.time() - start_time) * 1000 | |
| logger.info(f"Request completed in {processing_time_ms:.2f}ms") | |
| return STTResponse( | |
| transcript=transcript, | |
| downstream_response=downstream, | |
| error=None, | |
| processing_time_ms=processing_time_ms | |
| ) | |
| async def health_check(): | |
| """ | |
| Health check endpoint. | |
| Returns service status and configuration. | |
| """ | |
| import psutil | |
| import time | |
| uptime = time.time() - app.state.start_time if hasattr(app.state, 'start_time') else None | |
| return HealthResponse( | |
| status="healthy", | |
| device=settings.device, | |
| model=settings.model_name, | |
| max_audio_seconds=settings.max_audio_seconds, | |
| uptime=uptime | |
| ) | |
| async def get_config(): | |
| """ | |
| Get current configuration (excluding sensitive data). | |
| """ | |
| return { | |
| "model": settings.model_name, | |
| "device": settings.device, | |
| "sample_rate": settings.sample_rate, | |
| "max_audio_seconds": settings.max_audio_seconds, | |
| "chunk_seconds": settings.chunk_seconds, | |
| "overlap_seconds": settings.overlap_seconds, | |
| "cors_enabled": True | |
| } | |
| async def startup_event(): | |
| app.state.start_time = time.time() if 'time' in locals() else None | |
| logger.info("STT service started successfully") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info(f"Starting server on {settings.host}:{settings.port}") | |
| uvicorn.run( | |
| "main:app", | |
| host=settings.host, | |
| port=settings.port, | |
| workers=settings.workers, | |
| log_level="info" | |
| ) |