import json import math import os import time from collections import Counter, defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, AIMessage from prompts import ( SUMMARY_AND_SEEDS_PROMPT, THERAPIST_REPLY_PROMPT, L2_MERGED_PROMPT, L3_MERGED_PROMPT, L4_MERGED_PROMPT, QUICK_EVAL_PROMPT, L5_L6_MERGED_PROMPT, RELATIVE_DISCLOSURE_EVAL_PROMPT, PATH_DISTILLATION_PROMPT, ) class StrategicAdvisor: """PUCT版: UCB自适应预算分配 + 可变深度探索""" def __init__(self, c_puct=1.5): dashscope_base = dict( base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", api_key=os.getenv("DASHSCOPE_API_KEY"), ) self.llm = ChatOpenAI(model="qwen-turbo", **dashscope_base, temperature=0.7, max_tokens=256) self.c_puct = c_puct # ===== 工具方法 ===== def _format_history(self, history): lines = [] for msg in history: if isinstance(msg, HumanMessage): lines.append(f"来访者:{msg.content}") elif isinstance(msg, AIMessage): lines.append(f"咨询师:{msg.content}") return "\n".join(lines) if lines else "(无)" def _parse_json(self, text): content = text.strip() start = content.find("{") end = content.rfind("}") + 1 if start == -1 or end == 0: raise ValueError(f"无法解析 JSON: {content[:100]}") return json.loads(content[start:end]) # ===== PUCT 核心 ===== def compute_ucb(self, paths): """为每条路径计算 UCB 分数""" seed_counts = Counter(p["id"] for p in paths) n_total = len(paths) for p in paths: q = p.get("_quick_score", p.get("score", 1)) / 10.0 n_seed = seed_counts[p["id"]] exploration = self.c_puct * math.sqrt(n_total) / (1 + n_seed) p["ucb"] = q + exploration return paths def allocate_budget(self, paths, total_budget, min_per_seed=1): """按 UCB 分配展开预算,保证每个种子至少 min_per_seed 个名额""" self.compute_ucb(paths) ranked = sorted(paths, key=lambda x: -x["ucb"]) # 保底:每个种子至少选一条 selected = [] seeds_seen = set() for p in ranked: if p["id"] not in seeds_seen: selected.append(p) seeds_seen.add(p["id"]) if len(selected) >= total_budget: return selected # 剩余预算按 UCB 排序填充 remaining = total_budget - len(selected) for p in ranked: if remaining <= 0: break if p not in selected: selected.append(p) remaining -= 1 return selected # ===== 快速评分 ===== def _quick_score_one(self, text, current_disclosure): prompt = QUICK_EVAL_PROMPT.replace( "{current_disclosure_score}", str(current_disclosure) ).replace("{user_message}", text) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) return max(1, min(10, int(parsed.get("score", 1)))) except Exception: return 1 def quick_score_paths(self, paths, text_key, current_disclosure): """对路径列表并行快速评分,结果写入 _quick_score 字段""" with ThreadPoolExecutor(max_workers=len(paths)) as executor: futures = {executor.submit(self._quick_score_one, p[text_key], current_disclosure): i for i, p in enumerate(paths)} for future in as_completed(futures): idx = futures[future] paths[idx]["_quick_score"] = future.result() return paths # ===== Step 1: 总结 + 3个种子 ===== def summarize_and_seeds(self, history): prompt = SUMMARY_AND_SEEDS_PROMPT.replace( "{conversation_history}", self._format_history(history) ) for attempt in range(3): try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) summary = parsed.pop("summary", "总结失败") seeds = {k: v for k, v in parsed.items() if k in "ABC"} for k in "ABC": seeds.setdefault(k, "从你独特的临床视角出发") return summary, seeds except (json.JSONDecodeError, ValueError): if attempt == 2: return "总结失败", {k: "从你独特的临床视角出发" for k in "ABC"} # ===== Step 2 / L1: 3×咨询师回复 ===== def _gen_therapist_reply(self, seed_id, seed, history_text): prompt = THERAPIST_REPLY_PROMPT.replace( "{conversation_history}", history_text ).replace("{seed_perspective}", seed) try: result = self.llm.invoke(prompt) return {"id": seed_id, "seed": seed, "reply": result.content.strip()} except Exception as e: return {"id": seed_id, "seed": seed, "reply": f"(生成失败: {e})"} def generate_l1(self, seeds, history): history_text = self._format_history(history) results = [] with ThreadPoolExecutor(max_workers=3) as executor: futures = {executor.submit(self._gen_therapist_reply, sid, seeds[sid], history_text): sid for sid in seeds} for future in as_completed(futures): results.append(future.result()) results.sort(key=lambda x: x["id"]) return results # ===== Step 3 / L2: 合并方向+来访者回应 ===== def _gen_l2_merged(self, l1_item, history_text): prompt = L2_MERGED_PROMPT.replace( "{conversation_history}", history_text ).replace("{therapist_reply}", l1_item["reply"]) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) return [{**l1_item, "l2_dir": did, "client_response": parsed.get(did, "(模拟失败)")} for did in ["A", "B"]] except Exception: return [{**l1_item, "l2_dir": d, "client_response": "(模拟失败)"} for d in ["A", "B"]] def generate_l2(self, l1_results, history): history_text = self._format_history(history) results = [] with ThreadPoolExecutor(max_workers=3) as executor: futures = {executor.submit(self._gen_l2_merged, item, history_text): item["id"] for item in l1_results} for future in as_completed(futures): results.extend(future.result()) results.sort(key=lambda x: (x["id"], x["l2_dir"])) return results # ===== Step 4 / L3: 合并种子+咨询师延续 ===== def _gen_l3_merged(self, l2_item, history_text): prompt = L3_MERGED_PROMPT.replace( "{conversation_history}", history_text ).replace("{l1_therapist_reply}", l2_item["reply"] ).replace("{l2_client_response}", l2_item["client_response"]) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) return [{ "id": l2_item["id"], "l2_dir": l2_item["l2_dir"], "branch": bid, "seed": l2_item["seed"], "l1_reply": l2_item["reply"], "l2_client": l2_item["client_response"], "l3_reply": parsed.get(bid, "(生成失败)"), } for bid in ["A", "B", "C"]] except Exception: return [{ "id": l2_item["id"], "l2_dir": l2_item["l2_dir"], "branch": b, "seed": l2_item["seed"], "l1_reply": l2_item["reply"], "l2_client": l2_item["client_response"], "l3_reply": "(生成失败)", } for b in ["A", "B", "C"]] def generate_l3(self, l2_selected, history): history_text = self._format_history(history) results = [] with ThreadPoolExecutor(max_workers=len(l2_selected)) as executor: futures = {executor.submit(self._gen_l3_merged, item, history_text): (item["id"], item["l2_dir"]) for item in l2_selected} for future in as_completed(futures): results.extend(future.result()) results.sort(key=lambda x: (x["id"], x["l2_dir"], x["branch"])) return results # ===== Step 5 / L4: 合并方向+来访者回应 ===== def _gen_l4_merged(self, l3_item, history_text): prompt = L4_MERGED_PROMPT.replace( "{conversation_history}", history_text ).replace("{l1_therapist_reply}", l3_item["l1_reply"] ).replace("{l2_client_response}", l3_item["l2_client"] ).replace("{l3_therapist_reply}", l3_item["l3_reply"]) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) return [{**l3_item, "l4_dir": did, "l4_client": parsed.get(did, "(模拟失败)")} for did in ["A", "B"]] except Exception: return [{**l3_item, "l4_dir": d, "l4_client": "(模拟失败)"} for d in ["A", "B"]] def generate_l4(self, l3_selected, history): history_text = self._format_history(history) results = [] with ThreadPoolExecutor(max_workers=len(l3_selected)) as executor: futures = {executor.submit(self._gen_l4_merged, item, history_text): (item["id"], item["l2_dir"], item["branch"]) for item in l3_selected} for future in as_completed(futures): results.extend(future.result()) results.sort(key=lambda x: (x["id"], x["l2_dir"], x["branch"], x.get("l4_dir", ""))) return results # ===== Step 5.5: 终评 ===== def _score_relative(self, item, current_disclosure): prompt = RELATIVE_DISCLOSURE_EVAL_PROMPT.replace( "{current_disclosure_score}", str(current_disclosure) ).replace("{user_message}", item["l4_client"]) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) score = max(1, min(10, int(parsed.get("score", 1)))) dims = {k: parsed.get(k, False) for k in "ABCDE"} return {**item, "score": score, "delta": score - current_disclosure, "dims": dims, "reason": parsed.get("reasoning", "")} except Exception: return {**item, "score": 1, "delta": 1 - current_disclosure, "dims": {}, "reason": "评分失败"} def score_all(self, l4_results, current_disclosure=1): results = [] with ThreadPoolExecutor(max_workers=len(l4_results)) as executor: futures = {executor.submit(self._score_relative, item, current_disclosure): i for i, item in enumerate(l4_results)} for future in as_completed(futures): results.append(future.result()) results.sort(key=lambda x: (x["id"], x.get("l2_dir", ""), x["branch"])) return results # ===== Step 6: 高UCB路径深度探索 (L5+L6) ===== def _gen_l5_l6(self, item, history_text): prompt = L5_L6_MERGED_PROMPT.replace( "{conversation_history}", history_text ).replace("{l1_therapist_reply}", item["l1_reply"] ).replace("{l2_client_response}", item["l2_client"] ).replace("{l3_therapist_reply}", item["l3_reply"] ).replace("{l4_client_response}", item["l4_client"]) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) return {**item, "l5_reply": parsed.get("l5_reply", "(生成失败)"), "l6_client": parsed.get("l6_client", "(模拟失败)"), "depth": 6} except Exception: return {**item, "l5_reply": "(生成失败)", "l6_client": "(模拟失败)", "depth": 6} def _score_l6(self, item, current_disclosure): prompt = RELATIVE_DISCLOSURE_EVAL_PROMPT.replace( "{current_disclosure_score}", str(current_disclosure) ).replace("{user_message}", item["l6_client"]) try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) score = max(1, min(10, int(parsed.get("score", 1)))) return {**item, "l6_score": score, "l6_delta": score - current_disclosure, "reason": parsed.get("reasoning", item.get("reason", ""))} except Exception: return {**item, "l6_score": item.get("score", 1), "l6_delta": 0} def deep_explore(self, top_paths, history, current_disclosure): """对 top UCB 路径进行 L5+L6 深度探索""" history_text = self._format_history(history) # 并行生成 L5+L6 deep_results = [] with ThreadPoolExecutor(max_workers=len(top_paths)) as executor: futures = {executor.submit(self._gen_l5_l6, item, history_text): i for i, item in enumerate(top_paths)} for future in as_completed(futures): deep_results.append(future.result()) # 并行评分 L6 scored = [] with ThreadPoolExecutor(max_workers=len(deep_results)) as executor: futures = {executor.submit(self._score_l6, item, current_disclosure): i for i, item in enumerate(deep_results)} for future in as_completed(futures): scored.append(future.result()) return scored # ===== Step 7: 蒸馏(UCB加权,深度路径×2) ===== def distill_paths(self, scored_4layer, deep_paths, summary): """合并4层和6层路径,按UCB加权选择蒸馏输入""" # 4层有效路径 effective_4 = [item for item in scored_4layer if item.get("delta", 0) > 0] # 6层路径(权重×2,复制一份进入排名) effective_6 = [] for item in deep_paths: item["_distill_weight"] = 2 effective_6.append(item) all_effective = effective_4 + effective_6 if not all_effective: # 退化:取4层最高分 ranked = sorted(scored_4layer, key=lambda x: x["score"], reverse=True) all_effective = [ranked[0]] if ranked else [] # 按加权分数降序:6层路径分数×1.5(深度奖励) def sort_key(x): base = x.get("l6_score", x.get("score", 0)) depth_bonus = 1.5 if x.get("depth") == 6 else 1.0 return base * depth_bonus ranked = sorted(all_effective, key=sort_key, reverse=True) top = ranked[:5] # 格式化 path_texts = [] for i, item in enumerate(top, 1): depth = item.get("depth", 4) if depth == 6: path_texts.append( f"路径{i}(种子{item['id']}.{item['branch']},深度=6轮,揭露度+{item.get('l6_delta', 0)}):\n" f" 咨询师①:{item['l1_reply']}\n" f" 来访者①:{item['l2_client']}\n" f" 咨询师②:{item['l3_reply']}\n" f" 来访者②:{item['l4_client']}\n" f" 咨询师③:{item['l5_reply']}\n" f" 来访者③:{item['l6_client']}" ) else: path_texts.append( f"路径{i}(种子{item['id']}.{item['branch']},深度=4轮,揭露度+{item.get('delta', 0)}):\n" f" 咨询师①:{item['l1_reply']}\n" f" 来访者①:{item['l2_client']}\n" f" 咨询师②:{item['l3_reply']}\n" f" 来访者②:{item['l4_client']}" ) effective_paths_text = "\n\n".join(path_texts) prompt = PATH_DISTILLATION_PROMPT.replace( "{summary}", summary ).replace("{effective_paths}", effective_paths_text) n_deep = sum(1 for t in top if t.get("depth") == 6) seeds_in = set(t["id"] for t in top) print(f"[PUCT] 蒸馏输入: {len(top)}条路径({n_deep}条6轮深度, {len(seeds_in)}个种子覆盖)") try: result = self.llm.invoke(prompt) parsed = self._parse_json(result.content) parsed["_distill_count"] = len(top) parsed["_deep_count"] = n_deep parsed["_distill_ids"] = [f"{i['id']}.{i['branch']}" for i in top] return parsed except Exception: best = top[0] if top else scored_4layer[0] return { "direction": best.get("seed", ""), "principles": [f"沿着「{best.get('seed', '')}」的方向继续探索"], "evidence": f"模拟显示{len(top)}条路径有效", "_distill_count": len(top), "_deep_count": 0, "_distill_ids": [f"{best['id']}.{best['branch']}"], } # ===== 完整 PUCT 流程 ===== def run(self, history, current_disclosure=1): total_start = time.time() # Step 1: 总结 + 3种子 t = time.time() summary, seeds = self.summarize_and_seeds(history) t1 = time.time() - t print(f"[PUCT] Step1 总结+种子: {t1:.1f}s | {summary[:60]}") for sid, seed in seeds.items(): print(f" {sid}: {seed[:50]}") # Step 2 / L1: 3×咨询师 t = time.time() l1 = self.generate_l1(seeds, history) t2 = time.time() - t print(f"[PUCT] L1 {len(l1)}×咨询师: {t2:.1f}s") # Step 3 / L2: 6×来访者 t = time.time() l2 = self.generate_l2(l1, history) t3 = time.time() - t print(f"[PUCT] L2 {len(l2)}×来访者: {t3:.1f}s") # Step 3.5: L2 快速评分 t = time.time() l2 = self.quick_score_paths(l2, "client_response", current_disclosure) t3_5 = time.time() - t print(f"[PUCT] L2快评: {t3_5:.1f}s") for item in l2: print(f" L2-{item['id']}.{item['l2_dir']}: qs={item['_quick_score']} | {item['client_response'][:30]}") # Step 4: UCB选择 → L3 (预算≤6条L2进入L3) l2_budget = min(6, len(l2)) # 最多全选 l2_selected = self.allocate_budget(l2, l2_budget) print(f"[PUCT] UCB选择L2→L3: {len(l2_selected)}条 (from {len(l2)})") for item in l2_selected: print(f" 选中 {item['id']}.{item['l2_dir']}: ucb={item['ucb']:.2f} qs={item['_quick_score']}") t = time.time() l3 = self.generate_l3(l2_selected, history) t4 = time.time() - t print(f"[PUCT] L3 {len(l3)}×咨询师: {t4:.1f}s") # Step 4.5: L3 快速评分(评估咨询师回应的推动效果) t = time.time() l3 = self.quick_score_paths(l3, "l3_reply", current_disclosure) t4_5 = time.time() - t print(f"[PUCT] L3快评: {t4_5:.1f}s") # Step 5: UCB选择 → L4 (预算≤12条L3进入L4) l3_budget = min(12, len(l3)) l3_selected = self.allocate_budget(l3, l3_budget) print(f"[PUCT] UCB选择L3→L4: {len(l3_selected)}条 (from {len(l3)})") t = time.time() l4 = self.generate_l4(l3_selected, history) t5 = time.time() - t print(f"[PUCT] L4 {len(l4)}×来访者: {t5:.1f}s") # Step 5.5: L4 终评 t = time.time() scored = self.score_all(l4, current_disclosure) t5_5 = time.time() - t print(f"[PUCT] L4终评({len(scored)}条): {t5_5:.1f}s") for item in sorted(scored, key=lambda x: -x["score"])[:5]: print(f" {item['id']}.{item.get('l2_dir','')}.{item['branch']}: score={item['score']} delta={item['delta']}") # 选当前最优 groups = defaultdict(list) for item in scored: groups[item["id"]].append(item) seed_best = {sid: max(items, key=lambda x: x["score"]) for sid, items in groups.items()} best = max(seed_best.values(), key=lambda x: x["score"]) print(f"[PUCT] 4层最优: {best['id']}.{best.get('l2_dir','')}.{best['branch']} score={best['score']} delta={best['delta']}") # Step 6: 高UCB路径深度探索 (L5+L6) # 从终评结果中选 top-3 by UCB self.compute_ucb(scored) top3 = sorted(scored, key=lambda x: -x["ucb"])[:3] top3_desc = [f"{p['id']}.{p.get('l2_dir','')}.{p['branch']}(ucb={p['ucb']:.2f})" for p in top3] print(f"[PUCT] 深度探索 top-3: {top3_desc}") t = time.time() deep_paths = self.deep_explore(top3, history, current_disclosure) t6 = time.time() - t print(f"[PUCT] L5+L6深探: {t6:.1f}s") for dp in deep_paths: print(f" 深探 {dp['id']}.{dp['branch']}: L5={dp['l5_reply'][:30]} → L6 score={dp.get('l6_score','?')} delta={dp.get('l6_delta','?')}") # 更新 best(如果深度路径更好) for dp in deep_paths: if dp.get("l6_score", 0) > best.get("score", 0): best = dp print(f"[PUCT] 深探更优: {dp['id']}.{dp['branch']} l6_score={dp['l6_score']}") # Step 7: 蒸馏 t = time.time() guidance = self.distill_paths(scored, deep_paths, summary) t7 = time.time() - t print(f"[PUCT] 蒸馏: {t7:.1f}s") print(f" 方向: {guidance.get('direction', '?')}") for p in guidance.get("principles", []): print(f" 原则: {p}") total_cost = time.time() - total_start print(f"[PUCT] 总耗时: {total_cost:.1f}s") strategic_trace = { "summary": summary, "seeds": seeds, "candidates": [ { "id": item["id"], "branch": item["branch"], "l1_reply": item["l1_reply"], "l2_client": item["l2_client"], "l3_reply": item["l3_reply"], "l4_client": item["l4_client"], "score": item["score"], "delta": item["delta"], "reason": item.get("reason", ""), } for item in scored ], "deep_paths": [ { "id": dp["id"], "branch": dp["branch"], "l5_reply": dp.get("l5_reply", ""), "l6_client": dp.get("l6_client", ""), "l6_score": dp.get("l6_score", 0), "l6_delta": dp.get("l6_delta", 0), } for dp in deep_paths ], "selected": f"{best['id']}.{best.get('l2_dir','')}.{best['branch']}", "guidance": guidance, "current_disclosure": current_disclosure, "timing": { "total_seconds": round(total_cost, 1), "step1_summary_seeds": round(t1, 1), "L1_therapist": round(t2, 1), "L2_merged": round(t3, 1), "L2_quick_score": round(t3_5, 1), "L3_merged": round(t4, 1), "L3_quick_score": round(t4_5, 1), "L4_merged": round(t5, 1), "L4_final_score": round(t5_5, 1), "L5_L6_deep": round(t6, 1), "distillation": round(t7, 1), }, } return best, guidance, strategic_trace