freud-zero-mvp / mcts_reasoner.py
Feng Chike
Freud Zero MVP: 心理咨询AI系统(清洁部署)
408f650
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from prompts import (
CANDIDATE_GENERATION_PROMPT,
CLIENT_SIMULATOR_PROMPT,
MCTS_EVALUATOR_PROMPT,
)
class MCTSReasoner:
def __init__(self):
base_kwargs = dict(
model="deepseek-chat",
base_url="https://api.deepseek.com/v1",
api_key=os.getenv("DEEPSEEK_API_KEY"),
)
self.gen_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=1024)
self.sim_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=256)
self.eval_llm = ChatOpenAI(**base_kwargs, temperature=0.0, max_tokens=64)
def _format_history(self, history):
"""将 langchain 消息历史格式化为可读文本(跳过 system message)。"""
lines = []
for msg in history:
if isinstance(msg, HumanMessage):
lines.append(f"来访者:{msg.content}")
elif isinstance(msg, AIMessage):
lines.append(f"咨询师:{msg.content}")
return "\n".join(lines) if lines else "(首次对话)"
def _parse_json(self, text):
"""从 LLM 输出中提取 JSON。"""
content = text.strip()
start = content.find("[") if "[" in content else content.find("{")
end = content.rfind("]") + 1 if "[" in content else content.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError(f"无法解析 JSON: {content[:100]}")
return json.loads(content[start:end])
def generate_candidates(self, history, user_message):
"""Step 1: 生成 5 个候选咨询师回复。"""
prompt = CANDIDATE_GENERATION_PROMPT.replace(
"{conversation_history}", self._format_history(history)
).replace("{user_message}", user_message)
for attempt in range(3):
try:
result = self.gen_llm.invoke(prompt)
candidates = self._parse_json(result.content)
return candidates
except (json.JSONDecodeError, ValueError):
if attempt == 2:
raise
def _simulate_one(self, candidate, history_text, user_message):
"""模拟单个候选回复的来访者反应。"""
prompt = CLIENT_SIMULATOR_PROMPT.replace(
"{conversation_history}", history_text
).replace(
"{user_message}", user_message
).replace(
"{therapist_response}", candidate["response"]
)
for attempt in range(2):
try:
result = self.sim_llm.invoke(prompt)
parsed = self._parse_json(result.content)
return {
"id": candidate["id"],
"simulated_client_response": parsed.get("simulated_response", ""),
"emotional_state": parsed.get("emotional_state", ""),
}
except (json.JSONDecodeError, ValueError):
if attempt == 1:
return {
"id": candidate["id"],
"simulated_client_response": "(模拟失败)",
"emotional_state": "未知",
}
def simulate_client_reactions(self, candidates, history, user_message):
"""Step 2: 并行模拟来访者对每个候选回复的反应。"""
history_text = self._format_history(history)
simulations = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = {
executor.submit(
self._simulate_one, c, history_text, user_message
): c["id"]
for c in candidates
}
for future in as_completed(futures):
simulations.append(future.result())
simulations.sort(key=lambda x: x["id"])
return simulations
def _evaluate_one(self, simulation):
"""评估单个模拟反应的揭露深度。"""
prompt = MCTS_EVALUATOR_PROMPT.replace(
"{client_response}", simulation["simulated_client_response"]
)
for attempt in range(2):
try:
result = self.eval_llm.invoke(prompt)
parsed = self._parse_json(result.content)
return {
"id": simulation["id"],
"score": max(0, min(10, int(parsed.get("score", 0)))),
"reason": parsed.get("reason", ""),
}
except (json.JSONDecodeError, ValueError):
if attempt == 1:
return {"id": simulation["id"], "score": 0, "reason": "评估解析失败"}
def evaluate_disclosures(self, simulations):
"""Step 3: 并行评估每个模拟反应的揭露深度。"""
evaluations = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = {
executor.submit(self._evaluate_one, s): s["id"] for s in simulations
}
for future in as_completed(futures):
evaluations.append(future.result())
evaluations.sort(key=lambda x: x["id"])
return evaluations
def select_best(self, candidates, simulations, evaluations):
"""Step 4: 选择最高分候选。分数相同时优先选择情感深度更高的。"""
max_score = max(e["score"] for e in evaluations)
top_candidates = [e for e in evaluations if e["score"] == max_score]
if len(top_candidates) == 1:
best_id = top_candidates[0]["id"]
reason = "最高揭露深度评分"
else:
# 同分时,找模拟反应中情感状态描述最长的(粗略代理情感深度)
best_id = top_candidates[0]["id"]
max_depth = 0
for tc in top_candidates:
sim = next((s for s in simulations if s["id"] == tc["id"]), None)
if sim:
depth = len(sim.get("emotional_state", ""))
if depth > max_depth:
max_depth = depth
best_id = tc["id"]
reason = "同分中情感深度更高"
best_response = next(c["response"] for c in candidates if c["id"] == best_id)
return best_id, best_response, reason
def run(self, history, user_message):
"""执行完整 MCTS 推理流程,返回最优回复和完整 trace。"""
# Step 1
candidates = self.generate_candidates(history, user_message)
# Step 2
simulations = self.simulate_client_reactions(candidates, history, user_message)
# Step 3
evaluations = self.evaluate_disclosures(simulations)
# Step 4
selected_id, best_response, selection_reason = self.select_best(
candidates, simulations, evaluations
)
mcts_trace = {
"candidates": candidates,
"simulations": simulations,
"evaluations": evaluations,
"selected": selected_id,
"selection_reason": selection_reason,
}
return best_response, mcts_trace