File size: 5,056 Bytes
a032fae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | """
SALMONN FastAPI Server
HTTP API for audio understanding and transcription.
"""
import os
import tempfile
import shutil
from pathlib import Path
from typing import Optional
import yaml
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from omegaconf import OmegaConf
from inference import SALMONNInference
# Load config
CONFIG_PATH = os.environ.get("SALMONN_CONFIG", "config.yaml")
with open(CONFIG_PATH, "r") as f:
config = OmegaConf.create(yaml.safe_load(f))
# Initialize FastAPI app
app = FastAPI(
title="SALMONN API",
description="Audio Language Model for Speech, Audio Events, and Music Understanding",
version="1.0.0",
)
# Global model instance (loaded on startup)
model: Optional[SALMONNInference] = None
class TranscribeResponse(BaseModel):
text: str
status: str = "success"
class ChatResponse(BaseModel):
question: str
answer: str
status: str = "success"
class HealthResponse(BaseModel):
status: str
model_loaded: bool
device: str
@app.on_event("startup")
async def startup_event():
"""Load model on startup."""
global model
print("Starting SALMONN server...")
model = SALMONNInference(CONFIG_PATH)
model.load()
print("Server ready!")
@app.get("/", response_model=dict)
async def root():
"""Root endpoint with API info."""
return {
"name": "SALMONN API",
"version": "1.0.0",
"endpoints": {
"/health": "Health check",
"/transcribe": "Transcribe audio (POST)",
"/chat": "Ask questions about audio (POST)",
}
}
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
return HealthResponse(
status="healthy" if model and model._loaded else "loading",
model_loaded=model._loaded if model else False,
device=str(model.device) if model else "unknown",
)
@app.post("/transcribe", response_model=TranscribeResponse)
async def transcribe(
audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
):
"""
Transcribe an audio file to text.
- **audio**: Audio file to transcribe
Returns transcribed text.
"""
if not model or not model._loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Save uploaded file temporarily
suffix = Path(audio.filename).suffix if audio.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(audio.file, tmp)
tmp_path = tmp.name
try:
text = model.transcribe(tmp_path)
return TranscribeResponse(text=text)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/chat", response_model=ChatResponse)
async def chat(
audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
question: str = Form(..., description="Question about the audio"),
):
"""
Ask a question about an audio file.
- **audio**: Audio file to analyze
- **question**: Question about the audio content
Returns the model's answer.
"""
if not model or not model._loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Save uploaded file temporarily
suffix = Path(audio.filename).suffix if audio.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(audio.file, tmp)
tmp_path = tmp.name
try:
answer = model.chat(tmp_path, question)
return ChatResponse(question=question, answer=answer)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
@app.post("/describe")
async def describe(
audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
):
"""
Get a detailed description of the audio content.
- **audio**: Audio file to describe
Returns description of the audio.
"""
if not model or not model._loaded:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Save uploaded file temporarily
suffix = Path(audio.filename).suffix if audio.filename else ".wav"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
shutil.copyfileobj(audio.file, tmp)
tmp_path = tmp.name
try:
description = model.describe(tmp_path)
return {"description": description, "status": "success"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
os.unlink(tmp_path)
if __name__ == "__main__":
uvicorn.run(
"server:app",
host=config.server.host,
port=config.server.port,
reload=config.server.get("reload", False),
)
|