File size: 2,655 Bytes
57cfcfd
 
 
 
4c90601
0fd89a7
 
57cfcfd
 
4c90601
57cfcfd
 
 
 
0fd89a7
 
 
 
 
 
57cfcfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f203179
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# audio_api.py
import base64
import io
from typing import Optional

import torch
import torchaudio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse

# -------------------- ζ¨‘εž‹εŠ θ½½ --------------------
MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
device = "cuda" if torch.cuda.is_available() else "cpu"

serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device)

# -------------------- FastAPI --------------------
app = FastAPI(title="Higgs Audio Generation API", version="0.1.0")

class AudioRequest(BaseModel):
    user_prompt: str = Field(..., description="ιœ€θ¦η”ŸζˆιŸ³ι’‘ηš„ζ–‡ζœ¬")
    max_new_tokens: Optional[int] = Field(1024, ge=1, le=2048)
    temperature: Optional[float] = Field(0.3, ge=0.0, le=2.0)
    top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0)
    top_k: Optional[int] = Field(50, ge=1, le=100)

class AudioResponse(BaseModel):
    audio_base64: str
    sample_rate: int

@app.post("/generate-audio", response_model=AudioResponse)
def generate_audio(req: AudioRequest):
    system_prompt = (
        "Generate audio following instruction.\n\n<|scene_desc_start|>\n"
        "Audio is recorded from a quiet room.\n<|scene_desc_end|>"
    )
    messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=req.user_prompt),
    ]

    try:
        output: HiggsAudioResponse = serve_engine.generate(
            chat_ml_sample=ChatMLSample(messages=messages),
            max_new_tokens=req.max_new_tokens,
            temperature=req.temperature,
            top_p=req.top_p,
            top_k=req.top_k,
            stop_strings=["<|end_of_text|>", "<|eot_id|>"],
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    # 把 numpy 数组转 torch.Tensor εΉΆηΌ–η ζˆ WAV ε­—θŠ‚ζ΅
    waveform = torch.from_numpy(output.audio)[None, :]  # shape=(1, T)
    buf = io.BytesIO()
    torchaudio.save(buf, waveform, output.sampling_rate, format="wav")
    audio_bytes = buf.getvalue()
    audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")

    return AudioResponse(audio_base64=audio_b64, sample_rate=output.sampling_rate)

# ζ–°ε’žοΌšζŠŠ / ζŒ‡ε‘ι™ζ€ι¦–ι‘΅
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/", include_in_schema=False)
async def index():
    return FileResponse("static/index.html")