Spaces:
Sleeping
Sleeping
| # audio_api.py | |
| import base64 | |
| import io | |
| from typing import Optional | |
| import torch | |
| import torchaudio | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from boson_multimodal.data_types import ChatMLSample, Message, AudioContent | |
| from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse | |
| # -------------------- 樑εε θ½½ -------------------- | |
| MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base" | |
| AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device) | |
| # -------------------- FastAPI -------------------- | |
| app = FastAPI(title="Higgs Audio Generation API", version="0.1.0") | |
| class AudioRequest(BaseModel): | |
| user_prompt: str = Field(..., description="ιθ¦ηζι³ι’ηζζ¬") | |
| max_new_tokens: Optional[int] = Field(1024, ge=1, le=2048) | |
| temperature: Optional[float] = Field(0.3, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0) | |
| top_k: Optional[int] = Field(50, ge=1, le=100) | |
| class AudioResponse(BaseModel): | |
| audio_base64: str | |
| sample_rate: int | |
| def generate_audio(req: AudioRequest): | |
| system_prompt = ( | |
| "Generate audio following instruction.\n\n<|scene_desc_start|>\n" | |
| "Audio is recorded from a quiet room.\n<|scene_desc_end|>" | |
| ) | |
| messages = [ | |
| Message(role="system", content=system_prompt), | |
| Message(role="user", content=req.user_prompt), | |
| ] | |
| try: | |
| output: HiggsAudioResponse = serve_engine.generate( | |
| chat_ml_sample=ChatMLSample(messages=messages), | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| top_k=req.top_k, | |
| stop_strings=["<|end_of_text|>", "<|eot_id|>"], | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ζ numpy ζ°η»θ½¬ torch.Tensor εΉΆηΌη ζ WAV εθζ΅ | |
| waveform = torch.from_numpy(output.audio)[None, :] # shape=(1, T) | |
| buf = io.BytesIO() | |
| torchaudio.save(buf, waveform, output.sampling_rate, format="wav") | |
| audio_bytes = buf.getvalue() | |
| audio_b64 = base64.b64encode(audio_bytes).decode("utf-8") | |
| return AudioResponse(audio_base64=audio_b64, sample_rate=output.sampling_rate) | |
| # ζ°ε’οΌζ / ζειζι¦ι‘΅ | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| async def index(): | |
| return FileResponse("static/index.html") |