File size: 7,294 Bytes
408f650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from prompts import (
    CANDIDATE_GENERATION_PROMPT,
    CLIENT_SIMULATOR_PROMPT,
    MCTS_EVALUATOR_PROMPT,
)


class MCTSReasoner:
    def __init__(self):
        base_kwargs = dict(
            model="deepseek-chat",
            base_url="https://api.deepseek.com/v1",
            api_key=os.getenv("DEEPSEEK_API_KEY"),
        )
        self.gen_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=1024)
        self.sim_llm = ChatOpenAI(**base_kwargs, temperature=0.7, max_tokens=256)
        self.eval_llm = ChatOpenAI(**base_kwargs, temperature=0.0, max_tokens=64)

    def _format_history(self, history):
        """将 langchain 消息历史格式化为可读文本(跳过 system message)。"""
        lines = []
        for msg in history:
            if isinstance(msg, HumanMessage):
                lines.append(f"来访者:{msg.content}")
            elif isinstance(msg, AIMessage):
                lines.append(f"咨询师:{msg.content}")
        return "\n".join(lines) if lines else "(首次对话)"

    def _parse_json(self, text):
        """从 LLM 输出中提取 JSON。"""
        content = text.strip()
        start = content.find("[") if "[" in content else content.find("{")
        end = content.rfind("]") + 1 if "[" in content else content.rfind("}") + 1
        if start == -1 or end == 0:
            raise ValueError(f"无法解析 JSON: {content[:100]}")
        return json.loads(content[start:end])

    def generate_candidates(self, history, user_message):
        """Step 1: 生成 5 个候选咨询师回复。"""
        prompt = CANDIDATE_GENERATION_PROMPT.replace(
            "{conversation_history}", self._format_history(history)
        ).replace("{user_message}", user_message)

        for attempt in range(3):
            try:
                result = self.gen_llm.invoke(prompt)
                candidates = self._parse_json(result.content)
                return candidates
            except (json.JSONDecodeError, ValueError):
                if attempt == 2:
                    raise

    def _simulate_one(self, candidate, history_text, user_message):
        """模拟单个候选回复的来访者反应。"""
        prompt = CLIENT_SIMULATOR_PROMPT.replace(
            "{conversation_history}", history_text
        ).replace(
            "{user_message}", user_message
        ).replace(
            "{therapist_response}", candidate["response"]
        )

        for attempt in range(2):
            try:
                result = self.sim_llm.invoke(prompt)
                parsed = self._parse_json(result.content)
                return {
                    "id": candidate["id"],
                    "simulated_client_response": parsed.get("simulated_response", ""),
                    "emotional_state": parsed.get("emotional_state", ""),
                }
            except (json.JSONDecodeError, ValueError):
                if attempt == 1:
                    return {
                        "id": candidate["id"],
                        "simulated_client_response": "(模拟失败)",
                        "emotional_state": "未知",
                    }

    def simulate_client_reactions(self, candidates, history, user_message):
        """Step 2: 并行模拟来访者对每个候选回复的反应。"""
        history_text = self._format_history(history)
        simulations = []

        with ThreadPoolExecutor(max_workers=5) as executor:
            futures = {
                executor.submit(
                    self._simulate_one, c, history_text, user_message
                ): c["id"]
                for c in candidates
            }
            for future in as_completed(futures):
                simulations.append(future.result())

        simulations.sort(key=lambda x: x["id"])
        return simulations

    def _evaluate_one(self, simulation):
        """评估单个模拟反应的揭露深度。"""
        prompt = MCTS_EVALUATOR_PROMPT.replace(
            "{client_response}", simulation["simulated_client_response"]
        )

        for attempt in range(2):
            try:
                result = self.eval_llm.invoke(prompt)
                parsed = self._parse_json(result.content)
                return {
                    "id": simulation["id"],
                    "score": max(0, min(10, int(parsed.get("score", 0)))),
                    "reason": parsed.get("reason", ""),
                }
            except (json.JSONDecodeError, ValueError):
                if attempt == 1:
                    return {"id": simulation["id"], "score": 0, "reason": "评估解析失败"}

    def evaluate_disclosures(self, simulations):
        """Step 3: 并行评估每个模拟反应的揭露深度。"""
        evaluations = []

        with ThreadPoolExecutor(max_workers=5) as executor:
            futures = {
                executor.submit(self._evaluate_one, s): s["id"] for s in simulations
            }
            for future in as_completed(futures):
                evaluations.append(future.result())

        evaluations.sort(key=lambda x: x["id"])
        return evaluations

    def select_best(self, candidates, simulations, evaluations):
        """Step 4: 选择最高分候选。分数相同时优先选择情感深度更高的。"""
        max_score = max(e["score"] for e in evaluations)
        top_candidates = [e for e in evaluations if e["score"] == max_score]

        if len(top_candidates) == 1:
            best_id = top_candidates[0]["id"]
            reason = "最高揭露深度评分"
        else:
            # 同分时,找模拟反应中情感状态描述最长的(粗略代理情感深度)
            best_id = top_candidates[0]["id"]
            max_depth = 0
            for tc in top_candidates:
                sim = next((s for s in simulations if s["id"] == tc["id"]), None)
                if sim:
                    depth = len(sim.get("emotional_state", ""))
                    if depth > max_depth:
                        max_depth = depth
                        best_id = tc["id"]
            reason = "同分中情感深度更高"

        best_response = next(c["response"] for c in candidates if c["id"] == best_id)
        return best_id, best_response, reason

    def run(self, history, user_message):
        """执行完整 MCTS 推理流程,返回最优回复和完整 trace。"""
        # Step 1
        candidates = self.generate_candidates(history, user_message)

        # Step 2
        simulations = self.simulate_client_reactions(candidates, history, user_message)

        # Step 3
        evaluations = self.evaluate_disclosures(simulations)

        # Step 4
        selected_id, best_response, selection_reason = self.select_best(
            candidates, simulations, evaluations
        )

        mcts_trace = {
            "candidates": candidates,
            "simulations": simulations,
            "evaluations": evaluations,
            "selected": selected_id,
            "selection_reason": selection_reason,
        }

        return best_response, mcts_trace