STT_FARMLINGUA / app /main.py
drrobot9's picture
Update app/main.py
66075f1 verified
# main.py
import json
import torch
import torchaudio
import requests
import numpy as np
import tempfile
import os
import logging
from typing import Optional, Dict, Any
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from transformers import AutoProcessor, AutoModelForCTC
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from pydub import AudioSegment
from contextlib import asynccontextmanager
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class Settings(BaseSettings):
model_name: str = "facebook/mms-1b-all"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
text_model_url: str = "https://remostartdev-farmlingua-ai-conversational.hf.space/ask"
timeout_sec: int = 3000
# Audio processing settings
sample_rate: int = 16000
max_audio_seconds: int = 300
chunk_seconds: int = 20
overlap_seconds: int = 2
host: str = "0.0.0.0"
port: int = 7860
workers: int = 1
# CORS settings
cors_origins: list = ["*"]
cors_methods: list = ["*"]
cors_headers: list = ["*"]
class Config:
env_file = ".env"
env_prefix = "STT_"
settings = Settings()
# Calculate derived constants
MAX_SAMPLES = settings.sample_rate * settings.max_audio_seconds
CHUNK_SIZE = settings.chunk_seconds * settings.sample_rate
OVERLAP = settings.overlap_seconds * settings.sample_rate
STEP = CHUNK_SIZE - OVERLAP
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info(f"Starting STT service with device: {settings.device}")
logger.info(f"Loading model: {settings.model_name}")
try:
# Initialize processor and model
app.state.processor = AutoProcessor.from_pretrained(settings.model_name)
app.state.model = AutoModelForCTC.from_pretrained(settings.model_name).to(settings.device)
app.state.model.eval()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
yield
logger.info("Shutting down STT service")
if hasattr(app.state, 'model'):
del app.state.model
torch.cuda.empty_cache()
app = FastAPI(
title="Universal Audio STT",
version="1.5.0",
description="Speech-to-Text service with support for multiple audio formats",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=settings.cors_methods,
allow_headers=settings.cors_headers,
)
class STTResponse(BaseModel):
transcript: str
downstream_response: Optional[Dict[str, Any]] = None
error: Optional[str] = None
processing_time_ms: Optional[float] = None
class HealthResponse(BaseModel):
status: str
device: str
model: str
max_audio_seconds: int
uptime: Optional[float] = None
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"transcript": "",
"downstream_response": None,
"error": f"Internal server error: {str(exc)}"
}
)
def load_audio_safe(file_bytes: bytes) -> tuple[np.ndarray | None, str | None]:
"""Load audio file using pydub (supports more formats) with torchaudio fallback."""
if not file_bytes:
return None, "Empty audio file"
if len(file_bytes) == 0:
return None, "Empty audio file"
with tempfile.NamedTemporaryFile(suffix='.audio', delete=False) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
try:
audio = AudioSegment.from_file(tmp_path)
if len(audio) == 0:
return None, "Audio file is empty"
if audio.channels > 1:
audio = audio.set_channels(1)
if audio.frame_rate != settings.sample_rate:
audio = audio.set_frame_rate(settings.sample_rate)
samples = np.array(audio.get_array_of_samples()).astype(np.float32)
if audio.sample_width == 1: # 8-bit
samples = samples / 127.5 - 1.0
elif audio.sample_width == 2: # 16-bit
samples = samples / 32768.0
elif audio.sample_width == 3: # 24-bit
samples = samples / 8388608.0
elif audio.sample_width == 4: # 32-bit
samples = samples / 2147483648.0
else:
if len(samples) > 0:
max_val = np.max(np.abs(samples))
if max_val > 0:
samples = samples / max_val
except Exception as pydub_error:
logger.warning(f"Pydub failed, trying torchaudio: {pydub_error}")
try:
waveform, sr = torchaudio.load(tmp_path)
if waveform.numel() == 0:
return None, "Audio contains no samples"
waveform = waveform.mean(dim=0)
if sr != settings.sample_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=sr, new_freq=settings.sample_rate
)
samples = waveform.numpy()
except Exception as torchaudio_error:
logger.error(f"Both pydub and torchaudio failed: {torchaudio_error}")
return None, f"Unsupported audio format. Supported formats: MP3, WAV, M4A, FLAC, OGG, etc."
except Exception as e:
logger.error(f"Failed to load audio: {str(e)}")
return None, f"Failed to process audio file: {str(e)}"
finally:
try:
os.unlink(tmp_path)
except:
pass
if len(samples) == 0:
return None, "Audio contains no samples"
if len(samples) > MAX_SAMPLES:
return None, f"Audio exceeds {settings.max_audio_seconds // 60} minute limit ({settings.max_audio_seconds} seconds)"
return samples, None
def chunk_audio(audio: np.ndarray):
"""Split audio into overlapping chunks for processing."""
for start in range(0, len(audio), STEP):
chunk = audio[start:start + CHUNK_SIZE]
if len(chunk) < settings.sample_rate: # Less than 1 second
break
yield chunk
def transcribe_chunk(chunk: np.ndarray, processor, model) -> str:
"""Transcribe a single audio chunk."""
inputs = processor(
chunk,
sampling_rate=settings.sample_rate,
return_tensors="pt",
padding=True,
)
with torch.no_grad():
logits = model(inputs.input_values.to(settings.device)).logits
predicted_ids = torch.argmax(logits, dim=-1)
return processor.batch_decode(predicted_ids)[0].strip()
def transcribe_long(audio: np.ndarray, processor, model) -> str:
"""Transcribe long audio by processing in chunks."""
texts = []
for chunk in chunk_audio(audio):
text = transcribe_chunk(chunk, processor, model)
if text:
texts.append(text)
return " ".join(texts)
def forward_to_text_model(text: str) -> Optional[Dict[str, Any]]:
"""Forward transcribed text to downstream text model."""
if not text or not text.strip():
return None
try:
response = requests.post(
settings.text_model_url,
json={"query": text},
timeout=settings.timeout_sec,
)
response.raise_for_status()
return response.json()
except requests.exceptions.Timeout:
logger.warning("Downstream text model timeout")
return None
except requests.exceptions.RequestException as e:
logger.warning(f"Downstream text model error: {str(e)}")
return None
@app.post("/stt", response_model=STTResponse)
async def stt(audio: UploadFile = File(...)):
"""
Speech-to-Text endpoint.
Accepts audio files in various formats (MP3, WAV, M4A, FLAC, OGG, etc.)
and returns transcribed text.
"""
import time
start_time = time.time()
# Validate file
if not audio.content_type or not audio.content_type.startswith('audio/'):
logger.warning(f"Invalid content type: {audio.content_type}")
try:
audio_bytes = await audio.read()
logger.info(f"Received audio file: {audio.filename}, size: {len(audio_bytes)} bytes")
except Exception as e:
logger.error(f"Failed to read audio file: {str(e)}")
raise HTTPException(status_code=400, detail="Failed to read audio file")
# Load audio
audio_data, error = load_audio_safe(audio_bytes)
if error:
logger.warning(f"Audio loading failed: {error}")
return STTResponse(
transcript="",
downstream_response=None,
error=error,
processing_time_ms=(time.time() - start_time) * 1000
)
# Transcribe
try:
transcript = transcribe_long(
audio_data,
app.state.processor,
app.state.model
)
logger.info(f"Transcription successful, length: {len(transcript)} chars")
except Exception as e:
logger.error(f"Transcription failed: {str(e)}")
return STTResponse(
transcript="",
downstream_response=None,
error=f"Transcription failed: {str(e)}",
processing_time_ms=(time.time() - start_time) * 1000
)
downstream = None
if transcript and transcript.strip():
try:
downstream = forward_to_text_model(transcript)
except Exception as e:
logger.warning(f"Downstream processing failed: {str(e)}")
processing_time_ms = (time.time() - start_time) * 1000
logger.info(f"Request completed in {processing_time_ms:.2f}ms")
return STTResponse(
transcript=transcript,
downstream_response=downstream,
error=None,
processing_time_ms=processing_time_ms
)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint.
Returns service status and configuration.
"""
import psutil
import time
uptime = time.time() - app.state.start_time if hasattr(app.state, 'start_time') else None
return HealthResponse(
status="healthy",
device=settings.device,
model=settings.model_name,
max_audio_seconds=settings.max_audio_seconds,
uptime=uptime
)
@app.get("/config")
async def get_config():
"""
Get current configuration (excluding sensitive data).
"""
return {
"model": settings.model_name,
"device": settings.device,
"sample_rate": settings.sample_rate,
"max_audio_seconds": settings.max_audio_seconds,
"chunk_seconds": settings.chunk_seconds,
"overlap_seconds": settings.overlap_seconds,
"cors_enabled": True
}
@app.on_event("startup")
async def startup_event():
app.state.start_time = time.time() if 'time' in locals() else None
logger.info("STT service started successfully")
if __name__ == "__main__":
import uvicorn
logger.info(f"Starting server on {settings.host}:{settings.port}")
uvicorn.run(
"main:app",
host=settings.host,
port=settings.port,
workers=settings.workers,
log_level="info"
)