Spaces:
Sleeping
Sleeping
| """ | |
| A validation framework to post-process and correct agent answers. | |
| This module provides a system for validating and correcting the final answers | |
| generated by the AI agent. This is crucial for ensuring that the output adheres | |
| to the strict formatting and content requirements of benchmarks like GAIA. | |
| The framework consists of: | |
| - A `Validator` abstract base class that defines a common interface for all | |
| validation logic, including a `can_handle` method to determine if a | |
| validator is applicable to a given task. | |
| - A `ValidatorRegistry` that holds all validator instances and runs them in | |
| a chain, sorted by a `priority` attribute. | |
| - Several concrete validator implementations for common cleanup tasks like | |
| removing trailing punctuation, correcting capitalization, extracting numbers, | |
| and handling domain-specific rules (e.g., the botany question). | |
| """ | |
| import re | |
| from abc import ABC, abstractmethod | |
| class Validator(ABC): | |
| """Abstract base class for a validator.""" | |
| priority = 100 # Default priority | |
| def can_handle(self, task_description: str) -> bool: | |
| """Whether this validator can handle the given task.""" | |
| pass | |
| def validate(self, answer: str) -> str: | |
| """Validate and potentially correct the answer.""" | |
| pass | |
| class PunctuationValidator(Validator): | |
| """Validator to remove trailing punctuation.""" | |
| priority = 999 # Should run last | |
| def can_handle(self, task_description: str) -> bool: | |
| """This validator can handle any task.""" | |
| return True | |
| def validate(self, answer: str) -> str: | |
| """Removes trailing punctuation from the answer.""" | |
| print(f"Running PunctuationValidator on: '{answer}'") | |
| validated_answer = re.sub(r"[.,!?;:]$", "", answer).strip() | |
| print(f"PunctuationValidator result: '{validated_answer}'") | |
| return validated_answer | |
| class BotanyValidator(Validator): | |
| """Validator for the botany question.""" | |
| priority = 10 # High priority for content-specific validation | |
| def can_handle(self, task_description: str) -> bool: | |
| """Handles the botany question about fruits and vegetables.""" | |
| return "botany" in task_description.lower() and "vegetable" in task_description.lower() | |
| def validate(self, answer: str) -> str: | |
| """Ensures no botanical fruits are in the vegetable list.""" | |
| print(f"Running BotanyValidator on: '{answer}'") | |
| botanical_fruits = [ | |
| "corn", "zucchini", "green beans", "bell pepper", "tomato", | |
| "pepper", "cucumber", "beans", "peas", "okra", "eggplant", "acorns" | |
| ] | |
| vegetables = [veg.strip().lower() for veg in answer.split(',')] | |
| final_vegetables = [veg for veg in vegetables if veg not in botanical_fruits] | |
| validated_answer = ", ".join(sorted(final_vegetables)) | |
| print(f"BotanyValidator result: '{validated_answer}'") | |
| return validated_answer | |
| class CaseValidator(Validator): | |
| """Validator to capitalize single-word, all-lowercase answers.""" | |
| priority = 900 # Low priority, but before final punctuation strip | |
| def can_handle(self, task_description: str) -> bool: | |
| """This validator can handle any task.""" | |
| return True | |
| def validate(self, answer: str) -> str: | |
| """Capitalizes the answer if it's a single, all-lowercase word.""" | |
| print(f"Running CaseValidator on: '{answer}'") | |
| validated_answer = answer | |
| if len(answer.split()) == 1 and answer.islower(): | |
| validated_answer = answer.capitalize() | |
| print(f"CaseValidator result: '{validated_answer}'") | |
| return validated_answer | |
| class NumericValidator(Validator): | |
| """Validator to ensure numeric answers are just numbers.""" | |
| priority = 50 # Medium priority | |
| def can_handle(self, task_description: str) -> bool: | |
| """Handles tasks that ask 'how many'.""" | |
| return "how many" in task_description.lower() | |
| def validate(self, answer: str) -> str: | |
| """Extracts only the numeric digits from the answer.""" | |
| print(f"Running NumericValidator on: '{answer}'") | |
| validated_answer = answer | |
| numbers = re.findall(r'\d+', answer) | |
| if numbers: | |
| validated_answer = numbers[0] | |
| print(f"NumericValidator result: '{validated_answer}'") | |
| return validated_answer | |
| class ValidatorRegistry: | |
| """Registry for validators.""" | |
| def __init__(self): | |
| self.validators = [] | |
| def register(self, validator: Validator): | |
| """Register a validator and maintain sorted order by priority.""" | |
| self.validators.append(validator) | |
| self.validators.sort(key=lambda v: v.priority) | |
| def process(self, task_description: str, answer: str) -> str: | |
| """Process the answer through all applicable validators in order of priority.""" | |
| processed_answer = answer | |
| for validator in self.validators: | |
| if validator.can_handle(task_description): | |
| processed_answer = validator.validate(processed_answer) | |
| return processed_answer | |
| # Instantiate and register validators | |
| validator_registry = ValidatorRegistry() | |
| validator_registry.register(PunctuationValidator()) | |
| validator_registry.register(BotanyValidator()) | |
| validator_registry.register(CaseValidator()) | |
| validator_registry.register(NumericValidator()) | |