SussurroXRest / app.py
LucaR84's picture
fixed ready endpoint
d36bbaa
from fastapi import FastAPI, UploadFile, File, HTTPException, status, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import tempfile
import os
import sync_hc
import logging
from pydantic import BaseModel
from speech_models.whisperx_manager import WhisperXModelManager
from utils.audio_buffer import AudioBuffer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TranscriptionRequest(BaseModel):
language: str = "en"
task: str = "transcribe"
chunk_duration: float = 30.0
class StreamingConfig(BaseModel):
language: str = "en"
task: str = "transcribe"
chunk_duration: float = 30.0
# Initialize FastAPI app
app = FastAPI(title="Speech Transcription API", version="1.0.0")
# Global model manager instance
model_manager = WhisperXModelManager(model_name="tiny")
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup"""
if not model_manager.is_loaded:
raise RuntimeError("Failed to load model at startup")
@app.get("/health", status_code=status.HTTP_200_OK)
async def health_check():
"""Health check endpoint"""
return JSONResponse(
content={"status": "healthy", "model_loaded": model_manager.is_loaded}
)
@app.get("/ready", status_code=status.HTTP_200_OK)
async def readiness_check():
"""Readiness check endpoint"""
if model_manager.is_loaded:
return JSONResponse(
content={"status": "ready", "model_name": model_manager.model_name}
)
else:
return JSONResponse(
content={"status": "not ready", "model_loaded": False},
status_code=status.HTTP_503_SERVICE_UNAVAILABLE
)
@app.post("/transcribe", status_code=status.HTTP_200_OK)
async def transcribe_audio(
file: UploadFile = File(...),
language: str = "en",
chunk_duration: float = 30.0
):
"""Transcribe uploaded audio file with chunking"""
if not model_manager.is_loaded:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model not loaded"
)
# Validate file extension
allowed_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.ogg', '.aac'}
file_extension = os.path.splitext(file.filename)[1].lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid file type. Allowed types: {', '.join(allowed_extensions)}"
)
# Validate chunk duration
if chunk_duration <= 0 or chunk_duration > 60:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Chunk duration must be between 1 and 60 seconds"
)
try:
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file:
temp_file.write(await file.read())
temp_file_path = temp_file.name
# Transcribe audio with chunking
result = model_manager.transcribe(temp_file_path, language, chunk_duration)
# Clean up temporary file
os.unlink(temp_file_path)
return JSONResponse(content=result)
except Exception as e:
# Clean up temporary file if exists
if 'temp_file_path' in locals():
try:
os.unlink(temp_file_path)
except:
pass
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Transcription failed: {str(e)}"
)
@app.websocket("/ws/transcribe")
async def websocket_transcribe(websocket: WebSocket):
"""WebSocket endpoint for streaming transcription with chunking"""
await websocket.accept()
if not model_manager.is_loaded:
await websocket.send_json({
"error": "Model not loaded",
"status": "error"
})
await websocket.close()
return
audio_buffer = None
try:
# Receive configuration
config_data = await websocket.receive_json()
language = config_data.get("language", "en")
task = config_data.get("task", "transcribe")
chunk_duration = config_data.get("chunk_duration", 30.0)
# Initialize audio buffer
audio_buffer = AudioBuffer(chunk_duration=chunk_duration)
logger.info(f"Starting streaming transcription session - Language: {language}, Chunk Duration: {chunk_duration}s")
# Send acknowledgment
await websocket.send_json({
"status": "connected",
"message": "Ready to receive audio data",
"chunk_duration": chunk_duration
})
# Process incoming audio chunks
while True:
try:
# Receive audio data
data = await websocket.receive_bytes()
if not data:
continue
# Add to buffer and check if chunk is ready
ready_chunk = audio_buffer.add_data(data)
if ready_chunk is not None:
logger.info("Processing buffered audio chunk")
# Process complete chunk
result = await model_manager.transcribe_stream(
ready_chunk.tobytes(),
language,
chunk_duration
)
# Send transcription result
await websocket.send_json({
"status": "transcription",
"result": result,
"chunk_complete": True
})
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
break
except Exception as e:
logger.error(f"Error processing audio chunk: {str(e)}")
await websocket.send_json({
"error": str(e),
"status": "error"
})
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
finally:
# Process any remaining audio data
if audio_buffer and len(audio_buffer.get_remaining()) > 0:
try:
remaining_audio = audio_buffer.get_remaining()
if len(remaining_audio) > 16000: # At least 1 second
result = await model_manager.transcribe_stream(
remaining_audio.tobytes(),
language
)
await websocket.send_json({
"status": "transcription",
"result": result,
"final": True
})
except Exception as e:
logger.error(f"Error processing final chunk: {str(e)}")
await websocket.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)