freud-zero-mvp / strategic_advisor.py
Feng Chike
Freud Zero MVP: 心理咨询AI系统(清洁部署)
408f650
import json
import math
import os
import time
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from prompts import (
SUMMARY_AND_SEEDS_PROMPT,
THERAPIST_REPLY_PROMPT, L2_MERGED_PROMPT, L3_MERGED_PROMPT, L4_MERGED_PROMPT,
QUICK_EVAL_PROMPT, L5_L6_MERGED_PROMPT,
RELATIVE_DISCLOSURE_EVAL_PROMPT, PATH_DISTILLATION_PROMPT,
)
class StrategicAdvisor:
"""PUCT版: UCB自适应预算分配 + 可变深度探索"""
def __init__(self, c_puct=1.5):
dashscope_base = dict(
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
api_key=os.getenv("DASHSCOPE_API_KEY"),
)
self.llm = ChatOpenAI(model="qwen-turbo", **dashscope_base, temperature=0.7, max_tokens=256)
self.c_puct = c_puct
# ===== 工具方法 =====
def _format_history(self, history):
lines = []
for msg in history:
if isinstance(msg, HumanMessage):
lines.append(f"来访者:{msg.content}")
elif isinstance(msg, AIMessage):
lines.append(f"咨询师:{msg.content}")
return "\n".join(lines) if lines else "(无)"
def _parse_json(self, text):
content = text.strip()
start = content.find("{")
end = content.rfind("}") + 1
if start == -1 or end == 0:
raise ValueError(f"无法解析 JSON: {content[:100]}")
return json.loads(content[start:end])
# ===== PUCT 核心 =====
def compute_ucb(self, paths):
"""为每条路径计算 UCB 分数"""
seed_counts = Counter(p["id"] for p in paths)
n_total = len(paths)
for p in paths:
q = p.get("_quick_score", p.get("score", 1)) / 10.0
n_seed = seed_counts[p["id"]]
exploration = self.c_puct * math.sqrt(n_total) / (1 + n_seed)
p["ucb"] = q + exploration
return paths
def allocate_budget(self, paths, total_budget, min_per_seed=1):
"""按 UCB 分配展开预算,保证每个种子至少 min_per_seed 个名额"""
self.compute_ucb(paths)
ranked = sorted(paths, key=lambda x: -x["ucb"])
# 保底:每个种子至少选一条
selected = []
seeds_seen = set()
for p in ranked:
if p["id"] not in seeds_seen:
selected.append(p)
seeds_seen.add(p["id"])
if len(selected) >= total_budget:
return selected
# 剩余预算按 UCB 排序填充
remaining = total_budget - len(selected)
for p in ranked:
if remaining <= 0:
break
if p not in selected:
selected.append(p)
remaining -= 1
return selected
# ===== 快速评分 =====
def _quick_score_one(self, text, current_disclosure):
prompt = QUICK_EVAL_PROMPT.replace(
"{current_disclosure_score}", str(current_disclosure)
).replace("{user_message}", text)
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
return max(1, min(10, int(parsed.get("score", 1))))
except Exception:
return 1
def quick_score_paths(self, paths, text_key, current_disclosure):
"""对路径列表并行快速评分,结果写入 _quick_score 字段"""
with ThreadPoolExecutor(max_workers=len(paths)) as executor:
futures = {executor.submit(self._quick_score_one, p[text_key], current_disclosure): i
for i, p in enumerate(paths)}
for future in as_completed(futures):
idx = futures[future]
paths[idx]["_quick_score"] = future.result()
return paths
# ===== Step 1: 总结 + 3个种子 =====
def summarize_and_seeds(self, history):
prompt = SUMMARY_AND_SEEDS_PROMPT.replace(
"{conversation_history}", self._format_history(history)
)
for attempt in range(3):
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
summary = parsed.pop("summary", "总结失败")
seeds = {k: v for k, v in parsed.items() if k in "ABC"}
for k in "ABC":
seeds.setdefault(k, "从你独特的临床视角出发")
return summary, seeds
except (json.JSONDecodeError, ValueError):
if attempt == 2:
return "总结失败", {k: "从你独特的临床视角出发" for k in "ABC"}
# ===== Step 2 / L1: 3×咨询师回复 =====
def _gen_therapist_reply(self, seed_id, seed, history_text):
prompt = THERAPIST_REPLY_PROMPT.replace(
"{conversation_history}", history_text
).replace("{seed_perspective}", seed)
try:
result = self.llm.invoke(prompt)
return {"id": seed_id, "seed": seed, "reply": result.content.strip()}
except Exception as e:
return {"id": seed_id, "seed": seed, "reply": f"(生成失败: {e})"}
def generate_l1(self, seeds, history):
history_text = self._format_history(history)
results = []
with ThreadPoolExecutor(max_workers=3) as executor:
futures = {executor.submit(self._gen_therapist_reply, sid, seeds[sid], history_text): sid for sid in seeds}
for future in as_completed(futures):
results.append(future.result())
results.sort(key=lambda x: x["id"])
return results
# ===== Step 3 / L2: 合并方向+来访者回应 =====
def _gen_l2_merged(self, l1_item, history_text):
prompt = L2_MERGED_PROMPT.replace(
"{conversation_history}", history_text
).replace("{therapist_reply}", l1_item["reply"])
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
return [{**l1_item, "l2_dir": did, "client_response": parsed.get(did, "(模拟失败)")}
for did in ["A", "B"]]
except Exception:
return [{**l1_item, "l2_dir": d, "client_response": "(模拟失败)"} for d in ["A", "B"]]
def generate_l2(self, l1_results, history):
history_text = self._format_history(history)
results = []
with ThreadPoolExecutor(max_workers=3) as executor:
futures = {executor.submit(self._gen_l2_merged, item, history_text): item["id"] for item in l1_results}
for future in as_completed(futures):
results.extend(future.result())
results.sort(key=lambda x: (x["id"], x["l2_dir"]))
return results
# ===== Step 4 / L3: 合并种子+咨询师延续 =====
def _gen_l3_merged(self, l2_item, history_text):
prompt = L3_MERGED_PROMPT.replace(
"{conversation_history}", history_text
).replace("{l1_therapist_reply}", l2_item["reply"]
).replace("{l2_client_response}", l2_item["client_response"])
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
return [{
"id": l2_item["id"], "l2_dir": l2_item["l2_dir"], "branch": bid,
"seed": l2_item["seed"], "l1_reply": l2_item["reply"],
"l2_client": l2_item["client_response"], "l3_reply": parsed.get(bid, "(生成失败)"),
} for bid in ["A", "B", "C"]]
except Exception:
return [{
"id": l2_item["id"], "l2_dir": l2_item["l2_dir"], "branch": b,
"seed": l2_item["seed"], "l1_reply": l2_item["reply"],
"l2_client": l2_item["client_response"], "l3_reply": "(生成失败)",
} for b in ["A", "B", "C"]]
def generate_l3(self, l2_selected, history):
history_text = self._format_history(history)
results = []
with ThreadPoolExecutor(max_workers=len(l2_selected)) as executor:
futures = {executor.submit(self._gen_l3_merged, item, history_text): (item["id"], item["l2_dir"])
for item in l2_selected}
for future in as_completed(futures):
results.extend(future.result())
results.sort(key=lambda x: (x["id"], x["l2_dir"], x["branch"]))
return results
# ===== Step 5 / L4: 合并方向+来访者回应 =====
def _gen_l4_merged(self, l3_item, history_text):
prompt = L4_MERGED_PROMPT.replace(
"{conversation_history}", history_text
).replace("{l1_therapist_reply}", l3_item["l1_reply"]
).replace("{l2_client_response}", l3_item["l2_client"]
).replace("{l3_therapist_reply}", l3_item["l3_reply"])
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
return [{**l3_item, "l4_dir": did, "l4_client": parsed.get(did, "(模拟失败)")}
for did in ["A", "B"]]
except Exception:
return [{**l3_item, "l4_dir": d, "l4_client": "(模拟失败)"} for d in ["A", "B"]]
def generate_l4(self, l3_selected, history):
history_text = self._format_history(history)
results = []
with ThreadPoolExecutor(max_workers=len(l3_selected)) as executor:
futures = {executor.submit(self._gen_l4_merged, item, history_text): (item["id"], item["l2_dir"], item["branch"])
for item in l3_selected}
for future in as_completed(futures):
results.extend(future.result())
results.sort(key=lambda x: (x["id"], x["l2_dir"], x["branch"], x.get("l4_dir", "")))
return results
# ===== Step 5.5: 终评 =====
def _score_relative(self, item, current_disclosure):
prompt = RELATIVE_DISCLOSURE_EVAL_PROMPT.replace(
"{current_disclosure_score}", str(current_disclosure)
).replace("{user_message}", item["l4_client"])
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
score = max(1, min(10, int(parsed.get("score", 1))))
dims = {k: parsed.get(k, False) for k in "ABCDE"}
return {**item, "score": score, "delta": score - current_disclosure,
"dims": dims, "reason": parsed.get("reasoning", "")}
except Exception:
return {**item, "score": 1, "delta": 1 - current_disclosure,
"dims": {}, "reason": "评分失败"}
def score_all(self, l4_results, current_disclosure=1):
results = []
with ThreadPoolExecutor(max_workers=len(l4_results)) as executor:
futures = {executor.submit(self._score_relative, item, current_disclosure): i
for i, item in enumerate(l4_results)}
for future in as_completed(futures):
results.append(future.result())
results.sort(key=lambda x: (x["id"], x.get("l2_dir", ""), x["branch"]))
return results
# ===== Step 6: 高UCB路径深度探索 (L5+L6) =====
def _gen_l5_l6(self, item, history_text):
prompt = L5_L6_MERGED_PROMPT.replace(
"{conversation_history}", history_text
).replace("{l1_therapist_reply}", item["l1_reply"]
).replace("{l2_client_response}", item["l2_client"]
).replace("{l3_therapist_reply}", item["l3_reply"]
).replace("{l4_client_response}", item["l4_client"])
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
return {**item,
"l5_reply": parsed.get("l5_reply", "(生成失败)"),
"l6_client": parsed.get("l6_client", "(模拟失败)"),
"depth": 6}
except Exception:
return {**item, "l5_reply": "(生成失败)", "l6_client": "(模拟失败)", "depth": 6}
def _score_l6(self, item, current_disclosure):
prompt = RELATIVE_DISCLOSURE_EVAL_PROMPT.replace(
"{current_disclosure_score}", str(current_disclosure)
).replace("{user_message}", item["l6_client"])
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
score = max(1, min(10, int(parsed.get("score", 1))))
return {**item, "l6_score": score, "l6_delta": score - current_disclosure,
"reason": parsed.get("reasoning", item.get("reason", ""))}
except Exception:
return {**item, "l6_score": item.get("score", 1), "l6_delta": 0}
def deep_explore(self, top_paths, history, current_disclosure):
"""对 top UCB 路径进行 L5+L6 深度探索"""
history_text = self._format_history(history)
# 并行生成 L5+L6
deep_results = []
with ThreadPoolExecutor(max_workers=len(top_paths)) as executor:
futures = {executor.submit(self._gen_l5_l6, item, history_text): i
for i, item in enumerate(top_paths)}
for future in as_completed(futures):
deep_results.append(future.result())
# 并行评分 L6
scored = []
with ThreadPoolExecutor(max_workers=len(deep_results)) as executor:
futures = {executor.submit(self._score_l6, item, current_disclosure): i
for i, item in enumerate(deep_results)}
for future in as_completed(futures):
scored.append(future.result())
return scored
# ===== Step 7: 蒸馏(UCB加权,深度路径×2) =====
def distill_paths(self, scored_4layer, deep_paths, summary):
"""合并4层和6层路径,按UCB加权选择蒸馏输入"""
# 4层有效路径
effective_4 = [item for item in scored_4layer if item.get("delta", 0) > 0]
# 6层路径(权重×2,复制一份进入排名)
effective_6 = []
for item in deep_paths:
item["_distill_weight"] = 2
effective_6.append(item)
all_effective = effective_4 + effective_6
if not all_effective:
# 退化:取4层最高分
ranked = sorted(scored_4layer, key=lambda x: x["score"], reverse=True)
all_effective = [ranked[0]] if ranked else []
# 按加权分数降序:6层路径分数×1.5(深度奖励)
def sort_key(x):
base = x.get("l6_score", x.get("score", 0))
depth_bonus = 1.5 if x.get("depth") == 6 else 1.0
return base * depth_bonus
ranked = sorted(all_effective, key=sort_key, reverse=True)
top = ranked[:5]
# 格式化
path_texts = []
for i, item in enumerate(top, 1):
depth = item.get("depth", 4)
if depth == 6:
path_texts.append(
f"路径{i}(种子{item['id']}.{item['branch']},深度=6轮,揭露度+{item.get('l6_delta', 0)}):\n"
f" 咨询师①:{item['l1_reply']}\n"
f" 来访者①:{item['l2_client']}\n"
f" 咨询师②:{item['l3_reply']}\n"
f" 来访者②:{item['l4_client']}\n"
f" 咨询师③:{item['l5_reply']}\n"
f" 来访者③:{item['l6_client']}"
)
else:
path_texts.append(
f"路径{i}(种子{item['id']}.{item['branch']},深度=4轮,揭露度+{item.get('delta', 0)}):\n"
f" 咨询师①:{item['l1_reply']}\n"
f" 来访者①:{item['l2_client']}\n"
f" 咨询师②:{item['l3_reply']}\n"
f" 来访者②:{item['l4_client']}"
)
effective_paths_text = "\n\n".join(path_texts)
prompt = PATH_DISTILLATION_PROMPT.replace(
"{summary}", summary
).replace("{effective_paths}", effective_paths_text)
n_deep = sum(1 for t in top if t.get("depth") == 6)
seeds_in = set(t["id"] for t in top)
print(f"[PUCT] 蒸馏输入: {len(top)}条路径({n_deep}条6轮深度, {len(seeds_in)}个种子覆盖)")
try:
result = self.llm.invoke(prompt)
parsed = self._parse_json(result.content)
parsed["_distill_count"] = len(top)
parsed["_deep_count"] = n_deep
parsed["_distill_ids"] = [f"{i['id']}.{i['branch']}" for i in top]
return parsed
except Exception:
best = top[0] if top else scored_4layer[0]
return {
"direction": best.get("seed", ""),
"principles": [f"沿着「{best.get('seed', '')}」的方向继续探索"],
"evidence": f"模拟显示{len(top)}条路径有效",
"_distill_count": len(top), "_deep_count": 0,
"_distill_ids": [f"{best['id']}.{best['branch']}"],
}
# ===== 完整 PUCT 流程 =====
def run(self, history, current_disclosure=1):
total_start = time.time()
# Step 1: 总结 + 3种子
t = time.time()
summary, seeds = self.summarize_and_seeds(history)
t1 = time.time() - t
print(f"[PUCT] Step1 总结+种子: {t1:.1f}s | {summary[:60]}")
for sid, seed in seeds.items():
print(f" {sid}: {seed[:50]}")
# Step 2 / L1: 3×咨询师
t = time.time()
l1 = self.generate_l1(seeds, history)
t2 = time.time() - t
print(f"[PUCT] L1 {len(l1)}×咨询师: {t2:.1f}s")
# Step 3 / L2: 6×来访者
t = time.time()
l2 = self.generate_l2(l1, history)
t3 = time.time() - t
print(f"[PUCT] L2 {len(l2)}×来访者: {t3:.1f}s")
# Step 3.5: L2 快速评分
t = time.time()
l2 = self.quick_score_paths(l2, "client_response", current_disclosure)
t3_5 = time.time() - t
print(f"[PUCT] L2快评: {t3_5:.1f}s")
for item in l2:
print(f" L2-{item['id']}.{item['l2_dir']}: qs={item['_quick_score']} | {item['client_response'][:30]}")
# Step 4: UCB选择 → L3 (预算≤6条L2进入L3)
l2_budget = min(6, len(l2)) # 最多全选
l2_selected = self.allocate_budget(l2, l2_budget)
print(f"[PUCT] UCB选择L2→L3: {len(l2_selected)}条 (from {len(l2)})")
for item in l2_selected:
print(f" 选中 {item['id']}.{item['l2_dir']}: ucb={item['ucb']:.2f} qs={item['_quick_score']}")
t = time.time()
l3 = self.generate_l3(l2_selected, history)
t4 = time.time() - t
print(f"[PUCT] L3 {len(l3)}×咨询师: {t4:.1f}s")
# Step 4.5: L3 快速评分(评估咨询师回应的推动效果)
t = time.time()
l3 = self.quick_score_paths(l3, "l3_reply", current_disclosure)
t4_5 = time.time() - t
print(f"[PUCT] L3快评: {t4_5:.1f}s")
# Step 5: UCB选择 → L4 (预算≤12条L3进入L4)
l3_budget = min(12, len(l3))
l3_selected = self.allocate_budget(l3, l3_budget)
print(f"[PUCT] UCB选择L3→L4: {len(l3_selected)}条 (from {len(l3)})")
t = time.time()
l4 = self.generate_l4(l3_selected, history)
t5 = time.time() - t
print(f"[PUCT] L4 {len(l4)}×来访者: {t5:.1f}s")
# Step 5.5: L4 终评
t = time.time()
scored = self.score_all(l4, current_disclosure)
t5_5 = time.time() - t
print(f"[PUCT] L4终评({len(scored)}条): {t5_5:.1f}s")
for item in sorted(scored, key=lambda x: -x["score"])[:5]:
print(f" {item['id']}.{item.get('l2_dir','')}.{item['branch']}: score={item['score']} delta={item['delta']}")
# 选当前最优
groups = defaultdict(list)
for item in scored:
groups[item["id"]].append(item)
seed_best = {sid: max(items, key=lambda x: x["score"]) for sid, items in groups.items()}
best = max(seed_best.values(), key=lambda x: x["score"])
print(f"[PUCT] 4层最优: {best['id']}.{best.get('l2_dir','')}.{best['branch']} score={best['score']} delta={best['delta']}")
# Step 6: 高UCB路径深度探索 (L5+L6)
# 从终评结果中选 top-3 by UCB
self.compute_ucb(scored)
top3 = sorted(scored, key=lambda x: -x["ucb"])[:3]
top3_desc = [f"{p['id']}.{p.get('l2_dir','')}.{p['branch']}(ucb={p['ucb']:.2f})" for p in top3]
print(f"[PUCT] 深度探索 top-3: {top3_desc}")
t = time.time()
deep_paths = self.deep_explore(top3, history, current_disclosure)
t6 = time.time() - t
print(f"[PUCT] L5+L6深探: {t6:.1f}s")
for dp in deep_paths:
print(f" 深探 {dp['id']}.{dp['branch']}: L5={dp['l5_reply'][:30]} → L6 score={dp.get('l6_score','?')} delta={dp.get('l6_delta','?')}")
# 更新 best(如果深度路径更好)
for dp in deep_paths:
if dp.get("l6_score", 0) > best.get("score", 0):
best = dp
print(f"[PUCT] 深探更优: {dp['id']}.{dp['branch']} l6_score={dp['l6_score']}")
# Step 7: 蒸馏
t = time.time()
guidance = self.distill_paths(scored, deep_paths, summary)
t7 = time.time() - t
print(f"[PUCT] 蒸馏: {t7:.1f}s")
print(f" 方向: {guidance.get('direction', '?')}")
for p in guidance.get("principles", []):
print(f" 原则: {p}")
total_cost = time.time() - total_start
print(f"[PUCT] 总耗时: {total_cost:.1f}s")
strategic_trace = {
"summary": summary,
"seeds": seeds,
"candidates": [
{
"id": item["id"], "branch": item["branch"],
"l1_reply": item["l1_reply"], "l2_client": item["l2_client"],
"l3_reply": item["l3_reply"], "l4_client": item["l4_client"],
"score": item["score"], "delta": item["delta"], "reason": item.get("reason", ""),
}
for item in scored
],
"deep_paths": [
{
"id": dp["id"], "branch": dp["branch"],
"l5_reply": dp.get("l5_reply", ""), "l6_client": dp.get("l6_client", ""),
"l6_score": dp.get("l6_score", 0), "l6_delta": dp.get("l6_delta", 0),
}
for dp in deep_paths
],
"selected": f"{best['id']}.{best.get('l2_dir','')}.{best['branch']}",
"guidance": guidance,
"current_disclosure": current_disclosure,
"timing": {
"total_seconds": round(total_cost, 1),
"step1_summary_seeds": round(t1, 1),
"L1_therapist": round(t2, 1),
"L2_merged": round(t3, 1),
"L2_quick_score": round(t3_5, 1),
"L3_merged": round(t4, 1),
"L3_quick_score": round(t4_5, 1),
"L4_merged": round(t5, 1),
"L4_final_score": round(t5_5, 1),
"L5_L6_deep": round(t6, 1),
"distillation": round(t7, 1),
},
}
return best, guidance, strategic_trace