""" PregoPal - 模型加载器(全双工版本) ====================================== 对接本地 llama-server 全双工 API。 架构: core/model_loader.py ←HTTP→ api/go_server.py → llama-server (omni) ↓ 本地推理 + TTS 用法: from core.model_loader import ModelLoader loader = ModelLoader() # 文本对话 resp = loader.chat([{"role": "user", "content": "你好"}]) # 语音对话(全双工) result = loader.voice_chat("/path/to/audio.wav") """ import os import io import json import base64 import logging import numpy as np import soundfile as sf import requests as req from typing import Optional logger = logging.getLogger(__name__) # 后端 API 地址 LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://127.0.0.1:8081") API_BASE = os.environ.get("MINICPM_API_BASE", LLAMA_SERVER_URL) class ModelLoader: """MiniCPM-o 4.5 模型加载器(支持文本 + 全双工语音)""" def __init__(self, api_base: str = None): self.api_base = (api_base or API_BASE).rstrip("/") self._omni_initialized = False # ── 文本对话 ──────────────────────────────────────────── def chat(self, messages: list[dict], max_tokens: int = 300, temperature: float = 0.7, stream: bool = False) -> dict: """ 文本对话(通过 llama-server) Args: messages: [{"role": "system"/"user", "content": "..."}] max_tokens: 最大输出 token 数 temperature: 生成温度 stream: 是否流式(暂不支持) Returns: dict: {"text": str, ...} """ body = { "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "stream": False, } try: url = f"{self.api_base}/v1/chat/completions" resp = req.post(url, json=body, timeout=120) if resp.status_code == 200: data = resp.json() return { "text": data["choices"][0]["message"]["content"], "success": True, } else: logger.error(f"chat 失败: {resp.status_code}") return {"text": "", "success": False, "error": str(resp.status_code)} except Exception as e: logger.error(f"chat 异常: {e}") return {"text": "", "success": False, "error": str(e)} def ask(self, prompt: str, system_prompt: Optional[str] = None, max_tokens: int = 300) -> str: """简化文本对话""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) result = self.chat(messages, max_tokens=max_tokens) return result.get("text", "") # ── 全双工语音对话 ────────────────────────────────────── def voice_chat(self, audio_path: str, max_tokens: int = 300) -> dict: """ 全双工语音对话 流程: WAV音频 → llama-server omni prefill → decode → TTS音频输出 Args: audio_path: WAV 文件路径(16kHz 单声道 float32) max_tokens: 最大输出 token 数 Returns: dict: { "text": str, # AI 回复文本 "audio_base64": str, # TTS 音频 base64 "success": bool, "round": int, } """ try: # 1. 读取音频 audio_data, sr = sf.read(audio_path, dtype='float32') if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) if sr != 16000: try: import librosa audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) except ImportError: pass # 2. 转 base64 buf = io.BytesIO() sf.write(buf, audio_data, 16000, format='WAV', subtype='PCM_16') audio_b64 = base64.b64encode(buf.getvalue()).decode('utf-8') # 3. 调用后端 body = { "audio_base64": audio_b64, "sample_rate": 16000, "max_tokens": max_tokens, } url = f"{self.api_base}/v1/omni/voice_chat" resp = req.post(url, json=body, timeout=180) if resp.status_code == 200: data = resp.json() return { "success": data.get("success", False), "text": data.get("text", ""), "audio_base64": data.get("audio_base64", ""), "round": data.get("round", 0), } else: return {"success": False, "error": f"HTTP {resp.status_code}"} except Exception as e: logger.error(f"voice_chat 异常: {e}") return {"success": False, "error": str(e)} # ── 健康检查 ──────────────────────────────────────────── def health(self) -> dict: """检查后端服务状态""" try: resp = req.get(f"{self.api_base}/health", timeout=5) if resp.status_code == 200: return resp.json() return {"status": "error", "message": f"HTTP {resp.status_code}"} except Exception as e: return {"status": "error", "message": str(e)} def unload(self): """释放资源""" self._omni_initialized = False