|
|
"""Verification Node - Final quality control and output formatting""" |
|
|
from typing import Dict, Any |
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
|
from langchain_groq import ChatGroq |
|
|
from src.tracing import get_langfuse_callback_handler |
|
|
|
|
|
|
|
|
def load_verification_prompt() -> str: |
|
|
"""Load the verification prompt from file""" |
|
|
try: |
|
|
with open("./prompts/verification_prompt.txt", "r", encoding="utf-8") as f: |
|
|
return f.read().strip() |
|
|
except FileNotFoundError: |
|
|
return """You are a verification agent. Ensure responses meet quality standards and format requirements.""" |
|
|
|
|
|
|
|
|
def extract_final_answer(response_content: str) -> str: |
|
|
"""Extract and format the final answer according to system prompt requirements""" |
|
|
|
|
|
answer = response_content.strip() |
|
|
|
|
|
|
|
|
answer = answer.replace("**", "").replace("*", "") |
|
|
|
|
|
|
|
|
prefixes_to_remove = [ |
|
|
"Final Answer:", "Answer:", "The answer is:", "The final answer is:", |
|
|
"Result:", "Solution:", "Response:", "Output:", "Conclusion:" |
|
|
] |
|
|
|
|
|
for prefix in prefixes_to_remove: |
|
|
if answer.lower().startswith(prefix.lower()): |
|
|
answer = answer[len(prefix):].strip() |
|
|
|
|
|
|
|
|
answer = answer.strip('"\'()[]{}') |
|
|
|
|
|
|
|
|
if '\n' in answer and all(line.strip().startswith(('-', '*', '•')) for line in answer.split('\n') if line.strip()): |
|
|
|
|
|
items = [line.strip().lstrip('-*•').strip() for line in answer.split('\n') if line.strip()] |
|
|
answer = ', '.join(items) |
|
|
|
|
|
return answer.strip() |
|
|
|
|
|
|
|
|
def verification_node(state: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Verification node that performs final quality control and formatting |
|
|
""" |
|
|
print("Verification Node: Performing final quality control") |
|
|
|
|
|
try: |
|
|
|
|
|
verification_prompt = load_verification_prompt() |
|
|
|
|
|
|
|
|
llm = ChatGroq(model="qwen-qwq-32b", temperature=0.0) |
|
|
|
|
|
|
|
|
callback_handler = get_langfuse_callback_handler() |
|
|
callbacks = [callback_handler] if callback_handler else [] |
|
|
|
|
|
|
|
|
messages = state.get("messages", []) |
|
|
quality_pass = state.get("quality_pass", True) |
|
|
quality_score = state.get("quality_score", 7) |
|
|
critic_assessment = state.get("critic_assessment", "") |
|
|
|
|
|
|
|
|
agent_response = state.get("agent_response") |
|
|
if not agent_response: |
|
|
|
|
|
for msg in reversed(messages): |
|
|
if msg.type == "ai": |
|
|
agent_response = msg |
|
|
break |
|
|
|
|
|
if not agent_response: |
|
|
print("Verification Node: No response to verify") |
|
|
return { |
|
|
**state, |
|
|
"final_answer": "No response found to verify", |
|
|
"verification_status": "failed", |
|
|
"current_step": "complete" |
|
|
} |
|
|
|
|
|
|
|
|
user_query = None |
|
|
for msg in reversed(messages): |
|
|
if msg.type == "human": |
|
|
user_query = msg.content |
|
|
break |
|
|
|
|
|
|
|
|
failure_threshold = 4 |
|
|
max_attempts = state.get("attempt_count", 1) |
|
|
|
|
|
if not quality_pass or quality_score < failure_threshold: |
|
|
if max_attempts >= 3: |
|
|
print("Verification Node: Maximum attempts reached, proceeding with fallback") |
|
|
return { |
|
|
**state, |
|
|
"final_answer": "Unable to provide a satisfactory answer after multiple attempts", |
|
|
"verification_status": "failed_max_attempts", |
|
|
"current_step": "fallback" |
|
|
} |
|
|
else: |
|
|
print(f"Verification Node: Quality check failed (score: {quality_score}), retrying") |
|
|
return { |
|
|
**state, |
|
|
"verification_status": "failed", |
|
|
"attempt_count": max_attempts + 1, |
|
|
"current_step": "routing" |
|
|
} |
|
|
|
|
|
|
|
|
print("Verification Node: Quality check passed, formatting final answer") |
|
|
|
|
|
|
|
|
verification_messages = [SystemMessage(content=verification_prompt)] |
|
|
|
|
|
verification_request = f""" |
|
|
Please verify and format the following response according to the exact-match output rules: |
|
|
|
|
|
Original Query: {user_query or "Unknown query"} |
|
|
|
|
|
Response to Verify: |
|
|
{agent_response.content} |
|
|
|
|
|
Quality Assessment: {critic_assessment} |
|
|
|
|
|
Ensure the final output strictly adheres to the format requirements specified in the system prompt. |
|
|
""" |
|
|
|
|
|
verification_messages.append(HumanMessage(content=verification_request)) |
|
|
|
|
|
|
|
|
verification_response = llm.invoke(verification_messages, config={"callbacks": callbacks}) |
|
|
|
|
|
|
|
|
final_answer = extract_final_answer(verification_response.content) |
|
|
|
|
|
|
|
|
return { |
|
|
**state, |
|
|
"messages": messages + [verification_response], |
|
|
"final_answer": final_answer, |
|
|
"verification_status": "passed", |
|
|
"current_step": "complete" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Verification Node Error: {e}") |
|
|
|
|
|
if agent_response: |
|
|
fallback_answer = extract_final_answer(agent_response.content) |
|
|
else: |
|
|
fallback_answer = f"Error during verification: {e}" |
|
|
|
|
|
return { |
|
|
**state, |
|
|
"final_answer": fallback_answer, |
|
|
"verification_status": "error", |
|
|
"current_step": "complete" |
|
|
} |
|
|
|
|
|
|
|
|
def should_retry(state: Dict[str, Any]) -> bool: |
|
|
"""Determine if we should retry the process""" |
|
|
verification_status = state.get("verification_status", "") |
|
|
return verification_status == "failed" and state.get("attempt_count", 1) < 3 |