GeekBot / agents /verification_agent.py
abrar-adnan's picture
Update agents/verification_agent.py
f51eec3 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from typing import Dict, List
from langchain_core.documents.base import Document
from config.settings import settings
class VerificationAgent:
def __init__(self):
"""
Initialize the verification agent with LLM.
"""
print("Initializing RelevanceChecker with lightweight Hugging Face model...")
# Use a smaller, CPU-friendly model by default
model_name = getattr(settings, "HF_MODEL_NAME")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float32 on CPU (fp16 only works on GPU)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device)
print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.")
def sanitize_response(self, response_text: str) -> str:
"""
Sanitize the LLM's response by stripping unnecessary whitespace.
"""
return response_text.strip()
def generate_prompt(self, answer: str, context: str) -> str:
"""
Generate a structured prompt for the LLM to verify the answer against the context.
"""
prompt = f"""You are a strict verification agent. Your task is to verify if an answer is supported by the provided context.
CRITICAL RULES:
1. ONLY use information from the context provided below. Do NOT use any external knowledge or assumptions.
2. If a claim in the answer is NOT explicitly or implicitly supported by the context, mark it as unsupported.
3. If the answer contradicts information in the context, mark it as a contradiction.
4. If you cannot verify a claim using ONLY the context, mark it as unsupported.
5. Be strict - do not assume or infer beyond what is clearly stated in the context.
6. Respond EXACTLY in the format specified below - no additional text, explanations, or formatting.
**VERIFICATION FORMAT (follow exactly):**
Supported: YES
Unsupported Claims: []
Contradictions: []
Relevant: YES
Additional Details: None
OR if unsupported/contradictions found:
Supported: NO
Unsupported Claims: [list each unsupported claim exactly as it appears in the answer]
Contradictions: [list each contradiction exactly as it appears]
Relevant: YES or NO
Additional Details: [brief explanation of why claims are unsupported or contradicted]
**Answer to verify:**
{answer}
**Context (use ONLY this for verification):**
{context}
**Your verification (respond ONLY with the format above):**
"""
return prompt
def generate_with_hf(self, prompt: str, max_new_tokens=512) -> str:
"""
Generate output using the local Hugging Face model.
"""
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device)
outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def parse_verification_response(self, response_text: str) -> Dict:
"""
Parse the LLM's verification response into a structured dictionary.
"""
try:
# Normalize the response - remove markdown formatting, extra whitespace
response_text = response_text.strip()
# Remove any markdown code blocks if present
if response_text.startswith('```'):
lines = response_text.split('\n')
response_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else response_text
print(f"[DEBUG] Parsing verification response (first 500 chars): {response_text[:500]}")
verification = {}
lines = response_text.split('\n')
for line in lines:
line = line.strip()
if not line or not ':' in line:
continue
# Split on first colon only
parts = line.split(':', 1)
if len(parts) != 2:
continue
key = parts[0].strip()
value = parts[1].strip()
# Normalize key names (case-insensitive matching)
key_lower = key.lower()
if 'supported' in key_lower:
# Extract YES/NO, handle variations
value_upper = value.upper()
print(f"[DEBUG] Found 'Supported' key with value: '{value}' (upper: '{value_upper}')")
if 'YES' in value_upper or 'TRUE' in value_upper or 'Y' == value_upper.strip():
verification["Supported"] = "YES"
print(f"[DEBUG] Set Supported to YES")
elif 'NO' in value_upper or 'FALSE' in value_upper or 'N' == value_upper.strip():
verification["Supported"] = "NO"
print(f"[DEBUG] Set Supported to NO")
else:
# If value is empty or unclear, check if there are unsupported claims/contradictions
# If no issues found later, default to YES; otherwise NO
print(f"[DEBUG] Supported value unclear: '{value}', will decide based on claims/contradictions")
verification["Supported"] = None # Mark as undecided
elif 'unsupported' in key_lower:
# Handle list parsing
items = []
value = value.strip()
if value.lower() in ['none', 'n/a', '[]', '']:
items = []
elif value.startswith('[') and value.endswith(']'):
# Parse list items
list_content = value[1:-1].strip()
if list_content:
items = [item.strip().strip('"').strip("'").strip()
for item in list_content.split(',')
if item.strip()]
else:
# Single item or comma-separated without brackets
items = [item.strip().strip('"').strip("'")
for item in value.split(',')
if item.strip() and item.strip().lower() not in ['none', 'n/a']]
verification["Unsupported Claims"] = items
elif 'contradiction' in key_lower:
# Handle list parsing (same logic as unsupported)
items = []
value = value.strip()
if value.lower() in ['none', 'n/a', '[]', '']:
items = []
elif value.startswith('[') and value.endswith(']'):
list_content = value[1:-1].strip()
if list_content:
items = [item.strip().strip('"').strip("'").strip()
for item in list_content.split(',')
if item.strip()]
else:
items = [item.strip().strip('"').strip("'")
for item in value.split(',')
if item.strip() and item.strip().lower() not in ['none', 'n/a']]
verification["Contradictions"] = items
elif 'relevant' in key_lower:
value_upper = value.upper()
if 'YES' in value_upper or 'TRUE' in value_upper:
verification["Relevant"] = "YES"
elif 'NO' in value_upper or 'FALSE' in value_upper:
verification["Relevant"] = "NO"
else:
verification["Relevant"] = "YES" # Default to YES if unclear
elif 'additional' in key_lower or 'detail' in key_lower:
if value.lower() in ['none', 'n/a', '']:
verification["Additional Details"] = ""
else:
verification["Additional Details"] = value
# Ensure all required keys are present with defaults
if "Supported" not in verification or verification.get("Supported") is None:
# If undecided, check if there are unsupported claims or contradictions
unsupported_claims = verification.get("Unsupported Claims", [])
contradictions = verification.get("Contradictions", [])
if not unsupported_claims and not contradictions:
verification["Supported"] = "YES" # No issues found, default to YES
print(f"[DEBUG] Supported was missing/undecided, but no claims/contradictions found, defaulting to YES")
else:
verification["Supported"] = "NO" # Issues found, default to NO
print(f"[DEBUG] Supported was missing/undecided, but found {len(unsupported_claims)} unsupported claims and {len(contradictions)} contradictions, defaulting to NO")
if "Unsupported Claims" not in verification:
verification["Unsupported Claims"] = []
if "Contradictions" not in verification:
verification["Contradictions"] = []
if "Relevant" not in verification:
verification["Relevant"] = "YES"
if "Additional Details" not in verification:
verification["Additional Details"] = ""
print(f"[DEBUG] Final parsed verification: Supported={verification.get('Supported')}, Unsupported Claims={len(verification.get('Unsupported Claims', []))}, Contradictions={len(verification.get('Contradictions', []))}")
return verification
except Exception as e:
print(f"Error parsing verification response: {e}")
print(f"Response text was: {response_text}")
# Return a safe default
return {
"Supported": "NO",
"Unsupported Claims": [],
"Contradictions": [],
"Relevant": "NO",
"Additional Details": f"Parsing error: {str(e)}"
}
def format_verification_report(self, verification: Dict) -> str:
"""
Format the verification report dictionary into a readable markdown-formatted report.
"""
supported = verification.get("Supported", "NO")
unsupported_claims = verification.get("Unsupported Claims", [])
contradictions = verification.get("Contradictions", [])
relevant = verification.get("Relevant", "NO")
additional_details = verification.get("Additional Details", "")
# Use markdown formatting for better display
report = f"### Verification Report\n\n"
# Add status indicators
supported_icon = "✅" if supported == "YES" else "❌"
report += f"**Supported:** {supported_icon} {supported}\n\n"
if unsupported_claims:
report += f"**⚠️ Unsupported Claims:**\n"
for claim in unsupported_claims:
report += f"- {claim}\n"
report += "\n"
else:
report += f"**Unsupported Claims:** None\n\n"
if contradictions:
report += f"**🔴 Contradictions:**\n"
for contradiction in contradictions:
report += f"- {contradiction}\n"
report += "\n"
else:
report += f"**Contradictions:** None\n\n"
relevant_icon = "✅" if relevant == "YES" else "❌"
report += f"**Relevant:** {relevant_icon} {relevant}\n\n"
if additional_details and additional_details.lower() not in ['none', 'n/a', '']:
report += f"**Additional Details:**\n{additional_details}\n"
else:
report += f"**Additional Details:** None\n"
return report
def generate_out_of_context_report(self) -> str:
"""
Generate a verification report for questions that are out of context.
"""
verification = {
"Supported": "NO",
"Unsupported Claims": ["The question is not related to the provided documents."],
"Contradictions": [],
"Relevant": "NO",
"Additional Details": "The question cannot be answered using the provided documents as it is out of context."
}
return self.format_verification_report(verification)
def check(self, answer: str, documents: List[Document]) -> Dict:
"""
Verify the answer against the provided documents.
"""
print(f"VerificationAgent.check called with answer='{answer}' and {len(documents)} documents.")
# Combine all document contents into one string
# Limit context size to prevent token overflow (keep last 8000 chars if too long)
context_parts = [doc.page_content for doc in documents]
context = "\n\n".join(context_parts)
# Truncate context if too long (keep most recent content which is usually more relevant)
MAX_CONTEXT_LENGTH = 10000 # Approximate character limit
if len(context) > MAX_CONTEXT_LENGTH:
print(f"Context too long ({len(context)} chars), truncating to last {MAX_CONTEXT_LENGTH} chars")
context = context[-MAX_CONTEXT_LENGTH:]
print(f"Combined context length: {len(context)} characters.")
# Create a prompt for the LLM to verify the answer
prompt = self.generate_prompt(answer, context)
print("Prompt created for the LLM.")
try:
print("Generating response with local Hugging Face model...")
llm_response = self.generate_with_hf(prompt)
print("LLM response received.")
except Exception as e:
print(f"Error during model inference: {e}")
raise RuntimeError("Failed to verify answer due to a model error.") from e
# Sanitize the response
sanitized_response = self.sanitize_response(llm_response) if llm_response else ""
if not sanitized_response:
print("LLM returned an empty response.")
verification_report = {
"Supported": "NO",
"Unsupported Claims": [],
"Contradictions": [],
"Relevant": "NO",
"Additional Details": "Empty response from the model."
}
else:
# Parse the response into the expected format
verification_report = self.parse_verification_response(sanitized_response)
if verification_report is None:
print("LLM did not respond with the expected format. Using default verification report.")
verification_report = {
"Supported": "NO",
"Unsupported Claims": [],
"Contradictions": [],
"Relevant": "NO",
"Additional Details": "Failed to parse the model's response."
}
# Format the verification report into a paragraph
verification_report_formatted = self.format_verification_report(verification_report)
print(f"Verification report:\n{verification_report_formatted}")
print(f"Context used: {context}")
return {
"verification_report": verification_report_formatted,
"context_used": context
}