| | """ |
| | SALMONN FastAPI Server |
| | HTTP API for audio understanding and transcription. |
| | """ |
| |
|
| | import os |
| | import tempfile |
| | import shutil |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import yaml |
| | import uvicorn |
| | from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
| | from fastapi.responses import JSONResponse |
| | from pydantic import BaseModel |
| | from omegaconf import OmegaConf |
| |
|
| | from inference import SALMONNInference |
| |
|
| | |
| | CONFIG_PATH = os.environ.get("SALMONN_CONFIG", "config.yaml") |
| |
|
| | with open(CONFIG_PATH, "r") as f: |
| | config = OmegaConf.create(yaml.safe_load(f)) |
| |
|
| | |
| | app = FastAPI( |
| | title="SALMONN API", |
| | description="Audio Language Model for Speech, Audio Events, and Music Understanding", |
| | version="1.0.0", |
| | ) |
| |
|
| | |
| | model: Optional[SALMONNInference] = None |
| |
|
| |
|
| | class TranscribeResponse(BaseModel): |
| | text: str |
| | status: str = "success" |
| |
|
| |
|
| | class ChatResponse(BaseModel): |
| | question: str |
| | answer: str |
| | status: str = "success" |
| |
|
| |
|
| | class HealthResponse(BaseModel): |
| | status: str |
| | model_loaded: bool |
| | device: str |
| |
|
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | """Load model on startup.""" |
| | global model |
| | print("Starting SALMONN server...") |
| | model = SALMONNInference(CONFIG_PATH) |
| | model.load() |
| | print("Server ready!") |
| |
|
| |
|
| | @app.get("/", response_model=dict) |
| | async def root(): |
| | """Root endpoint with API info.""" |
| | return { |
| | "name": "SALMONN API", |
| | "version": "1.0.0", |
| | "endpoints": { |
| | "/health": "Health check", |
| | "/transcribe": "Transcribe audio (POST)", |
| | "/chat": "Ask questions about audio (POST)", |
| | } |
| | } |
| |
|
| |
|
| | @app.get("/health", response_model=HealthResponse) |
| | async def health(): |
| | """Health check endpoint.""" |
| | return HealthResponse( |
| | status="healthy" if model and model._loaded else "loading", |
| | model_loaded=model._loaded if model else False, |
| | device=str(model.device) if model else "unknown", |
| | ) |
| |
|
| |
|
| | @app.post("/transcribe", response_model=TranscribeResponse) |
| | async def transcribe( |
| | audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), |
| | ): |
| | """ |
| | Transcribe an audio file to text. |
| | |
| | - **audio**: Audio file to transcribe |
| | |
| | Returns transcribed text. |
| | """ |
| | if not model or not model._loaded: |
| | raise HTTPException(status_code=503, detail="Model not loaded yet") |
| | |
| | |
| | suffix = Path(audio.filename).suffix if audio.filename else ".wav" |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| | shutil.copyfileobj(audio.file, tmp) |
| | tmp_path = tmp.name |
| | |
| | try: |
| | text = model.transcribe(tmp_path) |
| | return TranscribeResponse(text=text) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | finally: |
| | os.unlink(tmp_path) |
| |
|
| |
|
| | @app.post("/chat", response_model=ChatResponse) |
| | async def chat( |
| | audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), |
| | question: str = Form(..., description="Question about the audio"), |
| | ): |
| | """ |
| | Ask a question about an audio file. |
| | |
| | - **audio**: Audio file to analyze |
| | - **question**: Question about the audio content |
| | |
| | Returns the model's answer. |
| | """ |
| | if not model or not model._loaded: |
| | raise HTTPException(status_code=503, detail="Model not loaded yet") |
| | |
| | |
| | suffix = Path(audio.filename).suffix if audio.filename else ".wav" |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| | shutil.copyfileobj(audio.file, tmp) |
| | tmp_path = tmp.name |
| | |
| | try: |
| | answer = model.chat(tmp_path, question) |
| | return ChatResponse(question=question, answer=answer) |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | finally: |
| | os.unlink(tmp_path) |
| |
|
| |
|
| | @app.post("/describe") |
| | async def describe( |
| | audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"), |
| | ): |
| | """ |
| | Get a detailed description of the audio content. |
| | |
| | - **audio**: Audio file to describe |
| | |
| | Returns description of the audio. |
| | """ |
| | if not model or not model._loaded: |
| | raise HTTPException(status_code=503, detail="Model not loaded yet") |
| | |
| | |
| | suffix = Path(audio.filename).suffix if audio.filename else ".wav" |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| | shutil.copyfileobj(audio.file, tmp) |
| | tmp_path = tmp.name |
| | |
| | try: |
| | description = model.describe(tmp_path) |
| | return {"description": description, "status": "success"} |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | finally: |
| | os.unlink(tmp_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run( |
| | "server:app", |
| | host=config.server.host, |
| | port=config.server.port, |
| | reload=config.server.get("reload", False), |
| | ) |
| |
|