File size: 7,435 Bytes
fe36046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b46018
 
1d1a146
2b46018
1d1a146
 
 
 
 
 
 
 
 
 
 
 
 
 
2b46018
fe36046
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Verification Node - Final quality control and output formatting"""
from typing import Dict, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_groq import ChatGroq
from src.tracing import get_langfuse_callback_handler


def load_verification_prompt() -> str:
    """Load the verification prompt from file"""
    try:
        with open("./prompts/verification_prompt.txt", "r", encoding="utf-8") as f:
            return f.read().strip()
    except FileNotFoundError:
        return """You are a verification agent. Ensure responses meet quality standards and format requirements."""


def extract_final_answer(response_content: str) -> str:
    """Extract and format the final answer according to system prompt requirements"""
    # Remove common prefixes and suffixes
    answer = response_content.strip()
    
    # Remove markdown formatting
    answer = answer.replace("**", "").replace("*", "")
    
    # Remove common answer prefixes
    prefixes_to_remove = [
        "Final Answer:", "Answer:", "The answer is:", "The final answer is:",
        "Result:", "Solution:", "Response:", "Output:", "Conclusion:"
    ]
    
    for prefix in prefixes_to_remove:
        if answer.lower().startswith(prefix.lower()):
            answer = answer[len(prefix):].strip()
    
    # Remove quotes and brackets if they wrap the entire answer
    answer = answer.strip('"\'()[]{}')
    
    # Handle lists - format with comma and space separation
    if '\n' in answer and all(line.strip().startswith(('-', '*', '•')) for line in answer.split('\n') if line.strip()):
        # Convert bullet list to comma-separated
        items = [line.strip().lstrip('-*•').strip() for line in answer.split('\n') if line.strip()]
        answer = ', '.join(items)
    
    # If there are still multiple lines, keep only the first non-empty line (to avoid explanations)
    if '\n' in answer:
        candidate = None
        for line in answer.split('\n'):
            if not line.strip():
                continue
            cleaned_line = line.strip()
            # Skip meta-thinking placeholders such as "<think>" or "[thinking]"
            lower_line = cleaned_line.lower()
            if lower_line in {"<think>", "think", "[thinking]", "<thinking>", "[think]"}:
                continue
            # Skip XML-like tags
            if lower_line.startswith("<") and lower_line.endswith(">"):
                continue
            candidate = cleaned_line
            break
        if candidate is not None:
            answer = candidate
    
    return answer.strip()


def verification_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Verification node that performs final quality control and formatting
    """
    print("Verification Node: Performing final quality control")
    
    try:
        # Get verification prompt
        verification_prompt = load_verification_prompt()
        
        # Initialize LLM for verification
        llm = ChatGroq(model="qwen-qwq-32b", temperature=0.0)  # Very low temp for consistent formatting
        
        # Get callback handler for tracing
        callback_handler = get_langfuse_callback_handler()
        callbacks = [callback_handler] if callback_handler else []
        
        # Get state information
        messages = state.get("messages", [])
        quality_pass = state.get("quality_pass", True)
        quality_score = state.get("quality_score", 7)
        critic_assessment = state.get("critic_assessment", "")
        
        # Get the agent response to verify
        agent_response = state.get("agent_response")
        if not agent_response:
            # Find the last AI message
            for msg in reversed(messages):
                if msg.type == "ai":
                    agent_response = msg
                    break
        
        if not agent_response:
            print("Verification Node: No response to verify")
            return {
                **state,
                "final_answer": "No response found to verify",
                "verification_status": "failed",
                "current_step": "complete"
            }
        
        # Get user query for context
        user_query = None
        for msg in reversed(messages):
            if msg.type == "human":
                user_query = msg.content
                break
        
        # Determine if we should proceed or trigger fallback
        failure_threshold = 4
        max_attempts = state.get("attempt_count", 1)
        
        if not quality_pass or quality_score < failure_threshold:
            if max_attempts >= 3:
                print("Verification Node: Maximum attempts reached, proceeding with fallback")
                return {
                    **state,
                    "final_answer": "Unable to provide a satisfactory answer after multiple attempts",
                    "verification_status": "failed_max_attempts",
                    "current_step": "fallback"
                }
            else:
                print(f"Verification Node: Quality check failed (score: {quality_score}), retrying")
                return {
                    **state,
                    "verification_status": "failed",
                    "attempt_count": max_attempts + 1,
                    "current_step": "routing"  # Retry from routing
                }
        
        # Quality passed, format the final answer
        print("Verification Node: Quality check passed, formatting final answer")
        
        # Build verification messages
        verification_messages = [SystemMessage(content=verification_prompt)]
        
        verification_request = f"""
Please verify and format the following response according to the exact-match output rules:

Original Query: {user_query or "Unknown query"}

Response to Verify:
{agent_response.content}

Quality Assessment: {critic_assessment}

Ensure the final output strictly adheres to the format requirements specified in the system prompt.
"""
        
        verification_messages.append(HumanMessage(content=verification_request))
        
        # Get verification response
        verification_response = llm.invoke(verification_messages, config={"callbacks": callbacks})
        
        # Extract and format the final answer
        final_answer = extract_final_answer(verification_response.content)
        
        # Store the final formatted answer
        return {
            **state,
            "messages": messages + [verification_response],
            "final_answer": final_answer,
            "verification_status": "passed",
            "current_step": "complete"
        }
        
    except Exception as e:
        print(f"Verification Node Error: {e}")
        # Fallback - try to extract answer from agent response
        if agent_response:
            fallback_answer = extract_final_answer(agent_response.content)
        else:
            fallback_answer = f"Error during verification: {e}"
        
        return {
            **state,
            "final_answer": fallback_answer,
            "verification_status": "error",
            "current_step": "complete"
        }


def should_retry(state: Dict[str, Any]) -> bool:
    """Determine if we should retry the process"""
    verification_status = state.get("verification_status", "")
    return verification_status == "failed" and state.get("attempt_count", 1) < 3