WIN_21_2 / server.py
ArtemisTAO's picture
Update server.py
81717af verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import base64
import io
import numpy as np
import tempfile
import os
import logging
import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
# Ensure you have qwen_omni_utils installed, e.g., pip install qwen-omni-utils
from qwen_omni_utils import process_mm_info
import soundfile as sf
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Allow CORS for all origins (adjust as needed for production)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request/response models
class AudioRequest(BaseModel):
audio_data: str # Base64-encoded WAV file bytes (voice input)
sample_rate: int # Include sample rate for input audio
class AudioResponse(BaseModel):
audio_data: str # Base64-encoded WAV file bytes (voice output)
text: str = "" # Optionally include any text response
output_sample_rate: int # Include sample rate for output audio
# Global model and processor variables
model = None
processor = None
# Define the expected sample rate for the model's audio output
# Qwen2.5-Omni typically outputs audio at 24kHz
MODEL_OUTPUT_SAMPLE_RATE = 24000
def initialize_model():
"""Load the Qwen2.5-Omni model and processor."""
global model, processor
try:
# It's recommended to use the full Hugging Face model ID for clarity
model_path = "Qwen/Qwen2.5-Omni-7B" # Assuming this is your local path or Hugging Face ID
# Ensure that transformers and qwen-omni-utils are compatible with this model version.
# You might need to install them from source or specific versions if you encounter errors:
# pip uninstall transformers -y
# pip install git+https://github.com/huggingface/transformers@main # or a specific branch/commit
# pip install qwen-omni-utils -U
model = Qwen2_5OmniModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
#attn_implementation="flash_attention_2", # Uncomment if you have flash-attention installed
)
processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
@app.on_event("startup")
async def startup_event():
initialize_model()
def inference(audio_data: np.ndarray, input_sample_rate: int):
"""
Run inference on the provided audio data (NumPy array).
Returns both the text output (decoded) and audio output from the model.
"""
if processor is None or model is None:
raise ValueError("Model and processor are not initialized.")
# The processor expects audio as a list of dictionaries with "array" and "sampling_rate"
audio_input_for_processor = [{"array": audio_data, "sampling_rate": input_sample_rate}]
messages = [
{
"role": "system",
"content": (
"You are Qwen, an advanced multimodal AI assistant developed by the Qwen Team at Alibaba Group. "
"Your job is to listen carefully to the provided audio input and generate a clear, concise, and friendly response. "
"Please analyze the audio content, understand the user's intent, and provide an informative answer both in text and in speech."
)
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio_input_for_processor[0]}, # Pass the prepared audio dictionary
]
},
]
# Generate the text prompt and process the audio input
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# process_mm_info extracts modalities from messages and prepares them for the processor
# For audio, it expects {"audio": {"array": ..., "sampling_rate": ...}} in messages.
audios, images, videos = process_mm_info(messages, use_audio_in_video=True)
inputs = processor(
text=text_prompt,
audios=audios, # This will be the processed audio array from process_mm_info
images=images,
videos=videos,
return_tensors="pt",
padding=True
)
inputs = inputs.to(model.device).to(model.dtype)
# Generate output from the model. Return both text and audio.
# The output from model.generate should be a tuple (text_ids, audio_array)
# Ensure `return_audio=True` is correctly handled by your `transformers` version.
output = model.generate(**inputs, use_audio_in_video=True, return_audio=True)
# Check if the output is a tuple and has the expected number of elements
if isinstance(output, tuple) and len(output) >= 2:
text_out_ids = output[0]
audio_out_array = output[1]
else:
# Handle cases where output might just be text_ids or an unexpected format
logger.warning(f"Unexpected output format from model.generate: {type(output)}")
text_out_ids = output # Assume text IDs are the first element
audio_out_array = None # No audio generated or not in expected format
text_out = processor.batch_decode(text_out_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
# Qwen2.5-Omni's audio output is typically a torch.Tensor, reshape and convert to numpy
if audio_out_array is not None:
audio_out_array = audio_out_array.reshape(-1).detach().cpu().numpy()
return text_out, audio_out_array
def process_audio_output(audio_out: np.ndarray, sample_rate: int) -> str:
"""
Convert the output audio (a NumPy array) to a WAV file in memory,
then return its base64 encoding.
"""
if audio_out is None:
return "" # Return empty string if no audio output
# Assume audio_out values are in [-1, 1]. Convert to 16-bit PCM.
# Ensure the audio_out is a numpy array before conversion
audio_int16 = np.array(audio_out * 32767, dtype=np.int16)
wav_buffer = io.BytesIO()
# Write WAV data to the buffer at the desired sample rate
sf.write(wav_buffer, audio_int16, samplerate=sample_rate, format="WAV")
wav_buffer.seek(0)
wav_bytes = wav_buffer.getvalue()
return base64.b64encode(wav_bytes).decode("utf-8")
@app.post("/api/v1/inference", response_model=AudioResponse)
async def run_inference(request: AudioRequest) -> AudioResponse:
"""
Voice-to-voice inference endpoint.
Expects a JSON payload with a base64-encoded WAV audio file and its sample rate.
Returns a base64-encoded WAV audio file as output and the corresponding text.
"""
if model is None or processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Decode the incoming base64 audio bytes (voice input)
audio_bytes = base64.b64decode(request.audio_data)
# Load audio data directly into a NumPy array
# Using soundfile.read directly from BytesIO is more robust than temporary files
audio_input_array, _ = sf.read(io.BytesIO(audio_bytes), dtype='float32') # Sample rate is already in request
# Run inference using the audio numpy array
text_response, audio_out_array = inference(audio_input_array, request.sample_rate)
# Convert model output audio to a WAV file and encode to base64
output_audio_b64 = process_audio_output(audio_out_array, MODEL_OUTPUT_SAMPLE_RATE)
# Use the first element of the text output (if available)
text_to_return = text_response[0] if text_response else ""
return AudioResponse(audio_data=output_audio_b64, text=text_to_return, output_sample_rate=MODEL_OUTPUT_SAMPLE_RATE)
except Exception as e:
logger.error(f"Inference error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
@app.get("/api/v1/health")
def health_check():
"""Simple health check endpoint."""
return {"status": "ok", "model_loaded": model is not None}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)