from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import base64 import io import numpy as np import tempfile import os import logging import torch from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor # Ensure you have qwen_omni_utils installed, e.g., pip install qwen-omni-utils from qwen_omni_utils import process_mm_info import soundfile as sf # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Allow CORS for all origins (adjust as needed for production) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request/response models class AudioRequest(BaseModel): audio_data: str # Base64-encoded WAV file bytes (voice input) sample_rate: int # Include sample rate for input audio class AudioResponse(BaseModel): audio_data: str # Base64-encoded WAV file bytes (voice output) text: str = "" # Optionally include any text response output_sample_rate: int # Include sample rate for output audio # Global model and processor variables model = None processor = None # Define the expected sample rate for the model's audio output # Qwen2.5-Omni typically outputs audio at 24kHz MODEL_OUTPUT_SAMPLE_RATE = 24000 def initialize_model(): """Load the Qwen2.5-Omni model and processor.""" global model, processor try: # It's recommended to use the full Hugging Face model ID for clarity model_path = "Qwen/Qwen2.5-Omni-7B" # Assuming this is your local path or Hugging Face ID # Ensure that transformers and qwen-omni-utils are compatible with this model version. # You might need to install them from source or specific versions if you encounter errors: # pip uninstall transformers -y # pip install git+https://github.com/huggingface/transformers@main # or a specific branch/commit # pip install qwen-omni-utils -U model = Qwen2_5OmniModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map="auto", #attn_implementation="flash_attention_2", # Uncomment if you have flash-attention installed ) processor = Qwen2_5OmniProcessor.from_pretrained(model_path) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {e}") raise @app.on_event("startup") async def startup_event(): initialize_model() def inference(audio_data: np.ndarray, input_sample_rate: int): """ Run inference on the provided audio data (NumPy array). Returns both the text output (decoded) and audio output from the model. """ if processor is None or model is None: raise ValueError("Model and processor are not initialized.") # The processor expects audio as a list of dictionaries with "array" and "sampling_rate" audio_input_for_processor = [{"array": audio_data, "sampling_rate": input_sample_rate}] messages = [ { "role": "system", "content": ( "You are Qwen, an advanced multimodal AI assistant developed by the Qwen Team at Alibaba Group. " "Your job is to listen carefully to the provided audio input and generate a clear, concise, and friendly response. " "Please analyze the audio content, understand the user's intent, and provide an informative answer both in text and in speech." ) }, { "role": "user", "content": [ {"type": "audio", "audio": audio_input_for_processor[0]}, # Pass the prepared audio dictionary ] }, ] # Generate the text prompt and process the audio input text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # process_mm_info extracts modalities from messages and prepares them for the processor # For audio, it expects {"audio": {"array": ..., "sampling_rate": ...}} in messages. audios, images, videos = process_mm_info(messages, use_audio_in_video=True) inputs = processor( text=text_prompt, audios=audios, # This will be the processed audio array from process_mm_info images=images, videos=videos, return_tensors="pt", padding=True ) inputs = inputs.to(model.device).to(model.dtype) # Generate output from the model. Return both text and audio. # The output from model.generate should be a tuple (text_ids, audio_array) # Ensure `return_audio=True` is correctly handled by your `transformers` version. output = model.generate(**inputs, use_audio_in_video=True, return_audio=True) # Check if the output is a tuple and has the expected number of elements if isinstance(output, tuple) and len(output) >= 2: text_out_ids = output[0] audio_out_array = output[1] else: # Handle cases where output might just be text_ids or an unexpected format logger.warning(f"Unexpected output format from model.generate: {type(output)}") text_out_ids = output # Assume text IDs are the first element audio_out_array = None # No audio generated or not in expected format text_out = processor.batch_decode(text_out_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Qwen2.5-Omni's audio output is typically a torch.Tensor, reshape and convert to numpy if audio_out_array is not None: audio_out_array = audio_out_array.reshape(-1).detach().cpu().numpy() return text_out, audio_out_array def process_audio_output(audio_out: np.ndarray, sample_rate: int) -> str: """ Convert the output audio (a NumPy array) to a WAV file in memory, then return its base64 encoding. """ if audio_out is None: return "" # Return empty string if no audio output # Assume audio_out values are in [-1, 1]. Convert to 16-bit PCM. # Ensure the audio_out is a numpy array before conversion audio_int16 = np.array(audio_out * 32767, dtype=np.int16) wav_buffer = io.BytesIO() # Write WAV data to the buffer at the desired sample rate sf.write(wav_buffer, audio_int16, samplerate=sample_rate, format="WAV") wav_buffer.seek(0) wav_bytes = wav_buffer.getvalue() return base64.b64encode(wav_bytes).decode("utf-8") @app.post("/api/v1/inference", response_model=AudioResponse) async def run_inference(request: AudioRequest) -> AudioResponse: """ Voice-to-voice inference endpoint. Expects a JSON payload with a base64-encoded WAV audio file and its sample rate. Returns a base64-encoded WAV audio file as output and the corresponding text. """ if model is None or processor is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Decode the incoming base64 audio bytes (voice input) audio_bytes = base64.b64decode(request.audio_data) # Load audio data directly into a NumPy array # Using soundfile.read directly from BytesIO is more robust than temporary files audio_input_array, _ = sf.read(io.BytesIO(audio_bytes), dtype='float32') # Sample rate is already in request # Run inference using the audio numpy array text_response, audio_out_array = inference(audio_input_array, request.sample_rate) # Convert model output audio to a WAV file and encode to base64 output_audio_b64 = process_audio_output(audio_out_array, MODEL_OUTPUT_SAMPLE_RATE) # Use the first element of the text output (if available) text_to_return = text_response[0] if text_response else "" return AudioResponse(audio_data=output_audio_b64, text=text_to_return, output_sample_rate=MODEL_OUTPUT_SAMPLE_RATE) except Exception as e: logger.error(f"Inference error: {str(e)}") raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") @app.get("/api/v1/health") def health_check(): """Simple health check endpoint.""" return {"status": "ok", "model_loaded": model is not None} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)