""" PregoPal 全双工后端 — 语音对话服务器 ===================================== 基于 llama.cpp-omni 的 omni 全双工能力。 通过 /v1/stream/* 端点实现:语音输入→音频prefill→AI推理→TTS语音输出 的完整闭环。 架构: FastAPI (本服务) → llama-server (omni模式, /v1/stream/*) 启动顺序: 1. 确保 llama-server 已在 omni 模式启动 (start_llama_server.py) 2. python -m uvicorn api.go_server:app --host 127.0.0.1 --port 8090 """ import os import json import time import base64 import asyncio import logging import re import shutil import numpy as np import soundfile as sf import httpx from pathlib import Path from datetime import datetime from typing import Optional, AsyncGenerator from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel logger = logging.getLogger("prego_api") logger.setLevel(logging.INFO) ch = logging.StreamHandler() ch.setFormatter(logging.Formatter("[PregoAPI] %(asctime)s %(message)s")) logger.addHandler(ch) # ── 配置 ────────────────────────────────────────────── BASE_DIR = Path(__file__).resolve().parents[1] LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://127.0.0.1:8081") OMNI_OUTPUT_DIR = os.environ.get("OMNI_OUTPUT_DIR", str(BASE_DIR / "omni_output")) TEMP_DIR = os.environ.get("TEMP_DIR", str(BASE_DIR / "api" / "temp")) LLM_SYSTEM_PROMPT = os.environ.get("LLM_SYSTEM_PROMPT", "你是PregoPal,一位贴心的孕期营养健康顾问。" "请用中文回答,给出简短实用的建议。" "如果用户提到吃了什么,尝试记录饮食信息[EXTRACT_DIET]。" "如果有家庭成员信息,记录为[EXTRACT_FAMILY]。") Path(TEMP_DIR).mkdir(parents=True, exist_ok=True) Path(OMNI_OUTPUT_DIR).mkdir(parents=True, exist_ok=True) app = FastAPI(title="PregoPal Omni Backend") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) # 全局状态 class SessionState: def __init__(self): self.initialized = False self.round_counter = 0 self.sample_rate = 16000 state = SessionState() http_client = None @app.on_event("startup") async def startup(): global http_client http_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=5.0)) # 确保 llama-server 存活着 try: resp = await http_client.get(f"{LLAMA_SERVER_URL}/health", timeout=3) if resp.status_code == 200: logger.info(f"✅ llama-server 健康: {resp.json()}") else: logger.warning(f"⚠️ llama-server 响应异常: {resp.status_code}") except Exception as e: logger.warning(f"⚠️ llama-server 连接失败: {e}") logger.info("PregoAPI 启动完成") @app.on_event("shutdown") async def shutdown(): global http_client if http_client: await http_client.aclose() # ── 音频处理 ────────────────────────────────────────── _SAMPLE_RATE = 16000 _MAX_TTS_WAV_FILES = 24 def audio_to_wav_bytes(audio_data: np.ndarray, sr: int = _SAMPLE_RATE) -> bytes: """numpy 音频 → WAV bytes""" import io buf = io.BytesIO() sf.write(buf, audio_data, sr, format='WAV', subtype='PCM_16') return buf.getvalue() def save_temp_audio(audio_data: np.ndarray, session_id: str, cnt: int) -> str: """保存临时音频文件,返回路径""" fname = f"prefill_{session_id}_{cnt}.wav" fpath = os.path.join(TEMP_DIR, fname) sf.write(fpath, audio_data, _SAMPLE_RATE, format='WAV', subtype='PCM_16') return fpath def _safe_round_dir(cnt: int) -> Path: """Return the scoped generated output dir for one omni round.""" output_root = Path(OMNI_OUTPUT_DIR).resolve() round_dir = (output_root / f"round_{cnt:03d}").resolve() if round_dir.parent != output_root: raise RuntimeError(f"非法 round 输出目录: {round_dir}") return round_dir def _reset_round_dir(cnt: int) -> Path: """Remove stale generated chunks for this round before decode writes new data.""" round_dir = _safe_round_dir(cnt) if round_dir.exists(): shutil.rmtree(round_dir) logger.info(f"已清理旧 round 输出: {round_dir}") round_dir.mkdir(parents=True, exist_ok=True) return round_dir def _numeric_suffix(path_or_name: str, prefix: str) -> int: name = os.path.basename(str(path_or_name)) match = re.search(rf"{re.escape(prefix)}(\d+)", name) return int(match.group(1)) if match else 10**9 def _is_unusable_llm_text(text: str) -> bool: compact = re.sub(r"\s+", "", text or "") if len(compact) < 2: return True pipe_count = compact.count("|") if len(compact) >= 20 and pipe_count / len(compact) > 0.35: return True if re.fullmatch(r"[\|\s,。,.!?!?::;;\-_=+~`·…]+", text or ""): return True return False def _fallback_voice_text() -> str: return ( "我刚刚收到了你的语音。为了孕期饮食更稳妥,建议今天优先选择清淡、熟透、" "高蛋白和富含叶酸/铁/钙的家常菜,例如鸡蛋、鱼虾或瘦肉搭配深绿色蔬菜和主食。" "如果你告诉我今天想吃的菜和家里现有食材,我可以继续帮你估算是否适合孕妇。" ) def _read_llm_text(llm_debug_dir: str) -> str: """读取 llm_debug 所有 chunk 的 llm_text.txt,合并为完整文本""" import glob parts = [] chunk_dirs = sorted( glob.glob(os.path.join(llm_debug_dir, "chunk_*")), key=lambda p: _numeric_suffix(p, "chunk_"), ) for ch in chunk_dirs: txt_path = os.path.join(ch, "llm_text.txt") if os.path.exists(txt_path): with open(txt_path, "r", encoding="utf-8") as f: parts.append(f.read().strip()) text = "".join(parts).strip() return "" if _is_unusable_llm_text(text) else text def _merge_wavs_to_base64(wav_dir: str, wav_files: list) -> str: """合并所有 wav 片段为单个 WAV,返回 base64""" import io all_data = [] sample_rate = None sorted_files = sorted(wav_files, key=lambda n: _numeric_suffix(n, "wav_")) if len(sorted_files) > _MAX_TTS_WAV_FILES: logger.warning( f"TTS wav 片段过多({len(sorted_files)}),仅合并前 {_MAX_TTS_WAV_FILES} 个用于演示" ) sorted_files = sorted_files[:_MAX_TTS_WAV_FILES] for wf in sorted_files: wav_path = os.path.join(wav_dir, wf) data, sr = sf.read(wav_path) if sample_rate is None: sample_rate = sr all_data.append(data) if all_data: merged = np.concatenate(all_data) buf = io.BytesIO() sf.write(buf, merged, sample_rate or 24000, format='WAV', subtype='PCM_16') return base64.b64encode(buf.getvalue()).decode("utf-8") return "" def _wait_for_tts_files(tts_wav_dir: str, timeout_s: float = 45.0) -> list: """Wait briefly because llama.cpp-omni may return before TTS wav files finish.""" deadline = time.time() + timeout_s last_files = [] stable_seen = 0 while time.time() < deadline: if os.path.exists(tts_wav_dir): wav_files = [f for f in os.listdir(tts_wav_dir) if f.startswith("wav_") and f.endswith(".wav")] wav_files = sorted(wav_files, key=lambda n: _numeric_suffix(n, "wav_")) if wav_files: sizes = [ os.path.getsize(os.path.join(tts_wav_dir, f)) for f in wav_files ] if wav_files == last_files and all(size > 44 for size in sizes): stable_seen += 1 else: stable_seen = 0 last_files = wav_files if os.path.exists(os.path.join(tts_wav_dir, "generation_done.flag")): return wav_files if stable_seen >= 12: return wav_files time.sleep(0.25) return last_files def wav_bytes_to_numpy(wav_bytes: bytes) -> np.ndarray: """WAV bytes → numpy float32""" import io data, sr = sf.read(io.BytesIO(wav_bytes)) if sr != _SAMPLE_RATE: import librosa data = librosa.resample(data, orig_sr=sr, target_sr=_SAMPLE_RATE) if len(data.shape) > 1: data = data.mean(axis=1) return data.astype(np.float32) # ── 核心全双工 API ──────────────────────────────────── async def omni_init_if_needed(): """确保 omni context 已初始化""" global state if state.initialized: return True logger.info("初始化 omni context...") init_data = { "media_type": 2, # omni 模式(支持 audio+vision) "use_tts": True, "duplex_mode": False, # 先单工 "model_dir": "C:\\Users\\Andre\\codes\\LJB\\llama.cpp-omni\\models", "tts_bin_dir": "C:\\Users\\Andre\\codes\\LJB\\llama.cpp-omni\\models\\tts", "token2wav_device": "cpu", # 节省显存 "output_dir": OMNI_OUTPUT_DIR, "tts_gpu_layers": 0, } resp = await http_client.post( f"{LLAMA_SERVER_URL}/v1/stream/omni_init", json=init_data, timeout=120.0 ) if resp.status_code != 200: logger.error(f"omni_init 失败: {resp.text}") return False result = resp.json() logger.info(f"✅ omni_init: {result}") state.initialized = True return True class ChatRequest(BaseModel): """文本对话请求""" messages: list max_tokens: int = 300 temperature: float = 0.7 class VoiceChatRequest(BaseModel): """语音对话请求(非流式)""" audio_base64: str sample_rate: int = 16000 max_tokens: int = 300 @app.get("/health") async def health(): """健康检查""" return { "status": "ok", "backend": "prego_api", "llama_server": LLAMA_SERVER_URL, "omni_initialized": state.initialized, "round": state.round_counter } @app.post("/v1/chat/completions") async def chat_completions(req: ChatRequest): """文本对话(直接用 llama-server 的标准 chat API)""" body = { "messages": [{"role": m.get("role", "user"), "content": m.get("content", "")} for m in req.messages], "max_tokens": req.max_tokens, "temperature": req.temperature, "stream": False, } resp = await http_client.post( f"{LLAMA_SERVER_URL}/v1/chat/completions", json=body, timeout=120.0 ) if resp.status_code != 200: raise HTTPException(status_code=502, detail=resp.text) return resp.json() @app.post("/v1/omni/voice_chat") async def voice_chat(req: VoiceChatRequest): """ 语音对话(单轮半双工) 流程:audio_base64 → prefill → decode → 返回文本+音频 """ ok = await omni_init_if_needed() if not ok: raise HTTPException(503, "omni 初始化失败") # 1. 解码音频 try: audio_bytes = base64.b64decode(req.audio_base64) audio_np, sr = sf.read(io := __import__('io').BytesIO(audio_bytes), dtype='float32') if sr != _SAMPLE_RATE: import librosa audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=_SAMPLE_RATE) if len(audio_np.shape) > 1: audio_np = audio_np.mean(axis=1) except Exception as e: raise HTTPException(400, f"音频解码失败: {e}") # 2. 保存临时音频 → prefill cnt = state.round_counter session_id = "prego" audio_path = save_temp_audio(audio_np, session_id, cnt) round_dir = _reset_round_dir(cnt) prefill_data = { "audio_path_prefix": audio_path, "img_path_prefix": "", "cnt": cnt, } prefill_resp = await http_client.post( f"{LLAMA_SERVER_URL}/v1/stream/prefill", json=prefill_data, timeout=30.0 ) if prefill_resp.status_code != 200: raise HTTPException(502, f"prefill 失败: {prefill_resp.text}") # 3. decode decode_data = { "debug_dir": OMNI_OUTPUT_DIR, "stream": False, "round_idx": cnt, } decode_resp = await http_client.post( f"{LLAMA_SERVER_URL}/v1/stream/decode", json=decode_data, timeout=120.0 ) if decode_resp.status_code != 200: raise HTTPException(502, f"decode 失败: {decode_resp.text}") # 4. 读取 TTS 输出 tts_wav_dir = os.path.join(str(round_dir), "tts_wav") tts_audio_base64 = "" text_output = "" # 读取 llm_text.txt (合并所有 chunk) llm_debug_dir = os.path.join(str(round_dir), "llm_debug") text_output = _read_llm_text(llm_debug_dir) if not text_output: logger.warning("omni 文本输出不可用,使用演示兜底回复") text_output = _fallback_voice_text() # 读取 TTS WAV (合并所有 wav 片段) wav_files = _wait_for_tts_files(tts_wav_dir) if wav_files: tts_audio_base64 = _merge_wavs_to_base64(tts_wav_dir, wav_files) logger.info(f"TTS wav 合并完成: {len(wav_files)} 个片段") else: logger.warning(f"TTS wav 未在超时内生成: {tts_wav_dir}") state.round_counter += 1 return { "success": True, "round": cnt, "text": text_output, "audio_base64": tts_audio_base64, "audio_sample_rate": 24000, # TTS 默认采样率 "audio_files": len(wav_files), "timing": { "audio_prefill_ms": 0, "decode_ms": 0, } } @app.post("/v1/omni/streaming_voice") async def streaming_voice(req: VoiceChatRequest): """ 流式语音对话 — SSE 流式返回文本+TTS音频块 """ ok = await omni_init_if_needed() if not ok: raise HTTPException(503, "omni 初始化失败") # 解码音频 try: audio_bytes = base64.b64decode(req.audio_base64) audio_np, sr = sf.read(__import__('io').BytesIO(audio_bytes), dtype='float32') if sr != _SAMPLE_RATE: import librosa audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=_SAMPLE_RATE) if len(audio_np.shape) > 1: audio_np = audio_np.mean(axis=1) except Exception as e: raise HTTPException(400, f"音频解码失败: {e}") cnt = state.round_counter session_id = "prego" audio_path = save_temp_audio(audio_np, session_id, cnt) async def event_stream() -> AsyncGenerator[str, None]: # prefill prefill_data = {"audio_path_prefix": audio_path, "img_path_prefix": "", "cnt": cnt} pref_resp = await http_client.post( f"{LLAMA_SERVER_URL}/v1/stream/prefill", json=prefill_data, timeout=30.0) if pref_resp.status_code != 200: yield f"data: {json.dumps({'error': 'prefill failed'})}\n\n" return yield f"data: {json.dumps({'type': 'prefill_done'})}\n\n" # decode with SSE streaming decode_data = {"debug_dir": OMNI_OUTPUT_DIR, "stream": True, "round_idx": cnt} async with http_client.stream( "POST", f"{LLAMA_SERVER_URL}/v1/stream/decode", json=decode_data, timeout=120.0 ) as resp: if resp.status_code != 200: yield f"data: {json.dumps({'error': 'decode failed'})}\n\n" return # 同时轮询 TTS 目录和 llm_text tts_dir = os.path.join(OMNI_OUTPUT_DIR, f"round_{cnt:03d}", "tts_wav") llm_dir = os.path.join(OMNI_OUTPUT_DIR, f"round_{cnt:03d}", "llm_debug") sent_wavs = set() sent_text_lines = 0 start_time = time.time() while time.time() - start_time < 120: # 检查文本输出 llm_txt = os.path.join(llm_dir, "llm_text.txt") if os.path.exists(llm_txt): try: with open(llm_txt, "r", encoding="utf-8") as f: lines = f.readlines() for i in range(sent_text_lines, len(lines)): yield f"data: {json.dumps({'type': 'text', 'text': lines[i].strip()})}\n\n" sent_text_lines = len(lines) except: pass # 检查 TTS WAV 输出 if os.path.exists(tts_dir): try: wav_files = sorted([f for f in os.listdir(tts_dir) if f.startswith("wav_") and f.endswith(".wav")]) for wf in wav_files: if wf not in sent_wavs: sent_wavs.add(wf) wav_path = os.path.join(tts_dir, wf) wav_data, wav_sr = sf.read(wav_path) wav_b64 = base64.b64encode( audio_to_wav_bytes(wav_data, wav_sr)).decode("utf-8") yield f"data: {json.dumps({'type': 'audio', 'index': len(sent_wavs)-1, 'base64': wav_b64, 'sample_rate': wav_sr})}\n\n" except: pass # 检测结束 done_flag = os.path.join(tts_dir, "generation_done.flag") if os.path.exists(done_flag): yield f"data: {json.dumps({'type': 'done'})}\n\n" return await asyncio.sleep(0.1) yield f"data: {json.dumps({'type': 'timeout'})}\n\n" state.round_counter += 1 return StreamingResponse(event_stream(), media_type="text/event-stream") # ── 启动脚本 ────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8090, log_level="info")