salmonn-inference / server.py
marcosremar2's picture
Upload folder using huggingface_hub
a032fae verified
"""
SALMONN FastAPI Server
HTTP API for audio understanding and transcription.
"""
import os
import tempfile
import shutil
from pathlib import Path
from typing import Optional
import yaml
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from omegaconf import OmegaConf
from inference import SALMONNInference
# Load config
CONFIG_PATH = os.environ.get("SALMONN_CONFIG", "config.yaml")
with open(CONFIG_PATH, "r") as f:
config = OmegaConf.create(yaml.safe_load(f))
# Initialize FastAPI app
app = FastAPI(
title="SALMONN API",
description="Audio Language Model for Speech, Audio Events, and Music Understanding",
version="1.0.0",
)
# Global model instance (loaded on startup)
model: Optional[SALMONNInference] = None
class TranscribeResponse(BaseModel):
text: str
status: str = "success"
class ChatResponse(BaseModel):
question: str
answer: str
status: str = "success"
class HealthResponse(BaseModel):
status: str
model_loaded: bool
device: str
@app.on_event("startup")
async def startup_event():
"""Load model on startup."""
global model
print("Starting SALMONN server...")
model = SALMONNInference(CONFIG_PATH)
model.load()
print("Server ready!")
@app.get("/", response_model=dict)
async def root():
"""Root endpoint with API info."""
return {
"name": "SALMONN API",
"version": "1.0.0",
"endpoints": {
"/health": "Health check",
"/transcribe": "Transcribe audio (POST)",
"/chat": "Ask questions about audio (POST)",
}
}
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
return HealthResponse(
status="healthy" if model and model._loaded else "loading",
model_loaded=model._loaded if model else False,
device=str(model.device) if model else "unknown",
)
@app.post("/transcribe", response_model=TranscribeResponse)
async def transcribe(
audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
):
"""
Transcribe an audio file to text.
- **audio**: Audio file to transcribe
Returns transcribed text.
"""
if not model or not model._loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Save uploaded file temporarily
suffix = Path(audio.filename).suffix if audio.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(audio.file, tmp)
tmp_path = tmp.name
try:
text = model.transcribe(tmp_path)
return TranscribeResponse(text=text)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/chat", response_model=ChatResponse)
async def chat(
audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
question: str = Form(..., description="Question about the audio"),
):
"""
Ask a question about an audio file.
- **audio**: Audio file to analyze
- **question**: Question about the audio content
Returns the model's answer.
"""
if not model or not model._loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Save uploaded file temporarily
suffix = Path(audio.filename).suffix if audio.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(audio.file, tmp)
tmp_path = tmp.name
try:
answer = model.chat(tmp_path, question)
return ChatResponse(question=question, answer=answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/describe")
async def describe(
audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
):
"""
Get a detailed description of the audio content.
- **audio**: Audio file to describe
Returns description of the audio.
"""
if not model or not model._loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Save uploaded file temporarily
suffix = Path(audio.filename).suffix if audio.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(audio.file, tmp)
tmp_path = tmp.name
try:
description = model.describe(tmp_path)
return {"description": description, "status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
if __name__ == "__main__":
uvicorn.run(
"server:app",
host=config.server.host,
port=config.server.port,
reload=config.server.get("reload", False),
)