Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| from typing import Dict, List | |
| from langchain_core.documents.base import Document | |
| from config.settings import settings | |
| class VerificationAgent: | |
| def __init__(self): | |
| """ | |
| Initialize the verification agent with LLM. | |
| """ | |
| print("Initializing RelevanceChecker with lightweight Hugging Face model...") | |
| # Use a smaller, CPU-friendly model by default | |
| model_name = getattr(settings, "HF_MODEL_NAME") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Use float32 on CPU (fp16 only works on GPU) | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch_dtype).to(self.device) | |
| print(f"Model '{model_name}' loaded on {self.device} with dtype={torch_dtype}.") | |
| def sanitize_response(self, response_text: str) -> str: | |
| """ | |
| Sanitize the LLM's response by stripping unnecessary whitespace. | |
| """ | |
| return response_text.strip() | |
| def generate_prompt(self, answer: str, context: str) -> str: | |
| """ | |
| Generate a structured prompt for the LLM to verify the answer against the context. | |
| """ | |
| prompt = f"""You are a strict verification agent. Your task is to verify if an answer is supported by the provided context. | |
| CRITICAL RULES: | |
| 1. ONLY use information from the context provided below. Do NOT use any external knowledge or assumptions. | |
| 2. If a claim in the answer is NOT explicitly or implicitly supported by the context, mark it as unsupported. | |
| 3. If the answer contradicts information in the context, mark it as a contradiction. | |
| 4. If you cannot verify a claim using ONLY the context, mark it as unsupported. | |
| 5. Be strict - do not assume or infer beyond what is clearly stated in the context. | |
| 6. Respond EXACTLY in the format specified below - no additional text, explanations, or formatting. | |
| **VERIFICATION FORMAT (follow exactly):** | |
| Supported: YES | |
| Unsupported Claims: [] | |
| Contradictions: [] | |
| Relevant: YES | |
| Additional Details: None | |
| OR if unsupported/contradictions found: | |
| Supported: NO | |
| Unsupported Claims: [list each unsupported claim exactly as it appears in the answer] | |
| Contradictions: [list each contradiction exactly as it appears] | |
| Relevant: YES or NO | |
| Additional Details: [brief explanation of why claims are unsupported or contradicted] | |
| **Answer to verify:** | |
| {answer} | |
| **Context (use ONLY this for verification):** | |
| {context} | |
| **Your verification (respond ONLY with the format above):** | |
| """ | |
| return prompt | |
| def generate_with_hf(self, prompt: str, max_new_tokens=512) -> str: | |
| """ | |
| Generate output using the local Hugging Face model. | |
| """ | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(self.device) | |
| outputs = self.model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def parse_verification_response(self, response_text: str) -> Dict: | |
| """ | |
| Parse the LLM's verification response into a structured dictionary. | |
| """ | |
| try: | |
| # Normalize the response - remove markdown formatting, extra whitespace | |
| response_text = response_text.strip() | |
| # Remove any markdown code blocks if present | |
| if response_text.startswith('```'): | |
| lines = response_text.split('\n') | |
| response_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else response_text | |
| print(f"[DEBUG] Parsing verification response (first 500 chars): {response_text[:500]}") | |
| verification = {} | |
| lines = response_text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if not line or not ':' in line: | |
| continue | |
| # Split on first colon only | |
| parts = line.split(':', 1) | |
| if len(parts) != 2: | |
| continue | |
| key = parts[0].strip() | |
| value = parts[1].strip() | |
| # Normalize key names (case-insensitive matching) | |
| key_lower = key.lower() | |
| if 'supported' in key_lower: | |
| # Extract YES/NO, handle variations | |
| value_upper = value.upper() | |
| print(f"[DEBUG] Found 'Supported' key with value: '{value}' (upper: '{value_upper}')") | |
| if 'YES' in value_upper or 'TRUE' in value_upper or 'Y' == value_upper.strip(): | |
| verification["Supported"] = "YES" | |
| print(f"[DEBUG] Set Supported to YES") | |
| elif 'NO' in value_upper or 'FALSE' in value_upper or 'N' == value_upper.strip(): | |
| verification["Supported"] = "NO" | |
| print(f"[DEBUG] Set Supported to NO") | |
| else: | |
| # If value is empty or unclear, check if there are unsupported claims/contradictions | |
| # If no issues found later, default to YES; otherwise NO | |
| print(f"[DEBUG] Supported value unclear: '{value}', will decide based on claims/contradictions") | |
| verification["Supported"] = None # Mark as undecided | |
| elif 'unsupported' in key_lower: | |
| # Handle list parsing | |
| items = [] | |
| value = value.strip() | |
| if value.lower() in ['none', 'n/a', '[]', '']: | |
| items = [] | |
| elif value.startswith('[') and value.endswith(']'): | |
| # Parse list items | |
| list_content = value[1:-1].strip() | |
| if list_content: | |
| items = [item.strip().strip('"').strip("'").strip() | |
| for item in list_content.split(',') | |
| if item.strip()] | |
| else: | |
| # Single item or comma-separated without brackets | |
| items = [item.strip().strip('"').strip("'") | |
| for item in value.split(',') | |
| if item.strip() and item.strip().lower() not in ['none', 'n/a']] | |
| verification["Unsupported Claims"] = items | |
| elif 'contradiction' in key_lower: | |
| # Handle list parsing (same logic as unsupported) | |
| items = [] | |
| value = value.strip() | |
| if value.lower() in ['none', 'n/a', '[]', '']: | |
| items = [] | |
| elif value.startswith('[') and value.endswith(']'): | |
| list_content = value[1:-1].strip() | |
| if list_content: | |
| items = [item.strip().strip('"').strip("'").strip() | |
| for item in list_content.split(',') | |
| if item.strip()] | |
| else: | |
| items = [item.strip().strip('"').strip("'") | |
| for item in value.split(',') | |
| if item.strip() and item.strip().lower() not in ['none', 'n/a']] | |
| verification["Contradictions"] = items | |
| elif 'relevant' in key_lower: | |
| value_upper = value.upper() | |
| if 'YES' in value_upper or 'TRUE' in value_upper: | |
| verification["Relevant"] = "YES" | |
| elif 'NO' in value_upper or 'FALSE' in value_upper: | |
| verification["Relevant"] = "NO" | |
| else: | |
| verification["Relevant"] = "YES" # Default to YES if unclear | |
| elif 'additional' in key_lower or 'detail' in key_lower: | |
| if value.lower() in ['none', 'n/a', '']: | |
| verification["Additional Details"] = "" | |
| else: | |
| verification["Additional Details"] = value | |
| # Ensure all required keys are present with defaults | |
| if "Supported" not in verification or verification.get("Supported") is None: | |
| # If undecided, check if there are unsupported claims or contradictions | |
| unsupported_claims = verification.get("Unsupported Claims", []) | |
| contradictions = verification.get("Contradictions", []) | |
| if not unsupported_claims and not contradictions: | |
| verification["Supported"] = "YES" # No issues found, default to YES | |
| print(f"[DEBUG] Supported was missing/undecided, but no claims/contradictions found, defaulting to YES") | |
| else: | |
| verification["Supported"] = "NO" # Issues found, default to NO | |
| print(f"[DEBUG] Supported was missing/undecided, but found {len(unsupported_claims)} unsupported claims and {len(contradictions)} contradictions, defaulting to NO") | |
| if "Unsupported Claims" not in verification: | |
| verification["Unsupported Claims"] = [] | |
| if "Contradictions" not in verification: | |
| verification["Contradictions"] = [] | |
| if "Relevant" not in verification: | |
| verification["Relevant"] = "YES" | |
| if "Additional Details" not in verification: | |
| verification["Additional Details"] = "" | |
| print(f"[DEBUG] Final parsed verification: Supported={verification.get('Supported')}, Unsupported Claims={len(verification.get('Unsupported Claims', []))}, Contradictions={len(verification.get('Contradictions', []))}") | |
| return verification | |
| except Exception as e: | |
| print(f"Error parsing verification response: {e}") | |
| print(f"Response text was: {response_text}") | |
| # Return a safe default | |
| return { | |
| "Supported": "NO", | |
| "Unsupported Claims": [], | |
| "Contradictions": [], | |
| "Relevant": "NO", | |
| "Additional Details": f"Parsing error: {str(e)}" | |
| } | |
| def format_verification_report(self, verification: Dict) -> str: | |
| """ | |
| Format the verification report dictionary into a readable markdown-formatted report. | |
| """ | |
| supported = verification.get("Supported", "NO") | |
| unsupported_claims = verification.get("Unsupported Claims", []) | |
| contradictions = verification.get("Contradictions", []) | |
| relevant = verification.get("Relevant", "NO") | |
| additional_details = verification.get("Additional Details", "") | |
| # Use markdown formatting for better display | |
| report = f"### Verification Report\n\n" | |
| # Add status indicators | |
| supported_icon = "✅" if supported == "YES" else "❌" | |
| report += f"**Supported:** {supported_icon} {supported}\n\n" | |
| if unsupported_claims: | |
| report += f"**⚠️ Unsupported Claims:**\n" | |
| for claim in unsupported_claims: | |
| report += f"- {claim}\n" | |
| report += "\n" | |
| else: | |
| report += f"**Unsupported Claims:** None\n\n" | |
| if contradictions: | |
| report += f"**🔴 Contradictions:**\n" | |
| for contradiction in contradictions: | |
| report += f"- {contradiction}\n" | |
| report += "\n" | |
| else: | |
| report += f"**Contradictions:** None\n\n" | |
| relevant_icon = "✅" if relevant == "YES" else "❌" | |
| report += f"**Relevant:** {relevant_icon} {relevant}\n\n" | |
| if additional_details and additional_details.lower() not in ['none', 'n/a', '']: | |
| report += f"**Additional Details:**\n{additional_details}\n" | |
| else: | |
| report += f"**Additional Details:** None\n" | |
| return report | |
| def generate_out_of_context_report(self) -> str: | |
| """ | |
| Generate a verification report for questions that are out of context. | |
| """ | |
| verification = { | |
| "Supported": "NO", | |
| "Unsupported Claims": ["The question is not related to the provided documents."], | |
| "Contradictions": [], | |
| "Relevant": "NO", | |
| "Additional Details": "The question cannot be answered using the provided documents as it is out of context." | |
| } | |
| return self.format_verification_report(verification) | |
| def check(self, answer: str, documents: List[Document]) -> Dict: | |
| """ | |
| Verify the answer against the provided documents. | |
| """ | |
| print(f"VerificationAgent.check called with answer='{answer}' and {len(documents)} documents.") | |
| # Combine all document contents into one string | |
| # Limit context size to prevent token overflow (keep last 8000 chars if too long) | |
| context_parts = [doc.page_content for doc in documents] | |
| context = "\n\n".join(context_parts) | |
| # Truncate context if too long (keep most recent content which is usually more relevant) | |
| MAX_CONTEXT_LENGTH = 10000 # Approximate character limit | |
| if len(context) > MAX_CONTEXT_LENGTH: | |
| print(f"Context too long ({len(context)} chars), truncating to last {MAX_CONTEXT_LENGTH} chars") | |
| context = context[-MAX_CONTEXT_LENGTH:] | |
| print(f"Combined context length: {len(context)} characters.") | |
| # Create a prompt for the LLM to verify the answer | |
| prompt = self.generate_prompt(answer, context) | |
| print("Prompt created for the LLM.") | |
| try: | |
| print("Generating response with local Hugging Face model...") | |
| llm_response = self.generate_with_hf(prompt) | |
| print("LLM response received.") | |
| except Exception as e: | |
| print(f"Error during model inference: {e}") | |
| raise RuntimeError("Failed to verify answer due to a model error.") from e | |
| # Sanitize the response | |
| sanitized_response = self.sanitize_response(llm_response) if llm_response else "" | |
| if not sanitized_response: | |
| print("LLM returned an empty response.") | |
| verification_report = { | |
| "Supported": "NO", | |
| "Unsupported Claims": [], | |
| "Contradictions": [], | |
| "Relevant": "NO", | |
| "Additional Details": "Empty response from the model." | |
| } | |
| else: | |
| # Parse the response into the expected format | |
| verification_report = self.parse_verification_response(sanitized_response) | |
| if verification_report is None: | |
| print("LLM did not respond with the expected format. Using default verification report.") | |
| verification_report = { | |
| "Supported": "NO", | |
| "Unsupported Claims": [], | |
| "Contradictions": [], | |
| "Relevant": "NO", | |
| "Additional Details": "Failed to parse the model's response." | |
| } | |
| # Format the verification report into a paragraph | |
| verification_report_formatted = self.format_verification_report(verification_report) | |
| print(f"Verification report:\n{verification_report_formatted}") | |
| print(f"Context used: {context}") | |
| return { | |
| "verification_report": verification_report_formatted, | |
| "context_used": context | |
| } |