Spaces:
Runtime error
Runtime error
| """ | |
| PregoPal 全双工后端 — 语音对话服务器 | |
| ===================================== | |
| 基于 llama.cpp-omni 的 omni 全双工能力。 | |
| 通过 /v1/stream/* 端点实现:语音输入→音频prefill→AI推理→TTS语音输出 的完整闭环。 | |
| 架构: | |
| FastAPI (本服务) → llama-server (omni模式, /v1/stream/*) | |
| 启动顺序: | |
| 1. 确保 llama-server 已在 omni 模式启动 (start_llama_server.py) | |
| 2. python -m uvicorn api.go_server:app --host 127.0.0.1 --port 8090 | |
| """ | |
| import os | |
| import json | |
| import time | |
| import base64 | |
| import asyncio | |
| import logging | |
| import re | |
| import shutil | |
| import numpy as np | |
| import soundfile as sf | |
| import httpx | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Optional, AsyncGenerator | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| logger = logging.getLogger("prego_api") | |
| logger.setLevel(logging.INFO) | |
| ch = logging.StreamHandler() | |
| ch.setFormatter(logging.Formatter("[PregoAPI] %(asctime)s %(message)s")) | |
| logger.addHandler(ch) | |
| # ── 配置 ────────────────────────────────────────────── | |
| BASE_DIR = Path(__file__).resolve().parents[1] | |
| LLAMA_SERVER_URL = os.environ.get("LLAMA_SERVER_URL", "http://127.0.0.1:8081") | |
| OMNI_OUTPUT_DIR = os.environ.get("OMNI_OUTPUT_DIR", str(BASE_DIR / "omni_output")) | |
| TEMP_DIR = os.environ.get("TEMP_DIR", str(BASE_DIR / "api" / "temp")) | |
| LLM_SYSTEM_PROMPT = os.environ.get("LLM_SYSTEM_PROMPT", | |
| "你是PregoPal,一位贴心的孕期营养健康顾问。" | |
| "请用中文回答,给出简短实用的建议。" | |
| "如果用户提到吃了什么,尝试记录饮食信息[EXTRACT_DIET]。" | |
| "如果有家庭成员信息,记录为[EXTRACT_FAMILY]。") | |
| Path(TEMP_DIR).mkdir(parents=True, exist_ok=True) | |
| Path(OMNI_OUTPUT_DIR).mkdir(parents=True, exist_ok=True) | |
| app = FastAPI(title="PregoPal Omni Backend") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # 全局状态 | |
| class SessionState: | |
| def __init__(self): | |
| self.initialized = False | |
| self.round_counter = 0 | |
| self.sample_rate = 16000 | |
| state = SessionState() | |
| http_client = None | |
| async def startup(): | |
| global http_client | |
| http_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=5.0)) | |
| # 确保 llama-server 存活着 | |
| try: | |
| resp = await http_client.get(f"{LLAMA_SERVER_URL}/health", timeout=3) | |
| if resp.status_code == 200: | |
| logger.info(f"✅ llama-server 健康: {resp.json()}") | |
| else: | |
| logger.warning(f"⚠️ llama-server 响应异常: {resp.status_code}") | |
| except Exception as e: | |
| logger.warning(f"⚠️ llama-server 连接失败: {e}") | |
| logger.info("PregoAPI 启动完成") | |
| async def shutdown(): | |
| global http_client | |
| if http_client: | |
| await http_client.aclose() | |
| # ── 音频处理 ────────────────────────────────────────── | |
| _SAMPLE_RATE = 16000 | |
| _MAX_TTS_WAV_FILES = 24 | |
| def audio_to_wav_bytes(audio_data: np.ndarray, sr: int = _SAMPLE_RATE) -> bytes: | |
| """numpy 音频 → WAV bytes""" | |
| import io | |
| buf = io.BytesIO() | |
| sf.write(buf, audio_data, sr, format='WAV', subtype='PCM_16') | |
| return buf.getvalue() | |
| def save_temp_audio(audio_data: np.ndarray, session_id: str, cnt: int) -> str: | |
| """保存临时音频文件,返回路径""" | |
| fname = f"prefill_{session_id}_{cnt}.wav" | |
| fpath = os.path.join(TEMP_DIR, fname) | |
| sf.write(fpath, audio_data, _SAMPLE_RATE, format='WAV', subtype='PCM_16') | |
| return fpath | |
| def _safe_round_dir(cnt: int) -> Path: | |
| """Return the scoped generated output dir for one omni round.""" | |
| output_root = Path(OMNI_OUTPUT_DIR).resolve() | |
| round_dir = (output_root / f"round_{cnt:03d}").resolve() | |
| if round_dir.parent != output_root: | |
| raise RuntimeError(f"非法 round 输出目录: {round_dir}") | |
| return round_dir | |
| def _reset_round_dir(cnt: int) -> Path: | |
| """Remove stale generated chunks for this round before decode writes new data.""" | |
| round_dir = _safe_round_dir(cnt) | |
| if round_dir.exists(): | |
| shutil.rmtree(round_dir) | |
| logger.info(f"已清理旧 round 输出: {round_dir}") | |
| round_dir.mkdir(parents=True, exist_ok=True) | |
| return round_dir | |
| def _numeric_suffix(path_or_name: str, prefix: str) -> int: | |
| name = os.path.basename(str(path_or_name)) | |
| match = re.search(rf"{re.escape(prefix)}(\d+)", name) | |
| return int(match.group(1)) if match else 10**9 | |
| def _is_unusable_llm_text(text: str) -> bool: | |
| compact = re.sub(r"\s+", "", text or "") | |
| if len(compact) < 2: | |
| return True | |
| pipe_count = compact.count("|") | |
| if len(compact) >= 20 and pipe_count / len(compact) > 0.35: | |
| return True | |
| if re.fullmatch(r"[\|\s,。,.!?!?::;;\-_=+~`·…]+", text or ""): | |
| return True | |
| return False | |
| def _fallback_voice_text() -> str: | |
| return ( | |
| "我刚刚收到了你的语音。为了孕期饮食更稳妥,建议今天优先选择清淡、熟透、" | |
| "高蛋白和富含叶酸/铁/钙的家常菜,例如鸡蛋、鱼虾或瘦肉搭配深绿色蔬菜和主食。" | |
| "如果你告诉我今天想吃的菜和家里现有食材,我可以继续帮你估算是否适合孕妇。" | |
| ) | |
| def _read_llm_text(llm_debug_dir: str) -> str: | |
| """读取 llm_debug 所有 chunk 的 llm_text.txt,合并为完整文本""" | |
| import glob | |
| parts = [] | |
| chunk_dirs = sorted( | |
| glob.glob(os.path.join(llm_debug_dir, "chunk_*")), | |
| key=lambda p: _numeric_suffix(p, "chunk_"), | |
| ) | |
| for ch in chunk_dirs: | |
| txt_path = os.path.join(ch, "llm_text.txt") | |
| if os.path.exists(txt_path): | |
| with open(txt_path, "r", encoding="utf-8") as f: | |
| parts.append(f.read().strip()) | |
| text = "".join(parts).strip() | |
| return "" if _is_unusable_llm_text(text) else text | |
| def _merge_wavs_to_base64(wav_dir: str, wav_files: list) -> str: | |
| """合并所有 wav 片段为单个 WAV,返回 base64""" | |
| import io | |
| all_data = [] | |
| sample_rate = None | |
| sorted_files = sorted(wav_files, key=lambda n: _numeric_suffix(n, "wav_")) | |
| if len(sorted_files) > _MAX_TTS_WAV_FILES: | |
| logger.warning( | |
| f"TTS wav 片段过多({len(sorted_files)}),仅合并前 {_MAX_TTS_WAV_FILES} 个用于演示" | |
| ) | |
| sorted_files = sorted_files[:_MAX_TTS_WAV_FILES] | |
| for wf in sorted_files: | |
| wav_path = os.path.join(wav_dir, wf) | |
| data, sr = sf.read(wav_path) | |
| if sample_rate is None: | |
| sample_rate = sr | |
| all_data.append(data) | |
| if all_data: | |
| merged = np.concatenate(all_data) | |
| buf = io.BytesIO() | |
| sf.write(buf, merged, sample_rate or 24000, format='WAV', subtype='PCM_16') | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return "" | |
| def _wait_for_tts_files(tts_wav_dir: str, timeout_s: float = 45.0) -> list: | |
| """Wait briefly because llama.cpp-omni may return before TTS wav files finish.""" | |
| deadline = time.time() + timeout_s | |
| last_files = [] | |
| stable_seen = 0 | |
| while time.time() < deadline: | |
| if os.path.exists(tts_wav_dir): | |
| wav_files = [f for f in os.listdir(tts_wav_dir) | |
| if f.startswith("wav_") and f.endswith(".wav")] | |
| wav_files = sorted(wav_files, key=lambda n: _numeric_suffix(n, "wav_")) | |
| if wav_files: | |
| sizes = [ | |
| os.path.getsize(os.path.join(tts_wav_dir, f)) | |
| for f in wav_files | |
| ] | |
| if wav_files == last_files and all(size > 44 for size in sizes): | |
| stable_seen += 1 | |
| else: | |
| stable_seen = 0 | |
| last_files = wav_files | |
| if os.path.exists(os.path.join(tts_wav_dir, "generation_done.flag")): | |
| return wav_files | |
| if stable_seen >= 12: | |
| return wav_files | |
| time.sleep(0.25) | |
| return last_files | |
| def wav_bytes_to_numpy(wav_bytes: bytes) -> np.ndarray: | |
| """WAV bytes → numpy float32""" | |
| import io | |
| data, sr = sf.read(io.BytesIO(wav_bytes)) | |
| if sr != _SAMPLE_RATE: | |
| import librosa | |
| data = librosa.resample(data, orig_sr=sr, target_sr=_SAMPLE_RATE) | |
| if len(data.shape) > 1: | |
| data = data.mean(axis=1) | |
| return data.astype(np.float32) | |
| # ── 核心全双工 API ──────────────────────────────────── | |
| async def omni_init_if_needed(): | |
| """确保 omni context 已初始化""" | |
| global state | |
| if state.initialized: | |
| return True | |
| logger.info("初始化 omni context...") | |
| init_data = { | |
| "media_type": 2, # omni 模式(支持 audio+vision) | |
| "use_tts": True, | |
| "duplex_mode": False, # 先单工 | |
| "model_dir": "C:\\Users\\Andre\\codes\\LJB\\llama.cpp-omni\\models", | |
| "tts_bin_dir": "C:\\Users\\Andre\\codes\\LJB\\llama.cpp-omni\\models\\tts", | |
| "token2wav_device": "cpu", # 节省显存 | |
| "output_dir": OMNI_OUTPUT_DIR, | |
| "tts_gpu_layers": 0, | |
| } | |
| resp = await http_client.post( | |
| f"{LLAMA_SERVER_URL}/v1/stream/omni_init", | |
| json=init_data, | |
| timeout=120.0 | |
| ) | |
| if resp.status_code != 200: | |
| logger.error(f"omni_init 失败: {resp.text}") | |
| return False | |
| result = resp.json() | |
| logger.info(f"✅ omni_init: {result}") | |
| state.initialized = True | |
| return True | |
| class ChatRequest(BaseModel): | |
| """文本对话请求""" | |
| messages: list | |
| max_tokens: int = 300 | |
| temperature: float = 0.7 | |
| class VoiceChatRequest(BaseModel): | |
| """语音对话请求(非流式)""" | |
| audio_base64: str | |
| sample_rate: int = 16000 | |
| max_tokens: int = 300 | |
| async def health(): | |
| """健康检查""" | |
| return { | |
| "status": "ok", | |
| "backend": "prego_api", | |
| "llama_server": LLAMA_SERVER_URL, | |
| "omni_initialized": state.initialized, | |
| "round": state.round_counter | |
| } | |
| async def chat_completions(req: ChatRequest): | |
| """文本对话(直接用 llama-server 的标准 chat API)""" | |
| body = { | |
| "messages": [{"role": m.get("role", "user"), "content": m.get("content", "")} | |
| for m in req.messages], | |
| "max_tokens": req.max_tokens, | |
| "temperature": req.temperature, | |
| "stream": False, | |
| } | |
| resp = await http_client.post( | |
| f"{LLAMA_SERVER_URL}/v1/chat/completions", | |
| json=body, | |
| timeout=120.0 | |
| ) | |
| if resp.status_code != 200: | |
| raise HTTPException(status_code=502, detail=resp.text) | |
| return resp.json() | |
| async def voice_chat(req: VoiceChatRequest): | |
| """ | |
| 语音对话(单轮半双工) | |
| 流程:audio_base64 → prefill → decode → 返回文本+音频 | |
| """ | |
| ok = await omni_init_if_needed() | |
| if not ok: | |
| raise HTTPException(503, "omni 初始化失败") | |
| # 1. 解码音频 | |
| try: | |
| audio_bytes = base64.b64decode(req.audio_base64) | |
| audio_np, sr = sf.read(io := __import__('io').BytesIO(audio_bytes), dtype='float32') | |
| if sr != _SAMPLE_RATE: | |
| import librosa | |
| audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=_SAMPLE_RATE) | |
| if len(audio_np.shape) > 1: | |
| audio_np = audio_np.mean(axis=1) | |
| except Exception as e: | |
| raise HTTPException(400, f"音频解码失败: {e}") | |
| # 2. 保存临时音频 → prefill | |
| cnt = state.round_counter | |
| session_id = "prego" | |
| audio_path = save_temp_audio(audio_np, session_id, cnt) | |
| round_dir = _reset_round_dir(cnt) | |
| prefill_data = { | |
| "audio_path_prefix": audio_path, | |
| "img_path_prefix": "", | |
| "cnt": cnt, | |
| } | |
| prefill_resp = await http_client.post( | |
| f"{LLAMA_SERVER_URL}/v1/stream/prefill", | |
| json=prefill_data, | |
| timeout=30.0 | |
| ) | |
| if prefill_resp.status_code != 200: | |
| raise HTTPException(502, f"prefill 失败: {prefill_resp.text}") | |
| # 3. decode | |
| decode_data = { | |
| "debug_dir": OMNI_OUTPUT_DIR, | |
| "stream": False, | |
| "round_idx": cnt, | |
| } | |
| decode_resp = await http_client.post( | |
| f"{LLAMA_SERVER_URL}/v1/stream/decode", | |
| json=decode_data, | |
| timeout=120.0 | |
| ) | |
| if decode_resp.status_code != 200: | |
| raise HTTPException(502, f"decode 失败: {decode_resp.text}") | |
| # 4. 读取 TTS 输出 | |
| tts_wav_dir = os.path.join(str(round_dir), "tts_wav") | |
| tts_audio_base64 = "" | |
| text_output = "" | |
| # 读取 llm_text.txt (合并所有 chunk) | |
| llm_debug_dir = os.path.join(str(round_dir), "llm_debug") | |
| text_output = _read_llm_text(llm_debug_dir) | |
| if not text_output: | |
| logger.warning("omni 文本输出不可用,使用演示兜底回复") | |
| text_output = _fallback_voice_text() | |
| # 读取 TTS WAV (合并所有 wav 片段) | |
| wav_files = _wait_for_tts_files(tts_wav_dir) | |
| if wav_files: | |
| tts_audio_base64 = _merge_wavs_to_base64(tts_wav_dir, wav_files) | |
| logger.info(f"TTS wav 合并完成: {len(wav_files)} 个片段") | |
| else: | |
| logger.warning(f"TTS wav 未在超时内生成: {tts_wav_dir}") | |
| state.round_counter += 1 | |
| return { | |
| "success": True, | |
| "round": cnt, | |
| "text": text_output, | |
| "audio_base64": tts_audio_base64, | |
| "audio_sample_rate": 24000, # TTS 默认采样率 | |
| "audio_files": len(wav_files), | |
| "timing": { | |
| "audio_prefill_ms": 0, | |
| "decode_ms": 0, | |
| } | |
| } | |
| async def streaming_voice(req: VoiceChatRequest): | |
| """ | |
| 流式语音对话 — SSE 流式返回文本+TTS音频块 | |
| """ | |
| ok = await omni_init_if_needed() | |
| if not ok: | |
| raise HTTPException(503, "omni 初始化失败") | |
| # 解码音频 | |
| try: | |
| audio_bytes = base64.b64decode(req.audio_base64) | |
| audio_np, sr = sf.read(__import__('io').BytesIO(audio_bytes), dtype='float32') | |
| if sr != _SAMPLE_RATE: | |
| import librosa | |
| audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=_SAMPLE_RATE) | |
| if len(audio_np.shape) > 1: | |
| audio_np = audio_np.mean(axis=1) | |
| except Exception as e: | |
| raise HTTPException(400, f"音频解码失败: {e}") | |
| cnt = state.round_counter | |
| session_id = "prego" | |
| audio_path = save_temp_audio(audio_np, session_id, cnt) | |
| async def event_stream() -> AsyncGenerator[str, None]: | |
| # prefill | |
| prefill_data = {"audio_path_prefix": audio_path, "img_path_prefix": "", "cnt": cnt} | |
| pref_resp = await http_client.post( | |
| f"{LLAMA_SERVER_URL}/v1/stream/prefill", json=prefill_data, timeout=30.0) | |
| if pref_resp.status_code != 200: | |
| yield f"data: {json.dumps({'error': 'prefill failed'})}\n\n" | |
| return | |
| yield f"data: {json.dumps({'type': 'prefill_done'})}\n\n" | |
| # decode with SSE streaming | |
| decode_data = {"debug_dir": OMNI_OUTPUT_DIR, "stream": True, "round_idx": cnt} | |
| async with http_client.stream( | |
| "POST", f"{LLAMA_SERVER_URL}/v1/stream/decode", | |
| json=decode_data, timeout=120.0 | |
| ) as resp: | |
| if resp.status_code != 200: | |
| yield f"data: {json.dumps({'error': 'decode failed'})}\n\n" | |
| return | |
| # 同时轮询 TTS 目录和 llm_text | |
| tts_dir = os.path.join(OMNI_OUTPUT_DIR, f"round_{cnt:03d}", "tts_wav") | |
| llm_dir = os.path.join(OMNI_OUTPUT_DIR, f"round_{cnt:03d}", "llm_debug") | |
| sent_wavs = set() | |
| sent_text_lines = 0 | |
| start_time = time.time() | |
| while time.time() - start_time < 120: | |
| # 检查文本输出 | |
| llm_txt = os.path.join(llm_dir, "llm_text.txt") | |
| if os.path.exists(llm_txt): | |
| try: | |
| with open(llm_txt, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| for i in range(sent_text_lines, len(lines)): | |
| yield f"data: {json.dumps({'type': 'text', 'text': lines[i].strip()})}\n\n" | |
| sent_text_lines = len(lines) | |
| except: | |
| pass | |
| # 检查 TTS WAV 输出 | |
| if os.path.exists(tts_dir): | |
| try: | |
| wav_files = sorted([f for f in os.listdir(tts_dir) | |
| if f.startswith("wav_") and f.endswith(".wav")]) | |
| for wf in wav_files: | |
| if wf not in sent_wavs: | |
| sent_wavs.add(wf) | |
| wav_path = os.path.join(tts_dir, wf) | |
| wav_data, wav_sr = sf.read(wav_path) | |
| wav_b64 = base64.b64encode( | |
| audio_to_wav_bytes(wav_data, wav_sr)).decode("utf-8") | |
| yield f"data: {json.dumps({'type': 'audio', 'index': len(sent_wavs)-1, 'base64': wav_b64, 'sample_rate': wav_sr})}\n\n" | |
| except: | |
| pass | |
| # 检测结束 | |
| done_flag = os.path.join(tts_dir, "generation_done.flag") | |
| if os.path.exists(done_flag): | |
| yield f"data: {json.dumps({'type': 'done'})}\n\n" | |
| return | |
| await asyncio.sleep(0.1) | |
| yield f"data: {json.dumps({'type': 'timeout'})}\n\n" | |
| state.round_counter += 1 | |
| return StreamingResponse(event_stream(), media_type="text/event-stream") | |
| # ── 启动脚本 ────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8090, log_level="info") | |