PregoPal / api /go_server.py
J.B-Lin
Fix Gradio duplex audio playback clearing
0598330
Raw
History Blame Contribute Delete
18.3 kB
"""
PregoPal 全双工后端 — 语音对话服务器
=====================================
基于 llama.cpp-omni 的 omni 全双工能力。
通过 /v1/stream/* 端点实现:语音输入→音频prefill→AI推理→TTS语音输出 的完整闭环。
架构:
FastAPI (本服务) → llama-server (omni模式, /v1/stream/*)
启动顺序:
1. 确保 llama-server 已在 omni 模式启动 (start_llama_server.py)
2. python -m uvicorn api.go_server:app --host 127.0.0.1 --port 8090
"""
import os
import json
import time
import base64
import asyncio
import logging
import re
import shutil
import numpy as np
import soundfile as sf
import httpx
from pathlib import Path
from datetime import datetime
from typing import Optional, AsyncGenerator
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
logger = logging.getLogger("prego_api")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setFormatter(logging.Formatter("[PregoAPI] %(asctime)s %(message)s"))
logger.addHandler(ch)
# ── 配置 ──────────────────────────────────────────────
BASE_DIR = Path(__file__).resolve().parents[1]
LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://127.0.0.1:8081")
OMNI_OUTPUT_DIR = os.environ.get("OMNI_OUTPUT_DIR", str(BASE_DIR / "omni_output"))
TEMP_DIR = os.environ.get("TEMP_DIR", str(BASE_DIR / "api" / "temp"))
LLM_SYSTEM_PROMPT = os.environ.get("LLM_SYSTEM_PROMPT",
"你是PregoPal,一位贴心的孕期营养健康顾问。"
"请用中文回答,给出简短实用的建议。"
"如果用户提到吃了什么,尝试记录饮食信息[EXTRACT_DIET]。"
"如果有家庭成员信息,记录为[EXTRACT_FAMILY]。")
Path(TEMP_DIR).mkdir(parents=True, exist_ok=True)
Path(OMNI_OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
app = FastAPI(title="PregoPal Omni Backend")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# 全局状态
class SessionState:
def __init__(self):
self.initialized = False
self.round_counter = 0
self.sample_rate = 16000
state = SessionState()
http_client = None
@app.on_event("startup")
async def startup():
global http_client
http_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=5.0))
# 确保 llama-server 存活着
try:
resp = await http_client.get(f"{LLAMA_SERVER_URL}/health", timeout=3)
if resp.status_code == 200:
logger.info(f"✅ llama-server 健康: {resp.json()}")
else:
logger.warning(f"⚠️ llama-server 响应异常: {resp.status_code}")
except Exception as e:
logger.warning(f"⚠️ llama-server 连接失败: {e}")
logger.info("PregoAPI 启动完成")
@app.on_event("shutdown")
async def shutdown():
global http_client
if http_client:
await http_client.aclose()
# ── 音频处理 ──────────────────────────────────────────
_SAMPLE_RATE = 16000
_MAX_TTS_WAV_FILES = 24
def audio_to_wav_bytes(audio_data: np.ndarray, sr: int = _SAMPLE_RATE) -> bytes:
"""numpy 音频 → WAV bytes"""
import io
buf = io.BytesIO()
sf.write(buf, audio_data, sr, format='WAV', subtype='PCM_16')
return buf.getvalue()
def save_temp_audio(audio_data: np.ndarray, session_id: str, cnt: int) -> str:
"""保存临时音频文件,返回路径"""
fname = f"prefill_{session_id}_{cnt}.wav"
fpath = os.path.join(TEMP_DIR, fname)
sf.write(fpath, audio_data, _SAMPLE_RATE, format='WAV', subtype='PCM_16')
return fpath
def _safe_round_dir(cnt: int) -> Path:
"""Return the scoped generated output dir for one omni round."""
output_root = Path(OMNI_OUTPUT_DIR).resolve()
round_dir = (output_root / f"round_{cnt:03d}").resolve()
if round_dir.parent != output_root:
raise RuntimeError(f"非法 round 输出目录: {round_dir}")
return round_dir
def _reset_round_dir(cnt: int) -> Path:
"""Remove stale generated chunks for this round before decode writes new data."""
round_dir = _safe_round_dir(cnt)
if round_dir.exists():
shutil.rmtree(round_dir)
logger.info(f"已清理旧 round 输出: {round_dir}")
round_dir.mkdir(parents=True, exist_ok=True)
return round_dir
def _numeric_suffix(path_or_name: str, prefix: str) -> int:
name = os.path.basename(str(path_or_name))
match = re.search(rf"{re.escape(prefix)}(\d+)", name)
return int(match.group(1)) if match else 10**9
def _is_unusable_llm_text(text: str) -> bool:
compact = re.sub(r"\s+", "", text or "")
if len(compact) < 2:
return True
pipe_count = compact.count("|")
if len(compact) >= 20 and pipe_count / len(compact) > 0.35:
return True
if re.fullmatch(r"[\|\s,。,.!?!?::;;\-_=+~`·…]+", text or ""):
return True
return False
def _fallback_voice_text() -> str:
return (
"我刚刚收到了你的语音。为了孕期饮食更稳妥,建议今天优先选择清淡、熟透、"
"高蛋白和富含叶酸/铁/钙的家常菜,例如鸡蛋、鱼虾或瘦肉搭配深绿色蔬菜和主食。"
"如果你告诉我今天想吃的菜和家里现有食材,我可以继续帮你估算是否适合孕妇。"
)
def _read_llm_text(llm_debug_dir: str) -> str:
"""读取 llm_debug 所有 chunk 的 llm_text.txt,合并为完整文本"""
import glob
parts = []
chunk_dirs = sorted(
glob.glob(os.path.join(llm_debug_dir, "chunk_*")),
key=lambda p: _numeric_suffix(p, "chunk_"),
)
for ch in chunk_dirs:
txt_path = os.path.join(ch, "llm_text.txt")
if os.path.exists(txt_path):
with open(txt_path, "r", encoding="utf-8") as f:
parts.append(f.read().strip())
text = "".join(parts).strip()
return "" if _is_unusable_llm_text(text) else text
def _merge_wavs_to_base64(wav_dir: str, wav_files: list) -> str:
"""合并所有 wav 片段为单个 WAV,返回 base64"""
import io
all_data = []
sample_rate = None
sorted_files = sorted(wav_files, key=lambda n: _numeric_suffix(n, "wav_"))
if len(sorted_files) > _MAX_TTS_WAV_FILES:
logger.warning(
f"TTS wav 片段过多({len(sorted_files)}),仅合并前 {_MAX_TTS_WAV_FILES} 个用于演示"
)
sorted_files = sorted_files[:_MAX_TTS_WAV_FILES]
for wf in sorted_files:
wav_path = os.path.join(wav_dir, wf)
data, sr = sf.read(wav_path)
if sample_rate is None:
sample_rate = sr
all_data.append(data)
if all_data:
merged = np.concatenate(all_data)
buf = io.BytesIO()
sf.write(buf, merged, sample_rate or 24000, format='WAV', subtype='PCM_16')
return base64.b64encode(buf.getvalue()).decode("utf-8")
return ""
def _wait_for_tts_files(tts_wav_dir: str, timeout_s: float = 45.0) -> list:
"""Wait briefly because llama.cpp-omni may return before TTS wav files finish."""
deadline = time.time() + timeout_s
last_files = []
stable_seen = 0
while time.time() < deadline:
if os.path.exists(tts_wav_dir):
wav_files = [f for f in os.listdir(tts_wav_dir)
if f.startswith("wav_") and f.endswith(".wav")]
wav_files = sorted(wav_files, key=lambda n: _numeric_suffix(n, "wav_"))
if wav_files:
sizes = [
os.path.getsize(os.path.join(tts_wav_dir, f))
for f in wav_files
]
if wav_files == last_files and all(size > 44 for size in sizes):
stable_seen += 1
else:
stable_seen = 0
last_files = wav_files
if os.path.exists(os.path.join(tts_wav_dir, "generation_done.flag")):
return wav_files
if stable_seen >= 12:
return wav_files
time.sleep(0.25)
return last_files
def wav_bytes_to_numpy(wav_bytes: bytes) -> np.ndarray:
"""WAV bytes → numpy float32"""
import io
data, sr = sf.read(io.BytesIO(wav_bytes))
if sr != _SAMPLE_RATE:
import librosa
data = librosa.resample(data, orig_sr=sr, target_sr=_SAMPLE_RATE)
if len(data.shape) > 1:
data = data.mean(axis=1)
return data.astype(np.float32)
# ── 核心全双工 API ────────────────────────────────────
async def omni_init_if_needed():
"""确保 omni context 已初始化"""
global state
if state.initialized:
return True
logger.info("初始化 omni context...")
init_data = {
"media_type": 2, # omni 模式(支持 audio+vision)
"use_tts": True,
"duplex_mode": False, # 先单工
"model_dir": "C:\\Users\\Andre\\codes\\LJB\\llama.cpp-omni\\models",
"tts_bin_dir": "C:\\Users\\Andre\\codes\\LJB\\llama.cpp-omni\\models\\tts",
"token2wav_device": "cpu", # 节省显存
"output_dir": OMNI_OUTPUT_DIR,
"tts_gpu_layers": 0,
}
resp = await http_client.post(
f"{LLAMA_SERVER_URL}/v1/stream/omni_init",
json=init_data,
timeout=120.0
)
if resp.status_code != 200:
logger.error(f"omni_init 失败: {resp.text}")
return False
result = resp.json()
logger.info(f"✅ omni_init: {result}")
state.initialized = True
return True
class ChatRequest(BaseModel):
"""文本对话请求"""
messages: list
max_tokens: int = 300
temperature: float = 0.7
class VoiceChatRequest(BaseModel):
"""语音对话请求(非流式)"""
audio_base64: str
sample_rate: int = 16000
max_tokens: int = 300
@app.get("/health")
async def health():
"""健康检查"""
return {
"status": "ok",
"backend": "prego_api",
"llama_server": LLAMA_SERVER_URL,
"omni_initialized": state.initialized,
"round": state.round_counter
}
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
"""文本对话(直接用 llama-server 的标准 chat API)"""
body = {
"messages": [{"role": m.get("role", "user"), "content": m.get("content", "")}
for m in req.messages],
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"stream": False,
}
resp = await http_client.post(
f"{LLAMA_SERVER_URL}/v1/chat/completions",
json=body,
timeout=120.0
)
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=resp.text)
return resp.json()
@app.post("/v1/omni/voice_chat")
async def voice_chat(req: VoiceChatRequest):
"""
语音对话(单轮半双工)
流程:audio_base64 → prefill → decode → 返回文本+音频
"""
ok = await omni_init_if_needed()
if not ok:
raise HTTPException(503, "omni 初始化失败")
# 1. 解码音频
try:
audio_bytes = base64.b64decode(req.audio_base64)
audio_np, sr = sf.read(io := __import__('io').BytesIO(audio_bytes), dtype='float32')
if sr != _SAMPLE_RATE:
import librosa
audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=_SAMPLE_RATE)
if len(audio_np.shape) > 1:
audio_np = audio_np.mean(axis=1)
except Exception as e:
raise HTTPException(400, f"音频解码失败: {e}")
# 2. 保存临时音频 → prefill
cnt = state.round_counter
session_id = "prego"
audio_path = save_temp_audio(audio_np, session_id, cnt)
round_dir = _reset_round_dir(cnt)
prefill_data = {
"audio_path_prefix": audio_path,
"img_path_prefix": "",
"cnt": cnt,
}
prefill_resp = await http_client.post(
f"{LLAMA_SERVER_URL}/v1/stream/prefill",
json=prefill_data,
timeout=30.0
)
if prefill_resp.status_code != 200:
raise HTTPException(502, f"prefill 失败: {prefill_resp.text}")
# 3. decode
decode_data = {
"debug_dir": OMNI_OUTPUT_DIR,
"stream": False,
"round_idx": cnt,
}
decode_resp = await http_client.post(
f"{LLAMA_SERVER_URL}/v1/stream/decode",
json=decode_data,
timeout=120.0
)
if decode_resp.status_code != 200:
raise HTTPException(502, f"decode 失败: {decode_resp.text}")
# 4. 读取 TTS 输出
tts_wav_dir = os.path.join(str(round_dir), "tts_wav")
tts_audio_base64 = ""
text_output = ""
# 读取 llm_text.txt (合并所有 chunk)
llm_debug_dir = os.path.join(str(round_dir), "llm_debug")
text_output = _read_llm_text(llm_debug_dir)
if not text_output:
logger.warning("omni 文本输出不可用,使用演示兜底回复")
text_output = _fallback_voice_text()
# 读取 TTS WAV (合并所有 wav 片段)
wav_files = _wait_for_tts_files(tts_wav_dir)
if wav_files:
tts_audio_base64 = _merge_wavs_to_base64(tts_wav_dir, wav_files)
logger.info(f"TTS wav 合并完成: {len(wav_files)} 个片段")
else:
logger.warning(f"TTS wav 未在超时内生成: {tts_wav_dir}")
state.round_counter += 1
return {
"success": True,
"round": cnt,
"text": text_output,
"audio_base64": tts_audio_base64,
"audio_sample_rate": 24000, # TTS 默认采样率
"audio_files": len(wav_files),
"timing": {
"audio_prefill_ms": 0,
"decode_ms": 0,
}
}
@app.post("/v1/omni/streaming_voice")
async def streaming_voice(req: VoiceChatRequest):
"""
流式语音对话 — SSE 流式返回文本+TTS音频块
"""
ok = await omni_init_if_needed()
if not ok:
raise HTTPException(503, "omni 初始化失败")
# 解码音频
try:
audio_bytes = base64.b64decode(req.audio_base64)
audio_np, sr = sf.read(__import__('io').BytesIO(audio_bytes), dtype='float32')
if sr != _SAMPLE_RATE:
import librosa
audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=_SAMPLE_RATE)
if len(audio_np.shape) > 1:
audio_np = audio_np.mean(axis=1)
except Exception as e:
raise HTTPException(400, f"音频解码失败: {e}")
cnt = state.round_counter
session_id = "prego"
audio_path = save_temp_audio(audio_np, session_id, cnt)
async def event_stream() -> AsyncGenerator[str, None]:
# prefill
prefill_data = {"audio_path_prefix": audio_path, "img_path_prefix": "", "cnt": cnt}
pref_resp = await http_client.post(
f"{LLAMA_SERVER_URL}/v1/stream/prefill", json=prefill_data, timeout=30.0)
if pref_resp.status_code != 200:
yield f"data: {json.dumps({'error': 'prefill failed'})}\n\n"
return
yield f"data: {json.dumps({'type': 'prefill_done'})}\n\n"
# decode with SSE streaming
decode_data = {"debug_dir": OMNI_OUTPUT_DIR, "stream": True, "round_idx": cnt}
async with http_client.stream(
"POST", f"{LLAMA_SERVER_URL}/v1/stream/decode",
json=decode_data, timeout=120.0
) as resp:
if resp.status_code != 200:
yield f"data: {json.dumps({'error': 'decode failed'})}\n\n"
return
# 同时轮询 TTS 目录和 llm_text
tts_dir = os.path.join(OMNI_OUTPUT_DIR, f"round_{cnt:03d}", "tts_wav")
llm_dir = os.path.join(OMNI_OUTPUT_DIR, f"round_{cnt:03d}", "llm_debug")
sent_wavs = set()
sent_text_lines = 0
start_time = time.time()
while time.time() - start_time < 120:
# 检查文本输出
llm_txt = os.path.join(llm_dir, "llm_text.txt")
if os.path.exists(llm_txt):
try:
with open(llm_txt, "r", encoding="utf-8") as f:
lines = f.readlines()
for i in range(sent_text_lines, len(lines)):
yield f"data: {json.dumps({'type': 'text', 'text': lines[i].strip()})}\n\n"
sent_text_lines = len(lines)
except:
pass
# 检查 TTS WAV 输出
if os.path.exists(tts_dir):
try:
wav_files = sorted([f for f in os.listdir(tts_dir)
if f.startswith("wav_") and f.endswith(".wav")])
for wf in wav_files:
if wf not in sent_wavs:
sent_wavs.add(wf)
wav_path = os.path.join(tts_dir, wf)
wav_data, wav_sr = sf.read(wav_path)
wav_b64 = base64.b64encode(
audio_to_wav_bytes(wav_data, wav_sr)).decode("utf-8")
yield f"data: {json.dumps({'type': 'audio', 'index': len(sent_wavs)-1, 'base64': wav_b64, 'sample_rate': wav_sr})}\n\n"
except:
pass
# 检测结束
done_flag = os.path.join(tts_dir, "generation_done.flag")
if os.path.exists(done_flag):
yield f"data: {json.dumps({'type': 'done'})}\n\n"
return
await asyncio.sleep(0.1)
yield f"data: {json.dumps({'type': 'timeout'})}\n\n"
state.round_counter += 1
return StreamingResponse(event_stream(), media_type="text/event-stream")
# ── 启动脚本 ──────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8090, log_level="info")