File size: 5,056 Bytes
a032fae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
SALMONN FastAPI Server
HTTP API for audio understanding and transcription.
"""

import os
import tempfile
import shutil
from pathlib import Path
from typing import Optional

import yaml
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from omegaconf import OmegaConf

from inference import SALMONNInference

# Load config
CONFIG_PATH = os.environ.get("SALMONN_CONFIG", "config.yaml")

with open(CONFIG_PATH, "r") as f:
    config = OmegaConf.create(yaml.safe_load(f))

# Initialize FastAPI app
app = FastAPI(
    title="SALMONN API",
    description="Audio Language Model for Speech, Audio Events, and Music Understanding",
    version="1.0.0",
)

# Global model instance (loaded on startup)
model: Optional[SALMONNInference] = None


class TranscribeResponse(BaseModel):
    text: str
    status: str = "success"


class ChatResponse(BaseModel):
    question: str
    answer: str
    status: str = "success"


class HealthResponse(BaseModel):
    status: str
    model_loaded: bool
    device: str


@app.on_event("startup")
async def startup_event():
    """Load model on startup."""
    global model
    print("Starting SALMONN server...")
    model = SALMONNInference(CONFIG_PATH)
    model.load()
    print("Server ready!")


@app.get("/", response_model=dict)
async def root():
    """Root endpoint with API info."""
    return {
        "name": "SALMONN API",
        "version": "1.0.0",
        "endpoints": {
            "/health": "Health check",
            "/transcribe": "Transcribe audio (POST)",
            "/chat": "Ask questions about audio (POST)",
        }
    }


@app.get("/health", response_model=HealthResponse)
async def health():
    """Health check endpoint."""
    return HealthResponse(
        status="healthy" if model and model._loaded else "loading",
        model_loaded=model._loaded if model else False,
        device=str(model.device) if model else "unknown",
    )


@app.post("/transcribe", response_model=TranscribeResponse)
async def transcribe(
    audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
):
    """
    Transcribe an audio file to text.
    
    - **audio**: Audio file to transcribe
    
    Returns transcribed text.
    """
    if not model or not model._loaded:
        raise HTTPException(status_code=503, detail="Model not loaded yet")
    
    # Save uploaded file temporarily
    suffix = Path(audio.filename).suffix if audio.filename else ".wav"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        shutil.copyfileobj(audio.file, tmp)
        tmp_path = tmp.name
    
    try:
        text = model.transcribe(tmp_path)
        return TranscribeResponse(text=text)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        os.unlink(tmp_path)


@app.post("/chat", response_model=ChatResponse)
async def chat(
    audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
    question: str = Form(..., description="Question about the audio"),
):
    """
    Ask a question about an audio file.
    
    - **audio**: Audio file to analyze
    - **question**: Question about the audio content
    
    Returns the model's answer.
    """
    if not model or not model._loaded:
        raise HTTPException(status_code=503, detail="Model not loaded yet")
    
    # Save uploaded file temporarily
    suffix = Path(audio.filename).suffix if audio.filename else ".wav"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        shutil.copyfileobj(audio.file, tmp)
        tmp_path = tmp.name
    
    try:
        answer = model.chat(tmp_path, question)
        return ChatResponse(question=question, answer=answer)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        os.unlink(tmp_path)


@app.post("/describe")
async def describe(
    audio: UploadFile = File(..., description="Audio file (wav, mp3, etc.)"),
):
    """
    Get a detailed description of the audio content.
    
    - **audio**: Audio file to describe
    
    Returns description of the audio.
    """
    if not model or not model._loaded:
        raise HTTPException(status_code=503, detail="Model not loaded yet")
    
    # Save uploaded file temporarily
    suffix = Path(audio.filename).suffix if audio.filename else ".wav"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        shutil.copyfileobj(audio.file, tmp)
        tmp_path = tmp.name
    
    try:
        description = model.describe(tmp_path)
        return {"description": description, "status": "success"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        os.unlink(tmp_path)


if __name__ == "__main__":
    uvicorn.run(
        "server:app",
        host=config.server.host,
        port=config.server.port,
        reload=config.server.get("reload", False),
    )