Spaces:
Sleeping
Sleeping
File size: 7,294 Bytes
408f650 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from prompts import (
CANDIDATE_GENERATION_PROMPT,
CLIENT_SIMULATOR_PROMPT,
MCTS_EVALUATOR_PROMPT,
)
class MCTSReasoner:
def __init__(self):
base_kwargs = dict(
model="deepseek-chat",
base_url="https://api.deepseek.com/v1",
api_key=os.getenv("DEEPSEEK_API_KEY"),
)
self.gen_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=1024)
self.sim_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=256)
self.eval_llm = ChatOpenAI(**base_kwargs, temperature=0.0, max_tokens=64)
def _format_history(self, history):
"""将 langchain 消息历史格式化为可读文本(跳过 system message)。"""
lines = []
for msg in history:
if isinstance(msg, HumanMessage):
lines.append(f"来访者:{msg.content}")
elif isinstance(msg, AIMessage):
lines.append(f"咨询师:{msg.content}")
return "\n".join(lines) if lines else "(首次对话)"
def _parse_json(self, text):
"""从 LLM 输出中提取 JSON。"""
content = text.strip()
start = content.find("[") if "[" in content else content.find("{")
end = content.rfind("]") + 1 if "[" in content else content.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError(f"无法解析 JSON: {content[:100]}")
return json.loads(content[start:end])
def generate_candidates(self, history, user_message):
"""Step 1: 生成 5 个候选咨询师回复。"""
prompt = CANDIDATE_GENERATION_PROMPT.replace(
"{conversation_history}", self._format_history(history)
).replace("{user_message}", user_message)
for attempt in range(3):
try:
result = self.gen_llm.invoke(prompt)
candidates = self._parse_json(result.content)
return candidates
except (json.JSONDecodeError, ValueError):
if attempt == 2:
raise
def _simulate_one(self, candidate, history_text, user_message):
"""模拟单个候选回复的来访者反应。"""
prompt = CLIENT_SIMULATOR_PROMPT.replace(
"{conversation_history}", history_text
).replace(
"{user_message}", user_message
).replace(
"{therapist_response}", candidate["response"]
)
for attempt in range(2):
try:
result = self.sim_llm.invoke(prompt)
parsed = self._parse_json(result.content)
return {
"id": candidate["id"],
"simulated_client_response": parsed.get("simulated_response", ""),
"emotional_state": parsed.get("emotional_state", ""),
}
except (json.JSONDecodeError, ValueError):
if attempt == 1:
return {
"id": candidate["id"],
"simulated_client_response": "(模拟失败)",
"emotional_state": "未知",
}
def simulate_client_reactions(self, candidates, history, user_message):
"""Step 2: 并行模拟来访者对每个候选回复的反应。"""
history_text = self._format_history(history)
simulations = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = {
executor.submit(
self._simulate_one, c, history_text, user_message
): c["id"]
for c in candidates
}
for future in as_completed(futures):
simulations.append(future.result())
simulations.sort(key=lambda x: x["id"])
return simulations
def _evaluate_one(self, simulation):
"""评估单个模拟反应的揭露深度。"""
prompt = MCTS_EVALUATOR_PROMPT.replace(
"{client_response}", simulation["simulated_client_response"]
)
for attempt in range(2):
try:
result = self.eval_llm.invoke(prompt)
parsed = self._parse_json(result.content)
return {
"id": simulation["id"],
"score": max(0, min(10, int(parsed.get("score", 0)))),
"reason": parsed.get("reason", ""),
}
except (json.JSONDecodeError, ValueError):
if attempt == 1:
return {"id": simulation["id"], "score": 0, "reason": "评估解析失败"}
def evaluate_disclosures(self, simulations):
"""Step 3: 并行评估每个模拟反应的揭露深度。"""
evaluations = []
with ThreadPoolExecutor(max_workers=5) as executor:
futures = {
executor.submit(self._evaluate_one, s): s["id"] for s in simulations
}
for future in as_completed(futures):
evaluations.append(future.result())
evaluations.sort(key=lambda x: x["id"])
return evaluations
def select_best(self, candidates, simulations, evaluations):
"""Step 4: 选择最高分候选。分数相同时优先选择情感深度更高的。"""
max_score = max(e["score"] for e in evaluations)
top_candidates = [e for e in evaluations if e["score"] == max_score]
if len(top_candidates) == 1:
best_id = top_candidates[0]["id"]
reason = "最高揭露深度评分"
else:
# 同分时,找模拟反应中情感状态描述最长的(粗略代理情感深度)
best_id = top_candidates[0]["id"]
max_depth = 0
for tc in top_candidates:
sim = next((s for s in simulations if s["id"] == tc["id"]), None)
if sim:
depth = len(sim.get("emotional_state", ""))
if depth > max_depth:
max_depth = depth
best_id = tc["id"]
reason = "同分中情感深度更高"
best_response = next(c["response"] for c in candidates if c["id"] == best_id)
return best_id, best_response, reason
def run(self, history, user_message):
"""执行完整 MCTS 推理流程,返回最优回复和完整 trace。"""
# Step 1
candidates = self.generate_candidates(history, user_message)
# Step 2
simulations = self.simulate_client_reactions(candidates, history, user_message)
# Step 3
evaluations = self.evaluate_disclosures(simulations)
# Step 4
selected_id, best_response, selection_reason = self.select_best(
candidates, simulations, evaluations
)
mcts_trace = {
"candidates": candidates,
"simulations": simulations,
"evaluations": evaluations,
"selected": selected_id,
"selection_reason": selection_reason,
}
return best_response, mcts_trace
|