| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import base64 |
| import io |
| import numpy as np |
| import tempfile |
| import os |
| import logging |
|
|
| import torch |
| from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor |
| from qwen_omni_utils import process_mm_info |
| import soundfile as sf |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| class AudioRequest(BaseModel): |
| audio_data: str |
| sample_rate: int |
|
|
| class AudioResponse(BaseModel): |
| audio_data: str |
| text: str = "" |
|
|
| |
| model = None |
| processor = None |
|
|
| def initialize_model(): |
| """Load the Qwen2.5-Omni model and processor.""" |
| global model, processor |
| try: |
| model_path = "./model/Qwen2.5-Omni-7B" |
| model = Qwen2_5OmniModel.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| attn_implementation="flash_attention_2", |
| ) |
| processor = Qwen2_5OmniProcessor.from_pretrained(model_path) |
| logger.info("Model loaded successfully") |
| except Exception as e: |
| logger.error(f"Error loading model: {e}") |
| raise |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| initialize_model() |
|
|
| def inference(audio_file_path: str): |
| """ |
| Run inference on the provided audio file path. |
| The function expects a file path to a WAV file. |
| Returns both the text output (decoded) and audio output from the model. |
| """ |
| messages = [ |
| { |
| "role": "system", |
| "content": ( |
| "You are Qwen, an advanced multimodal AI assistant developed by the Qwen Team at Alibaba Group. " |
| "Your job is to listen carefully to the provided audio input and generate a clear, concise, and friendly response. " |
| "Please analyze the audio content, understand the user's intent, and provide an informative answer both in text and in speech." |
| ) |
| }, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "audio", "audio": audio_file_path}, |
| ] |
| }, |
| ] |
| |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| audios, images, videos = process_mm_info(messages, use_audio_in_video=True) |
| inputs = processor( |
| text=text, |
| audios=audios, |
| images=images, |
| videos=videos, |
| return_tensors="pt", |
| padding=True |
| ) |
| inputs = inputs.to(model.device).to(model.dtype) |
|
|
| |
| output = model.generate(**inputs, use_audio_in_video=True, return_audio=True) |
| text_out = processor.batch_decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) |
| audio_out = output[1] |
| return text_out, audio_out |
|
|
| def process_audio_output(audio_out): |
| """ |
| Convert the output audio (a NumPy array) to a WAV file in memory, |
| then return its base64 encoding. |
| """ |
| |
| audio_int16 = np.array(audio_out * 32767, dtype=np.int16) |
| wav_buffer = io.BytesIO() |
| |
| sf.write(wav_buffer, audio_int16, samplerate=24000, format="WAV") |
| wav_buffer.seek(0) |
| wav_bytes = wav_buffer.getvalue() |
| return base64.b64encode(wav_bytes).decode("utf-8") |
|
|
| @app.post("/api/v1/inference", response_model=AudioResponse) |
| async def run_inference(request: AudioRequest) -> AudioResponse: |
| """ |
| Voice-to-voice inference endpoint. |
| Expects a JSON payload with a base64-encoded WAV audio file and its sample rate. |
| Returns a base64-encoded WAV audio file as output and the corresponding text. |
| """ |
| if model is None or processor is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
| try: |
| |
| audio_bytes = base64.b64decode(request.audio_data) |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
| tmp.write(audio_bytes) |
| tmp_file_path = tmp.name |
|
|
| |
| text_response, audio_out = inference(tmp_file_path) |
| os.remove(tmp_file_path) |
|
|
| |
| output_audio_b64 = process_audio_output(audio_out) |
| |
| text_to_return = text_response[0] if text_response else "" |
| return AudioResponse(audio_data=output_audio_b64, text=text_to_return) |
| except Exception as e: |
| logger.error(f"Inference error: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/api/v1/health") |
| def health_check(): |
| """Simple health check endpoint.""" |
| return {"status": "ok", "model_loaded": model is not None} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |