File size: 11,657 Bytes
ebb8326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55f1010
ebb8326
55f1010
ebb8326
 
 
 
 
 
 
 
 
 
55f1010
 
 
 
 
 
 
ebb8326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Logic solver node implementing a Manual Code Execution workflow."""

import re
import string

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_experimental.utilities import PythonREPL

from src.data_processing.answer import extract_answer
from src.data_processing.formatting import format_choices
from src.state import GraphState
from src.utils.llm import get_large_model
from src.utils.logging import print_log
from src.utils.prompts import load_prompt

_python_repl = PythonREPL()


def extract_python_code(text: str) -> str | None:
    """Find and extract Python code from block ``` python ...   ```"""
    match = re.search(r"```(?:python)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    return None


def _validate_code_syntax(code: str) -> tuple[bool, str]:
    """Check if code has valid Python syntax. Returns (is_valid, error_message)."""
    try:
        compile(code, "<string>", "exec")
        return True, ""
    except SyntaxError as e:
        return False, str(e)


def _is_placeholder_code(code: str) -> bool:
    """Check if code contains placeholders or is incomplete."""
    if not code or len(code.strip()) < 10:
        return True
    if "..." in code:
        return True
    # Check for {key}-style placeholders (but not f-string or dict literals)
    if re.search(r"\{[a-zA-Z_][a-zA-Z0-9_]*\}", code):
        # Exclude common dict/set patterns and f-strings
        if not re.search(r'["\'][^"\']*\{[a-zA-Z_]', code):
            return True
    return False


def _indent_code(code: str) -> str:
    """Format code to make it easier to read in the terminal."""
    return "\n".join(f"        {line}" for line in code.splitlines())


def _fallback_text_reasoning(llm, question: str, choices_text: str) -> dict:
    """Fallback to CoT reasoning when code execution fails."""
    print_log("        [Logic] Falling back to CoT reasoning...")

    fallback_system = (
        "Nhiệm vụ của bạn là trả lời câu hỏi "
        "được đưa ra bằng khả năng phân tích và suy luận logic. "
        "Hãy phân tích vấn đề và suy luận đề từng bước một. " 
        "Cuối cùng, hãy trả lời theo đúng định dạng: 'Đáp án: X' "
        "trong đó X là ký tự đại diện cho lựa chọn đúng (A, B, C, D, ...)."
    )

    fallback_user = (
        f"Câu hỏi: {question}\n"
        f"{choices_text}"
    )

    fallback_messages: list[BaseMessage] = [
        SystemMessage(content=fallback_system),
        HumanMessage(content=fallback_user)
    ]

    fallback_response = llm.invoke(fallback_messages)
    fallback_content = fallback_response.content
    print_log(f"        [Logic] Fallback response received.")

    return {"text": fallback_content}


def _request_final_answer(llm, question: str, choices_text: str, computed_results: str) -> str:
    """Request a strict final answer from the model."""
    system_prompt = (
        "Bạn là trợ lý AI. Dựa vào kết quả tính toán được cung cấp, "
        "hãy đưa ra đáp án cuối cùng. CHỈ trả lời đúng một dòng: Đáp án: X "
        "(trong đó X là A, B, C hoặc D)."
    )
    user_prompt = (
        f"Câu hỏi: {question}\n"
        f"{choices_text}\n"
        f"Kết quả tính toán: {computed_results}\n\n"
        "Trả lời đúng một dòng: Đáp án: X"
    )
    
    messages: list[BaseMessage] = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=user_prompt)
    ]
    
    response = llm.invoke(messages)
    return response.content


def logic_solver_node(state: GraphState) -> dict:
    """Solve math/logic questions using Python code execution."""
    llm = get_large_model()
    all_choices = state["all_choices"]
    num_choices = len(all_choices)
    choices_text = format_choices(all_choices)
    is_chat_mode = num_choices == 0  # Chat mode when no choices

    system_prompt = load_prompt("logic_solver.j2", "system", choices=choices_text)
    user_prompt = load_prompt("logic_solver.j2", "user", question=state["question"], choices=choices_text)

    messages: list[BaseMessage] = [
        SystemMessage(content=system_prompt),
        HumanMessage(content=user_prompt)
    ]

    step_texts: list[str] = []
    computed_outputs: list[str] = []

    # Chat mode: just invoke LLM and return natural response
    if is_chat_mode:
        print_log("        [Logic] Chat mode detected - returning natural response")
        response = llm.invoke(messages)
        content = response.content
        return {"answer": "", "raw_response": content, "route": "math"}

    max_steps = 5
    for step in range(max_steps):
        response = llm.invoke(messages)
        content = response.content
        step_texts.append(content)
        messages.append(response)

        code_block = extract_python_code(content)

        if code_block:
            if _is_placeholder_code(code_block):
                print_log(f"        [Logic] Step {step+1}: Placeholder code detected. Requesting complete code...")
                regen_msg = (
                    "Code không hợp lệ (chứa placeholder hoặc không đầy đủ). "
                    "Hãy cung cấp code Python hoàn chỉnh, có thể chạy được, không chứa '...' hay placeholder. "
                    "In ra các giá trị tính toán được. "
                    "Cuối cùng, kết thúc bằng một dòng duy nhất: Đáp án: X (X là A, B, C hoặc D)."
                )
                messages.append(HumanMessage(content=regen_msg))
                continue
            
            print_log(f"        [Logic] Step {step+1}: Found Python code. Executing...")
            
            # Validate syntax before execution
            is_valid, syntax_error = _validate_code_syntax(code_block)
            if not is_valid:
                print_log(f"        [Error] Syntax error detected: {syntax_error}")
                error_msg = f"SyntaxError: {syntax_error}. "
                error_msg += "Lưu ý: KHÔNG sử dụng các từ khóa Python như 'lambda', 'class', 'def' làm tên biến. "
                error_msg += "Hãy đổi tên biến và thử lại."
                messages.append(HumanMessage(content=error_msg))
                continue
            
            print_log(f"        [Logic] Code:\n{_indent_code(code_block)}")

            try:
                if "print" not in code_block:
                    lines = code_block.splitlines()
                    if lines:
                        last_line = lines[-1]
                        if "=" in last_line:
                            var_name = last_line.split("=")[0].strip()
                        else:
                            var_name = last_line.strip()
                        code_block += f"\nprint({var_name})"

                output = _python_repl.run(code_block)
                output = output.strip() if output else "No output."
                print_log(f"        [Logic] Code output: {output}")
                computed_outputs.append(output)

                # Do NOT extract answer from code output directly
                # Instead, feed output back to model and ask for final answer line
                feedback_msg = (
                    f"Kết quả thực thi code: {output}\n\n"
                    "Dựa vào kết quả trên, hãy so sánh với các đáp án và đưa ra câu trả lời cuối cùng. "
                    "Kết thúc bằng đúng một dòng: Đáp án: X (X là A, B, C hoặc D)."
                )
                messages.append(HumanMessage(content=feedback_msg))

            except Exception as e:
                error_msg = f"Error running code: {str(e)}"
                print_log(f"        [Error] {error_msg}")
                messages.append(HumanMessage(content=f"{error_msg}. Hãy kiểm tra logic và sửa lại code."))

            continue

        # Check if current step contains an explicit answer (only at end of response)
        step_answer = extract_answer(content, num_choices=num_choices, require_end=True)
        if step_answer:
            print_log(f"        [Logic] Step {step+1}: Found explicit answer: {step_answer}")
            combined_raw = "\n---STEP---\n".join(step_texts)
            return {"answer": step_answer, "raw_response": combined_raw, "route": "math"}

        # Also check if response contains clear conclusion without "Đáp án:" format
        if any(phrase in content.lower() for phrase in ["kết luận", "vậy đáp án", "do đó", "vì vậy"]):
            # Try to extract any single letter at end of response
            lines = content.strip().split('\n')
            for line in reversed(lines[-3:]):  # Check last 3 lines
                line = line.strip()
                if len(line) == 1 and line.upper() in string.ascii_uppercase[:num_choices]:
                    print_log(f"        [Logic] Step {step+1}: Found implicit answer: {line.upper()}")
                    combined_raw = "\n---STEP---\n".join(step_texts)
                    return {"answer": line.upper(), "raw_response": combined_raw, "route": "math"}

        if step < max_steps - 1:
            print_log("        [Warning] No code or answer found. Reminding model...")
            messages.append(HumanMessage(content="Lưu ý: Bạn vẫn chưa đưa ra đáp án cuối cùng. Hãy kết thúc bằng: Đáp án: X"))

    # Max steps reached - build combined_raw and try to extract answer
    print_log("        [Warning] Max steps reached. Attempting answer extraction from combined text...")
    
    # Build combined_raw from all steps
    combined_raw = "\n---STEP---\n".join(step_texts) if step_texts else ""
    
    # Try fallback text reasoning with error handling
    try:
        fallback_result = _fallback_text_reasoning(llm, state["question"], choices_text)
        fallback_text = fallback_result["text"]
        if fallback_text:
            combined_raw += "\n---FALLBACK---\n" + fallback_text
    except Exception as e:
        print_log(f"        [Error] Fallback reasoning failed: {e}")
        fallback_text = ""
    
    # Extract answer from the entire combined text (takes LAST explicit answer)
    final_answer = extract_answer(combined_raw, num_choices=num_choices)
    
    if final_answer:
        print_log(f"        [Logic] Extracted final answer from combined text: {final_answer}")
        return {"answer": final_answer, "raw_response": combined_raw, "route": "math"}
    
    # Still no answer - do one final strict LLM call with error handling
    print_log("        [Logic] No explicit answer found. Requesting strict final answer...")
    computed_str = "; ".join(computed_outputs) if computed_outputs else "Không có kết quả tính toán"
    try:
        strict_response = _request_final_answer(llm, state["question"], choices_text, computed_str)
        combined_raw += "\n---FINAL---\n" + strict_response
        
        final_answer = extract_answer(strict_response, num_choices=num_choices)
        if final_answer:
            print_log(f"        [Logic] Final strict answer: {final_answer}")
            return {"answer": final_answer, "raw_response": combined_raw, "route": "math"}
    except Exception as e:
        print_log(f"        [Error] Final answer request failed: {e}")
    
    # Absolute fallback - default to A
    print_log("        [Warning] All extraction attempts failed. Defaulting to A.")
    return {"answer": "A", "raw_response": combined_raw, "route": "math"}