PregoPal / core /model_loader.py
J.B-Lin
ๅ…จๅŒๅทฅ่ฏญ้Ÿณๅฏน่ฏๅฎž็Žฐ
edca135
Raw
History Blame Contribute Delete
6.13 kB
"""
PregoPal - ๆจกๅž‹ๅŠ ่ฝฝๅ™จ๏ผˆๅ…จๅŒๅทฅ็‰ˆๆœฌ๏ผ‰
======================================
ๅฏนๆŽฅๆœฌๅœฐ llama-server ๅ…จๅŒๅทฅ APIใ€‚
ๆžถๆž„๏ผš
core/model_loader.py โ†HTTPโ†’ api/go_server.py โ†’ llama-server (omni)
โ†“
ๆœฌๅœฐๆŽจ็† + TTS
็”จๆณ•๏ผš
from core.model_loader import ModelLoader
loader = ModelLoader()
# ๆ–‡ๆœฌๅฏน่ฏ
resp = loader.chat([{"role": "user", "content": "ไฝ ๅฅฝ"}])
# ่ฏญ้Ÿณๅฏน่ฏ๏ผˆๅ…จๅŒๅทฅ๏ผ‰
result = loader.voice_chat("/path/to/audio.wav")
"""
import os
import io
import json
import base64
import logging
import numpy as np
import soundfile as sf
import requests as req
from typing import Optional
logger = logging.getLogger(__name__)
# ๅŽ็ซฏ API ๅœฐๅ€
LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://127.0.0.1:8081")
API_BASE = os.environ.get("MINICPM_API_BASE", LLAMA_SERVER_URL)
class ModelLoader:
"""MiniCPM-o 4.5 ๆจกๅž‹ๅŠ ่ฝฝๅ™จ๏ผˆๆ”ฏๆŒๆ–‡ๆœฌ + ๅ…จๅŒๅทฅ่ฏญ้Ÿณ๏ผ‰"""
def __init__(self, api_base: str = None):
self.api_base = (api_base or API_BASE).rstrip("/")
self._omni_initialized = False
# โ”€โ”€ ๆ–‡ๆœฌๅฏน่ฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def chat(self, messages: list[dict], max_tokens: int = 300,
temperature: float = 0.7, stream: bool = False) -> dict:
"""
ๆ–‡ๆœฌๅฏน่ฏ๏ผˆ้€š่ฟ‡ llama-server๏ผ‰
Args:
messages: [{"role": "system"/"user", "content": "..."}]
max_tokens: ๆœ€ๅคง่พ“ๅ‡บ token ๆ•ฐ
temperature: ็”Ÿๆˆๆธฉๅบฆ
stream: ๆ˜ฏๅฆๆตๅผ๏ผˆๆš‚ไธๆ”ฏๆŒ๏ผ‰
Returns:
dict: {"text": str, ...}
"""
body = {
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False,
}
try:
url = f"{self.api_base}/v1/chat/completions"
resp = req.post(url, json=body, timeout=120)
if resp.status_code == 200:
data = resp.json()
return {
"text": data["choices"][0]["message"]["content"],
"success": True,
}
else:
logger.error(f"chat ๅคฑ่ดฅ: {resp.status_code}")
return {"text": "", "success": False, "error": str(resp.status_code)}
except Exception as e:
logger.error(f"chat ๅผ‚ๅธธ: {e}")
return {"text": "", "success": False, "error": str(e)}
def ask(self, prompt: str, system_prompt: Optional[str] = None,
max_tokens: int = 300) -> str:
"""็ฎ€ๅŒ–ๆ–‡ๆœฌๅฏน่ฏ"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
result = self.chat(messages, max_tokens=max_tokens)
return result.get("text", "")
# โ”€โ”€ ๅ…จๅŒๅทฅ่ฏญ้Ÿณๅฏน่ฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def voice_chat(self, audio_path: str, max_tokens: int = 300) -> dict:
"""
ๅ…จๅŒๅทฅ่ฏญ้Ÿณๅฏน่ฏ
ๆต็จ‹: WAV้Ÿณ้ข‘ โ†’ llama-server omni prefill โ†’ decode โ†’ TTS้Ÿณ้ข‘่พ“ๅ‡บ
Args:
audio_path: WAV ๆ–‡ไปถ่ทฏๅพ„๏ผˆ16kHz ๅ•ๅฃฐ้“ float32๏ผ‰
max_tokens: ๆœ€ๅคง่พ“ๅ‡บ token ๆ•ฐ
Returns:
dict: {
"text": str, # AI ๅ›žๅคๆ–‡ๆœฌ
"audio_base64": str, # TTS ้Ÿณ้ข‘ base64
"success": bool,
"round": int,
}
"""
try:
# 1. ่ฏปๅ–้Ÿณ้ข‘
audio_data, sr = sf.read(audio_path, dtype='float32')
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
if sr != 16000:
try:
import librosa
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
except ImportError:
pass
# 2. ่ฝฌ base64
buf = io.BytesIO()
sf.write(buf, audio_data, 16000, format='WAV', subtype='PCM_16')
audio_b64 = base64.b64encode(buf.getvalue()).decode('utf-8')
# 3. ่ฐƒ็”จๅŽ็ซฏ
body = {
"audio_base64": audio_b64,
"sample_rate": 16000,
"max_tokens": max_tokens,
}
url = f"{self.api_base}/v1/omni/voice_chat"
resp = req.post(url, json=body, timeout=180)
if resp.status_code == 200:
data = resp.json()
return {
"success": data.get("success", False),
"text": data.get("text", ""),
"audio_base64": data.get("audio_base64", ""),
"round": data.get("round", 0),
}
else:
return {"success": False, "error": f"HTTP {resp.status_code}"}
except Exception as e:
logger.error(f"voice_chat ๅผ‚ๅธธ: {e}")
return {"success": False, "error": str(e)}
# โ”€โ”€ ๅฅๅบทๆฃ€ๆŸฅ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def health(self) -> dict:
"""ๆฃ€ๆŸฅๅŽ็ซฏๆœๅŠก็Šถๆ€"""
try:
resp = req.get(f"{self.api_base}/health", timeout=5)
if resp.status_code == 200:
return resp.json()
return {"status": "error", "message": f"HTTP {resp.status_code}"}
except Exception as e:
return {"status": "error", "message": str(e)}
def unload(self):
"""้‡Šๆ”พ่ต„ๆบ"""
self._omni_initialized = False