""" SALMONN FastAPI Server HTTP API for audio understanding and transcription. """ import os import tempfile import shutil from pathlib import Path from typing import Optional import yaml import uvicorn from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from omegaconf import OmegaConf from inference import SALMONNInference # Load config CONFIG_PATH = os.environ.get("SALMONN_CONFIG", "config.yaml") with open(CONFIG_PATH, "r") as f: config = OmegaConf.create(yaml.safe_load(f)) # Initialize FastAPI app app = FastAPI( title="SALMONN API", description="Audio Language Model for Speech, Audio Events, and Music Understanding", version="1.0.0", ) # Global model instance (loaded on startup) model: Optional[SALMONNInference] = None class TranscribeResponse(BaseModel): text: str status: str = "success" class ChatResponse(BaseModel): question: str answer: str status: str = "success" class HealthResponse(BaseModel): status: str model_loaded: bool device: str @app.on_event("startup") async def startup_event(): """Load model on startup.""" global model print("Starting SALMONN server...") model = SALMONNInference(CONFIG_PATH) model.load() print("Server ready!") @app.get("/", response_model=dict) async def root(): """Root endpoint with API info.""" return { "name": "SALMONN API", "version": "1.0.0", "endpoints": { "/health": "Health check", "/transcribe": "Transcribe audio (POST)", "/chat": "Ask questions about audio (POST)", } } @app.get("/health", response_model=HealthResponse) async def health(): """Health check endpoint.""" return HealthResponse( status="healthy" if model and model._loaded else "loading", model_loaded=model._loaded if model else False, device=str(model.device) if model else "unknown", ) @app.post("/transcribe", response_model=TranscribeResponse) async def transcribe( audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), ): """ Transcribe an audio file to text. - **audio**: Audio file to transcribe Returns transcribed text. """ if not model or not model._loaded: raise HTTPException(status_code=503, detail="Model not loaded yet") # Save uploaded file temporarily suffix = Path(audio.filename).suffix if audio.filename else ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: shutil.copyfileobj(audio.file, tmp) tmp_path = tmp.name try: text = model.transcribe(tmp_path) return TranscribeResponse(text=text) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: os.unlink(tmp_path) @app.post("/chat", response_model=ChatResponse) async def chat( audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), question: str = Form(..., description="Question about the audio"), ): """ Ask a question about an audio file. - **audio**: Audio file to analyze - **question**: Question about the audio content Returns the model's answer. """ if not model or not model._loaded: raise HTTPException(status_code=503, detail="Model not loaded yet") # Save uploaded file temporarily suffix = Path(audio.filename).suffix if audio.filename else ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: shutil.copyfileobj(audio.file, tmp) tmp_path = tmp.name try: answer = model.chat(tmp_path, question) return ChatResponse(question=question, answer=answer) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: os.unlink(tmp_path) @app.post("/describe") async def describe( audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), ): """ Get a detailed description of the audio content. - **audio**: Audio file to describe Returns description of the audio. """ if not model or not model._loaded: raise HTTPException(status_code=503, detail="Model not loaded yet") # Save uploaded file temporarily suffix = Path(audio.filename).suffix if audio.filename else ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: shutil.copyfileobj(audio.file, tmp) tmp_path = tmp.name try: description = model.describe(tmp_path) return {"description": description, "status": "success"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: os.unlink(tmp_path) if __name__ == "__main__": uvicorn.run( "server:app", host=config.server.host, port=config.server.port, reload=config.server.get("reload", False), )