File size: 5,755 Bytes
b024d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import base64
import io
import numpy as np
import tempfile
import os
import logging

import torch
from transformers import Qwen2_5OmniModel, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info
import soundfile as sf

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Allow CORS for all origins (adjust as needed)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request/response models
class AudioRequest(BaseModel):
    audio_data: str  # Base64-encoded WAV file bytes (voice input)
    sample_rate: int

class AudioResponse(BaseModel):
    audio_data: str  # Base64-encoded WAV file bytes (voice output)
    text: str = ""   # Optionally include any text response

# Global model and processor variables
model = None
processor = None

def initialize_model():
    """Load the Qwen2.5-Omni model and processor."""
    global model, processor
    try:
        model_path = "./model/Qwen2.5-Omni-7B"
        model = Qwen2_5OmniModel.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            attn_implementation="flash_attention_2",
        )
        processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        raise

@app.on_event("startup")
async def startup_event():
    initialize_model()

def inference(audio_file_path: str):
    """
    Run inference on the provided audio file path.
    The function expects a file path to a WAV file.
    Returns both the text output (decoded) and audio output from the model.
    """
    messages = [
        {
            "role": "system",
            "content": (
                "You are Qwen, an advanced multimodal AI assistant developed by the Qwen Team at Alibaba Group. "
                "Your job is to listen carefully to the provided audio input and generate a clear, concise, and friendly response. "
                "Please analyze the audio content, understand the user's intent, and provide an informative answer both in text and in speech."
            )
        },
        {
            "role": "user",
            "content": [
                {"type": "audio", "audio": audio_file_path},
            ]
        },
    ]
    # Generate the text prompt and process the audio input
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    audios, images, videos = process_mm_info(messages, use_audio_in_video=True)
    inputs = processor(
        text=text, 
        audios=audios, 
        images=images, 
        videos=videos, 
        return_tensors="pt", 
        padding=True
    )
    inputs = inputs.to(model.device).to(model.dtype)

    # Generate output from the model. Return both text and audio.
    output = model.generate(**inputs, use_audio_in_video=True, return_audio=True)
    text_out = processor.batch_decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
    audio_out = output[1]  # Model output audio (NumPy array)
    return text_out, audio_out

def process_audio_output(audio_out):
    """
    Convert the output audio (a NumPy array) to a WAV file in memory,
    then return its base64 encoding.
    """
    # Assume audio_out values are in [-1, 1]. Convert to 16-bit PCM.
    audio_int16 = np.array(audio_out * 32767, dtype=np.int16)
    wav_buffer = io.BytesIO()
    # Write WAV data to the buffer at the desired sample rate (e.g., 24000 Hz)
    sf.write(wav_buffer, audio_int16, samplerate=24000, format="WAV")
    wav_buffer.seek(0)
    wav_bytes = wav_buffer.getvalue()
    return base64.b64encode(wav_bytes).decode("utf-8")

@app.post("/api/v1/inference", response_model=AudioResponse)
async def run_inference(request: AudioRequest) -> AudioResponse:
    """
    Voice-to-voice inference endpoint.
    Expects a JSON payload with a base64-encoded WAV audio file and its sample rate.
    Returns a base64-encoded WAV audio file as output and the corresponding text.
    """
    if model is None or processor is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        # Decode the incoming base64 audio bytes (voice input)
        audio_bytes = base64.b64decode(request.audio_data)
        # Write the bytes to a temporary WAV file so that an audio loader can read it
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            tmp.write(audio_bytes)
            tmp_file_path = tmp.name

        # Run inference using the temporary audio file, capturing both text and audio output
        text_response, audio_out = inference(tmp_file_path)
        os.remove(tmp_file_path)  # Clean up the temporary file

        # Convert model output audio to a WAV file and encode to base64
        output_audio_b64 = process_audio_output(audio_out)
        # Use the first element of the text output (if available)
        text_to_return = text_response[0] if text_response else ""
        return AudioResponse(audio_data=output_audio_b64, text=text_to_return)
    except Exception as e:
        logger.error(f"Inference error: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/api/v1/health")
def health_check():
    """Simple health check endpoint."""
    return {"status": "ok", "model_loaded": model is not None}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)