""" A validation framework to post-process and correct agent answers. This module provides a system for validating and correcting the final answers generated by the AI agent. This is crucial for ensuring that the output adheres to the strict formatting and content requirements of benchmarks like GAIA. The framework consists of: - A `Validator` abstract base class that defines a common interface for all validation logic, including a `can_handle` method to determine if a validator is applicable to a given task. - A `ValidatorRegistry` that holds all validator instances and runs them in a chain, sorted by a `priority` attribute. - Several concrete validator implementations for common cleanup tasks like removing trailing punctuation, correcting capitalization, extracting numbers, and handling domain-specific rules (e.g., the botany question). """ import re from abc import ABC, abstractmethod class Validator(ABC): """Abstract base class for a validator.""" priority = 100 # Default priority @abstractmethod def can_handle(self, task_description: str) -> bool: """Whether this validator can handle the given task.""" pass @abstractmethod def validate(self, answer: str) -> str: """Validate and potentially correct the answer.""" pass class PunctuationValidator(Validator): """Validator to remove trailing punctuation.""" priority = 999 # Should run last def can_handle(self, task_description: str) -> bool: """This validator can handle any task.""" return True def validate(self, answer: str) -> str: """Removes trailing punctuation from the answer.""" print(f"Running PunctuationValidator on: '{answer}'") validated_answer = re.sub(r"[.,!?;:]$", "", answer).strip() print(f"PunctuationValidator result: '{validated_answer}'") return validated_answer class BotanyValidator(Validator): """Validator for the botany question.""" priority = 10 # High priority for content-specific validation def can_handle(self, task_description: str) -> bool: """Handles the botany question about fruits and vegetables.""" return "botany" in task_description.lower() and "vegetable" in task_description.lower() def validate(self, answer: str) -> str: """Ensures no botanical fruits are in the vegetable list.""" print(f"Running BotanyValidator on: '{answer}'") botanical_fruits = [ "corn", "zucchini", "green beans", "bell pepper", "tomato", "pepper", "cucumber", "beans", "peas", "okra", "eggplant", "acorns" ] vegetables = [veg.strip().lower() for veg in answer.split(',')] final_vegetables = [veg for veg in vegetables if veg not in botanical_fruits] validated_answer = ", ".join(sorted(final_vegetables)) print(f"BotanyValidator result: '{validated_answer}'") return validated_answer class CaseValidator(Validator): """Validator to capitalize single-word, all-lowercase answers.""" priority = 900 # Low priority, but before final punctuation strip def can_handle(self, task_description: str) -> bool: """This validator can handle any task.""" return True def validate(self, answer: str) -> str: """Capitalizes the answer if it's a single, all-lowercase word.""" print(f"Running CaseValidator on: '{answer}'") validated_answer = answer if len(answer.split()) == 1 and answer.islower(): validated_answer = answer.capitalize() print(f"CaseValidator result: '{validated_answer}'") return validated_answer class NumericValidator(Validator): """Validator to ensure numeric answers are just numbers.""" priority = 50 # Medium priority def can_handle(self, task_description: str) -> bool: """Handles tasks that ask 'how many'.""" return "how many" in task_description.lower() def validate(self, answer: str) -> str: """Extracts only the numeric digits from the answer.""" print(f"Running NumericValidator on: '{answer}'") validated_answer = answer numbers = re.findall(r'\d+', answer) if numbers: validated_answer = numbers[0] print(f"NumericValidator result: '{validated_answer}'") return validated_answer class ValidatorRegistry: """Registry for validators.""" def __init__(self): self.validators = [] def register(self, validator: Validator): """Register a validator and maintain sorted order by priority.""" self.validators.append(validator) self.validators.sort(key=lambda v: v.priority) def process(self, task_description: str, answer: str) -> str: """Process the answer through all applicable validators in order of priority.""" processed_answer = answer for validator in self.validators: if validator.can_handle(task_description): processed_answer = validator.validate(processed_answer) return processed_answer # Instantiate and register validators validator_registry = ValidatorRegistry() validator_registry.register(PunctuationValidator()) validator_registry.register(BotanyValidator()) validator_registry.register(CaseValidator()) validator_registry.register(NumericValidator())