# main.py import json import torch import torchaudio import requests import numpy as np import tempfile import os import logging from typing import Optional, Dict, Any from fastapi import FastAPI, UploadFile, File, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from transformers import AutoProcessor, AutoModelForCTC from pydantic import BaseModel from pydantic_settings import BaseSettings from pydub import AudioSegment from contextlib import asynccontextmanager logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class Settings(BaseSettings): model_name: str = "facebook/mms-1b-all" device: str = "cuda" if torch.cuda.is_available() else "cpu" text_model_url: str = "https://remostartdev-farmlingua-ai-conversational.hf.space/ask" timeout_sec: int = 3000 # Audio processing settings sample_rate: int = 16000 max_audio_seconds: int = 300 chunk_seconds: int = 20 overlap_seconds: int = 2 host: str = "0.0.0.0" port: int = 7860 workers: int = 1 # CORS settings cors_origins: list = ["*"] cors_methods: list = ["*"] cors_headers: list = ["*"] class Config: env_file = ".env" env_prefix = "STT_" settings = Settings() # Calculate derived constants MAX_SAMPLES = settings.sample_rate * settings.max_audio_seconds CHUNK_SIZE = settings.chunk_seconds * settings.sample_rate OVERLAP = settings.overlap_seconds * settings.sample_rate STEP = CHUNK_SIZE - OVERLAP @asynccontextmanager async def lifespan(app: FastAPI): # Startup logger.info(f"Starting STT service with device: {settings.device}") logger.info(f"Loading model: {settings.model_name}") try: # Initialize processor and model app.state.processor = AutoProcessor.from_pretrained(settings.model_name) app.state.model = AutoModelForCTC.from_pretrained(settings.model_name).to(settings.device) app.state.model.eval() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise yield logger.info("Shutting down STT service") if hasattr(app.state, 'model'): del app.state.model torch.cuda.empty_cache() app = FastAPI( title="Universal Audio STT", version="1.5.0", description="Speech-to-Text service with support for multiple audio formats", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_credentials=True, allow_methods=settings.cors_methods, allow_headers=settings.cors_headers, ) class STTResponse(BaseModel): transcript: str downstream_response: Optional[Dict[str, Any]] = None error: Optional[str] = None processing_time_ms: Optional[float] = None class HealthResponse(BaseModel): status: str device: str model: str max_audio_seconds: int uptime: Optional[float] = None @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): logger.error(f"Unhandled exception: {str(exc)}", exc_info=True) return JSONResponse( status_code=500, content={ "transcript": "", "downstream_response": None, "error": f"Internal server error: {str(exc)}" } ) def load_audio_safe(file_bytes: bytes) -> tuple[np.ndarray | None, str | None]: """Load audio file using pydub (supports more formats) with torchaudio fallback.""" if not file_bytes: return None, "Empty audio file" if len(file_bytes) == 0: return None, "Empty audio file" with tempfile.NamedTemporaryFile(suffix='.audio', delete=False) as tmp: tmp.write(file_bytes) tmp_path = tmp.name try: try: audio = AudioSegment.from_file(tmp_path) if len(audio) == 0: return None, "Audio file is empty" if audio.channels > 1: audio = audio.set_channels(1) if audio.frame_rate != settings.sample_rate: audio = audio.set_frame_rate(settings.sample_rate) samples = np.array(audio.get_array_of_samples()).astype(np.float32) if audio.sample_width == 1: # 8-bit samples = samples / 127.5 - 1.0 elif audio.sample_width == 2: # 16-bit samples = samples / 32768.0 elif audio.sample_width == 3: # 24-bit samples = samples / 8388608.0 elif audio.sample_width == 4: # 32-bit samples = samples / 2147483648.0 else: if len(samples) > 0: max_val = np.max(np.abs(samples)) if max_val > 0: samples = samples / max_val except Exception as pydub_error: logger.warning(f"Pydub failed, trying torchaudio: {pydub_error}") try: waveform, sr = torchaudio.load(tmp_path) if waveform.numel() == 0: return None, "Audio contains no samples" waveform = waveform.mean(dim=0) if sr != settings.sample_rate: waveform = torchaudio.functional.resample( waveform, orig_freq=sr, new_freq=settings.sample_rate ) samples = waveform.numpy() except Exception as torchaudio_error: logger.error(f"Both pydub and torchaudio failed: {torchaudio_error}") return None, f"Unsupported audio format. Supported formats: MP3, WAV, M4A, FLAC, OGG, etc." except Exception as e: logger.error(f"Failed to load audio: {str(e)}") return None, f"Failed to process audio file: {str(e)}" finally: try: os.unlink(tmp_path) except: pass if len(samples) == 0: return None, "Audio contains no samples" if len(samples) > MAX_SAMPLES: return None, f"Audio exceeds {settings.max_audio_seconds // 60} minute limit ({settings.max_audio_seconds} seconds)" return samples, None def chunk_audio(audio: np.ndarray): """Split audio into overlapping chunks for processing.""" for start in range(0, len(audio), STEP): chunk = audio[start:start + CHUNK_SIZE] if len(chunk) < settings.sample_rate: # Less than 1 second break yield chunk def transcribe_chunk(chunk: np.ndarray, processor, model) -> str: """Transcribe a single audio chunk.""" inputs = processor( chunk, sampling_rate=settings.sample_rate, return_tensors="pt", padding=True, ) with torch.no_grad(): logits = model(inputs.input_values.to(settings.device)).logits predicted_ids = torch.argmax(logits, dim=-1) return processor.batch_decode(predicted_ids)[0].strip() def transcribe_long(audio: np.ndarray, processor, model) -> str: """Transcribe long audio by processing in chunks.""" texts = [] for chunk in chunk_audio(audio): text = transcribe_chunk(chunk, processor, model) if text: texts.append(text) return " ".join(texts) def forward_to_text_model(text: str) -> Optional[Dict[str, Any]]: """Forward transcribed text to downstream text model.""" if not text or not text.strip(): return None try: response = requests.post( settings.text_model_url, json={"query": text}, timeout=settings.timeout_sec, ) response.raise_for_status() return response.json() except requests.exceptions.Timeout: logger.warning("Downstream text model timeout") return None except requests.exceptions.RequestException as e: logger.warning(f"Downstream text model error: {str(e)}") return None @app.post("/stt", response_model=STTResponse) async def stt(audio: UploadFile = File(...)): """ Speech-to-Text endpoint. Accepts audio files in various formats (MP3, WAV, M4A, FLAC, OGG, etc.) and returns transcribed text. """ import time start_time = time.time() # Validate file if not audio.content_type or not audio.content_type.startswith('audio/'): logger.warning(f"Invalid content type: {audio.content_type}") try: audio_bytes = await audio.read() logger.info(f"Received audio file: {audio.filename}, size: {len(audio_bytes)} bytes") except Exception as e: logger.error(f"Failed to read audio file: {str(e)}") raise HTTPException(status_code=400, detail="Failed to read audio file") # Load audio audio_data, error = load_audio_safe(audio_bytes) if error: logger.warning(f"Audio loading failed: {error}") return STTResponse( transcript="", downstream_response=None, error=error, processing_time_ms=(time.time() - start_time) * 1000 ) # Transcribe try: transcript = transcribe_long( audio_data, app.state.processor, app.state.model ) logger.info(f"Transcription successful, length: {len(transcript)} chars") except Exception as e: logger.error(f"Transcription failed: {str(e)}") return STTResponse( transcript="", downstream_response=None, error=f"Transcription failed: {str(e)}", processing_time_ms=(time.time() - start_time) * 1000 ) downstream = None if transcript and transcript.strip(): try: downstream = forward_to_text_model(transcript) except Exception as e: logger.warning(f"Downstream processing failed: {str(e)}") processing_time_ms = (time.time() - start_time) * 1000 logger.info(f"Request completed in {processing_time_ms:.2f}ms") return STTResponse( transcript=transcript, downstream_response=downstream, error=None, processing_time_ms=processing_time_ms ) @app.get("/health", response_model=HealthResponse) async def health_check(): """ Health check endpoint. Returns service status and configuration. """ import psutil import time uptime = time.time() - app.state.start_time if hasattr(app.state, 'start_time') else None return HealthResponse( status="healthy", device=settings.device, model=settings.model_name, max_audio_seconds=settings.max_audio_seconds, uptime=uptime ) @app.get("/config") async def get_config(): """ Get current configuration (excluding sensitive data). """ return { "model": settings.model_name, "device": settings.device, "sample_rate": settings.sample_rate, "max_audio_seconds": settings.max_audio_seconds, "chunk_seconds": settings.chunk_seconds, "overlap_seconds": settings.overlap_seconds, "cors_enabled": True } @app.on_event("startup") async def startup_event(): app.state.start_time = time.time() if 'time' in locals() else None logger.info("STT service started successfully") if __name__ == "__main__": import uvicorn logger.info(f"Starting server on {settings.host}:{settings.port}") uvicorn.run( "main:app", host=settings.host, port=settings.port, workers=settings.workers, log_level="info" )