Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| from prompts import ( | |
| CANDIDATE_GENERATION_PROMPT, | |
| CLIENT_SIMULATOR_PROMPT, | |
| MCTS_EVALUATOR_PROMPT, | |
| ) | |
| class MCTSReasoner: | |
| def __init__(self): | |
| base_kwargs = dict( | |
| model="deepseek-chat", | |
| base_url="https://api.deepseek.com/v1", | |
| api_key=os.getenv("DEEPSEEK_API_KEY"), | |
| ) | |
| self.gen_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=1024) | |
| self.sim_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=256) | |
| self.eval_llm = ChatOpenAI(**base_kwargs, temperature=0.0, max_tokens=64) | |
| def _format_history(self, history): | |
| """将 langchain 消息历史格式化为可读文本(跳过 system message)。""" | |
| lines = [] | |
| for msg in history: | |
| if isinstance(msg, HumanMessage): | |
| lines.append(f"来访者:{msg.content}") | |
| elif isinstance(msg, AIMessage): | |
| lines.append(f"咨询师:{msg.content}") | |
| return "\n".join(lines) if lines else "(首次对话)" | |
| def _parse_json(self, text): | |
| """从 LLM 输出中提取 JSON。""" | |
| content = text.strip() | |
| start = content.find("[") if "[" in content else content.find("{") | |
| end = content.rfind("]") + 1 if "[" in content else content.rfind("}") + 1 | |
| if start == -1 or end == 0: | |
| raise ValueError(f"无法解析 JSON: {content[:100]}") | |
| return json.loads(content[start:end]) | |
| def generate_candidates(self, history, user_message): | |
| """Step 1: 生成 5 个候选咨询师回复。""" | |
| prompt = CANDIDATE_GENERATION_PROMPT.replace( | |
| "{conversation_history}", self._format_history(history) | |
| ).replace("{user_message}", user_message) | |
| for attempt in range(3): | |
| try: | |
| result = self.gen_llm.invoke(prompt) | |
| candidates = self._parse_json(result.content) | |
| return candidates | |
| except (json.JSONDecodeError, ValueError): | |
| if attempt == 2: | |
| raise | |
| def _simulate_one(self, candidate, history_text, user_message): | |
| """模拟单个候选回复的来访者反应。""" | |
| prompt = CLIENT_SIMULATOR_PROMPT.replace( | |
| "{conversation_history}", history_text | |
| ).replace( | |
| "{user_message}", user_message | |
| ).replace( | |
| "{therapist_response}", candidate["response"] | |
| ) | |
| for attempt in range(2): | |
| try: | |
| result = self.sim_llm.invoke(prompt) | |
| parsed = self._parse_json(result.content) | |
| return { | |
| "id": candidate["id"], | |
| "simulated_client_response": parsed.get("simulated_response", ""), | |
| "emotional_state": parsed.get("emotional_state", ""), | |
| } | |
| except (json.JSONDecodeError, ValueError): | |
| if attempt == 1: | |
| return { | |
| "id": candidate["id"], | |
| "simulated_client_response": "(模拟失败)", | |
| "emotional_state": "未知", | |
| } | |
| def simulate_client_reactions(self, candidates, history, user_message): | |
| """Step 2: 并行模拟来访者对每个候选回复的反应。""" | |
| history_text = self._format_history(history) | |
| simulations = [] | |
| with ThreadPoolExecutor(max_workers=5) as executor: | |
| futures = { | |
| executor.submit( | |
| self._simulate_one, c, history_text, user_message | |
| ): c["id"] | |
| for c in candidates | |
| } | |
| for future in as_completed(futures): | |
| simulations.append(future.result()) | |
| simulations.sort(key=lambda x: x["id"]) | |
| return simulations | |
| def _evaluate_one(self, simulation): | |
| """评估单个模拟反应的揭露深度。""" | |
| prompt = MCTS_EVALUATOR_PROMPT.replace( | |
| "{client_response}", simulation["simulated_client_response"] | |
| ) | |
| for attempt in range(2): | |
| try: | |
| result = self.eval_llm.invoke(prompt) | |
| parsed = self._parse_json(result.content) | |
| return { | |
| "id": simulation["id"], | |
| "score": max(0, min(10, int(parsed.get("score", 0)))), | |
| "reason": parsed.get("reason", ""), | |
| } | |
| except (json.JSONDecodeError, ValueError): | |
| if attempt == 1: | |
| return {"id": simulation["id"], "score": 0, "reason": "评估解析失败"} | |
| def evaluate_disclosures(self, simulations): | |
| """Step 3: 并行评估每个模拟反应的揭露深度。""" | |
| evaluations = [] | |
| with ThreadPoolExecutor(max_workers=5) as executor: | |
| futures = { | |
| executor.submit(self._evaluate_one, s): s["id"] for s in simulations | |
| } | |
| for future in as_completed(futures): | |
| evaluations.append(future.result()) | |
| evaluations.sort(key=lambda x: x["id"]) | |
| return evaluations | |
| def select_best(self, candidates, simulations, evaluations): | |
| """Step 4: 选择最高分候选。分数相同时优先选择情感深度更高的。""" | |
| max_score = max(e["score"] for e in evaluations) | |
| top_candidates = [e for e in evaluations if e["score"] == max_score] | |
| if len(top_candidates) == 1: | |
| best_id = top_candidates[0]["id"] | |
| reason = "最高揭露深度评分" | |
| else: | |
| # 同分时,找模拟反应中情感状态描述最长的(粗略代理情感深度) | |
| best_id = top_candidates[0]["id"] | |
| max_depth = 0 | |
| for tc in top_candidates: | |
| sim = next((s for s in simulations if s["id"] == tc["id"]), None) | |
| if sim: | |
| depth = len(sim.get("emotional_state", "")) | |
| if depth > max_depth: | |
| max_depth = depth | |
| best_id = tc["id"] | |
| reason = "同分中情感深度更高" | |
| best_response = next(c["response"] for c in candidates if c["id"] == best_id) | |
| return best_id, best_response, reason | |
| def run(self, history, user_message): | |
| """执行完整 MCTS 推理流程,返回最优回复和完整 trace。""" | |
| # Step 1 | |
| candidates = self.generate_candidates(history, user_message) | |
| # Step 2 | |
| simulations = self.simulate_client_reactions(candidates, history, user_message) | |
| # Step 3 | |
| evaluations = self.evaluate_disclosures(simulations) | |
| # Step 4 | |
| selected_id, best_response, selection_reason = self.select_best( | |
| candidates, simulations, evaluations | |
| ) | |
| mcts_trace = { | |
| "candidates": candidates, | |
| "simulations": simulations, | |
| "evaluations": evaluations, | |
| "selected": selected_id, | |
| "selection_reason": selection_reason, | |
| } | |
| return best_response, mcts_trace | |