|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
import base64 |
|
|
import io |
|
|
import numpy as np |
|
|
import tempfile |
|
|
import os |
|
|
import logging |
|
|
import torch |
|
|
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor |
|
|
|
|
|
from qwen_omni_utils import process_mm_info |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class AudioRequest(BaseModel): |
|
|
audio_data: str |
|
|
sample_rate: int |
|
|
|
|
|
class AudioResponse(BaseModel): |
|
|
audio_data: str |
|
|
text: str = "" |
|
|
output_sample_rate: int |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
|
|
|
MODEL_OUTPUT_SAMPLE_RATE = 24000 |
|
|
|
|
|
def initialize_model(): |
|
|
"""Load the Qwen2.5-Omni model and processor.""" |
|
|
global model, processor |
|
|
try: |
|
|
|
|
|
model_path = "Qwen/Qwen2.5-Omni-7B" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Qwen2_5OmniModel.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
|
|
|
) |
|
|
processor = Qwen2_5OmniProcessor.from_pretrained(model_path) |
|
|
logger.info("Model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
initialize_model() |
|
|
|
|
|
def inference(audio_data: np.ndarray, input_sample_rate: int): |
|
|
""" |
|
|
Run inference on the provided audio data (NumPy array). |
|
|
Returns both the text output (decoded) and audio output from the model. |
|
|
""" |
|
|
if processor is None or model is None: |
|
|
raise ValueError("Model and processor are not initialized.") |
|
|
|
|
|
|
|
|
audio_input_for_processor = [{"array": audio_data, "sampling_rate": input_sample_rate}] |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"You are Qwen, an advanced multimodal AI assistant developed by the Qwen Team at Alibaba Group. " |
|
|
"Your job is to listen carefully to the provided audio input and generate a clear, concise, and friendly response. " |
|
|
"Please analyze the audio content, understand the user's intent, and provide an informative answer both in text and in speech." |
|
|
) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "audio", "audio": audio_input_for_processor[0]}, |
|
|
] |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
|
|
|
audios, images, videos = process_mm_info(messages, use_audio_in_video=True) |
|
|
|
|
|
inputs = processor( |
|
|
text=text_prompt, |
|
|
audios=audios, |
|
|
images=images, |
|
|
videos=videos, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
|
|
|
inputs = inputs.to(model.device).to(model.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = model.generate(**inputs, use_audio_in_video=True, return_audio=True) |
|
|
|
|
|
|
|
|
if isinstance(output, tuple) and len(output) >= 2: |
|
|
text_out_ids = output[0] |
|
|
audio_out_array = output[1] |
|
|
else: |
|
|
|
|
|
logger.warning(f"Unexpected output format from model.generate: {type(output)}") |
|
|
text_out_ids = output |
|
|
audio_out_array = None |
|
|
|
|
|
text_out = processor.batch_decode(text_out_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
|
|
|
|
|
|
if audio_out_array is not None: |
|
|
audio_out_array = audio_out_array.reshape(-1).detach().cpu().numpy() |
|
|
|
|
|
return text_out, audio_out_array |
|
|
|
|
|
def process_audio_output(audio_out: np.ndarray, sample_rate: int) -> str: |
|
|
""" |
|
|
Convert the output audio (a NumPy array) to a WAV file in memory, |
|
|
then return its base64 encoding. |
|
|
""" |
|
|
if audio_out is None: |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
audio_int16 = np.array(audio_out * 32767, dtype=np.int16) |
|
|
|
|
|
wav_buffer = io.BytesIO() |
|
|
|
|
|
sf.write(wav_buffer, audio_int16, samplerate=sample_rate, format="WAV") |
|
|
wav_buffer.seek(0) |
|
|
wav_bytes = wav_buffer.getvalue() |
|
|
return base64.b64encode(wav_bytes).decode("utf-8") |
|
|
|
|
|
@app.post("/api/v1/inference", response_model=AudioResponse) |
|
|
async def run_inference(request: AudioRequest) -> AudioResponse: |
|
|
""" |
|
|
Voice-to-voice inference endpoint. |
|
|
Expects a JSON payload with a base64-encoded WAV audio file and its sample rate. |
|
|
Returns a base64-encoded WAV audio file as output and the corresponding text. |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
audio_bytes = base64.b64decode(request.audio_data) |
|
|
|
|
|
|
|
|
|
|
|
audio_input_array, _ = sf.read(io.BytesIO(audio_bytes), dtype='float32') |
|
|
|
|
|
|
|
|
text_response, audio_out_array = inference(audio_input_array, request.sample_rate) |
|
|
|
|
|
|
|
|
output_audio_b64 = process_audio_output(audio_out_array, MODEL_OUTPUT_SAMPLE_RATE) |
|
|
|
|
|
|
|
|
text_to_return = text_response[0] if text_response else "" |
|
|
|
|
|
return AudioResponse(audio_data=output_audio_b64, text=text_to_return, output_sample_rate=MODEL_OUTPUT_SAMPLE_RATE) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") |
|
|
|
|
|
@app.get("/api/v1/health") |
|
|
def health_check(): |
|
|
"""Simple health check endpoint.""" |
|
|
return {"status": "ok", "model_loaded": model is not None} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |