WIN_21 / server.py
ArtemisTAO's picture
Upload folder using huggingface_hub
b024d42 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import base64
import io
import numpy as np
import tempfile
import os
import logging
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
import soundfile as sf
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Allow CORS for all origins (adjust as needed)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request/response models
class AudioRequest(BaseModel):
audio_data: str # Base64-encoded WAV file bytes (voice input)
sample_rate: int
class AudioResponse(BaseModel):
audio_data: str # Base64-encoded WAV file bytes (voice output)
text: str = "" # Optionally include any text response
# Global model and processor variables
model = None
processor = None
def initialize_model():
"""Load the Qwen2.5-Omni model and processor."""
global model, processor
try:
model_path = "./model/Qwen2.5-Omni-7B"
model = Qwen2_5OmniModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
)
processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
@app.on_event("startup")
async def startup_event():
initialize_model()
def inference(audio_file_path: str):
"""
Run inference on the provided audio file path.
The function expects a file path to a WAV file.
Returns both the text output (decoded) and audio output from the model.
"""
messages = [
{
"role": "system",
"content": (
"You are Qwen, an advanced multimodal AI assistant developed by the Qwen Team at Alibaba Group. "
"Your job is to listen carefully to the provided audio input and generate a clear, concise, and friendly response. "
"Please analyze the audio content, understand the user's intent, and provide an informative answer both in text and in speech."
)
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio_file_path},
]
},
]
# Generate the text prompt and process the audio input
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
audios, images, videos = process_mm_info(messages, use_audio_in_video=True)
inputs = processor(
text=text,
audios=audios,
images=images,
videos=videos,
return_tensors="pt",
padding=True
)
inputs = inputs.to(model.device).to(model.dtype)
# Generate output from the model. Return both text and audio.
output = model.generate(**inputs, use_audio_in_video=True, return_audio=True)
text_out = processor.batch_decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
audio_out = output[1] # Model output audio (NumPy array)
return text_out, audio_out
def process_audio_output(audio_out):
"""
Convert the output audio (a NumPy array) to a WAV file in memory,
then return its base64 encoding.
"""
# Assume audio_out values are in [-1, 1]. Convert to 16-bit PCM.
audio_int16 = np.array(audio_out * 32767, dtype=np.int16)
wav_buffer = io.BytesIO()
# Write WAV data to the buffer at the desired sample rate (e.g., 24000 Hz)
sf.write(wav_buffer, audio_int16, samplerate=24000, format="WAV")
wav_buffer.seek(0)
wav_bytes = wav_buffer.getvalue()
return base64.b64encode(wav_bytes).decode("utf-8")
@app.post("/api/v1/inference", response_model=AudioResponse)
async def run_inference(request: AudioRequest) -> AudioResponse:
"""
Voice-to-voice inference endpoint.
Expects a JSON payload with a base64-encoded WAV audio file and its sample rate.
Returns a base64-encoded WAV audio file as output and the corresponding text.
"""
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Decode the incoming base64 audio bytes (voice input)
audio_bytes = base64.b64decode(request.audio_data)
# Write the bytes to a temporary WAV file so that an audio loader can read it
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_bytes)
tmp_file_path = tmp.name
# Run inference using the temporary audio file, capturing both text and audio output
text_response, audio_out = inference(tmp_file_path)
os.remove(tmp_file_path) # Clean up the temporary file
# Convert model output audio to a WAV file and encode to base64
output_audio_b64 = process_audio_output(audio_out)
# Use the first element of the text output (if available)
text_to_return = text_response[0] if text_response else ""
return AudioResponse(audio_data=output_audio_b64, text=text_to_return)
except Exception as e:
logger.error(f"Inference error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/v1/health")
def health_check():
"""Simple health check endpoint."""
return {"status": "ok", "model_loaded": model is not None}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)