fexeak
add static
f203179
raw
history blame
2.66 kB
# audio_api.py
import base64
import io
from typing import Optional
import torch
import torchaudio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
# -------------------- ζ¨‘εž‹εŠ θ½½ --------------------
MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
device = "cuda" if torch.cuda.is_available() else "cpu"
serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device)
# -------------------- FastAPI --------------------
app = FastAPI(title="Higgs Audio Generation API", version="0.1.0")
class AudioRequest(BaseModel):
user_prompt: str = Field(..., description="ιœ€θ¦η”ŸζˆιŸ³ι’‘ηš„ζ–‡ζœ¬")
max_new_tokens: Optional[int] = Field(1024, ge=1, le=2048)
temperature: Optional[float] = Field(0.3, ge=0.0, le=2.0)
top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0)
top_k: Optional[int] = Field(50, ge=1, le=100)
class AudioResponse(BaseModel):
audio_base64: str
sample_rate: int
@app.post("/generate-audio", response_model=AudioResponse)
def generate_audio(req: AudioRequest):
system_prompt = (
"Generate audio following instruction.\n\n<|scene_desc_start|>\n"
"Audio is recorded from a quiet room.\n<|scene_desc_end|>"
)
messages = [
Message(role="system", content=system_prompt),
Message(role="user", content=req.user_prompt),
]
try:
output: HiggsAudioResponse = serve_engine.generate(
chat_ml_sample=ChatMLSample(messages=messages),
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
top_k=req.top_k,
stop_strings=["<|end_of_text|>", "<|eot_id|>"],
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 把 numpy 数组转 torch.Tensor εΉΆηΌ–η ζˆ WAV ε­—θŠ‚ζ΅
waveform = torch.from_numpy(output.audio)[None, :] # shape=(1, T)
buf = io.BytesIO()
torchaudio.save(buf, waveform, output.sampling_rate, format="wav")
audio_bytes = buf.getvalue()
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
return AudioResponse(audio_base64=audio_b64, sample_rate=output.sampling_rate)
# ζ–°ε’žοΌšζŠŠ / ζŒ‡ε‘ι™ζ€ι¦–ι‘΅
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", include_in_schema=False)
async def index():
return FileResponse("static/index.html")