GAIA_Agent / validators.py
nikhmr1235's picture
add detailed docstring for validators.py
8bb6cdd verified
"""
A validation framework to post-process and correct agent answers.
This module provides a system for validating and correcting the final answers
generated by the AI agent. This is crucial for ensuring that the output adheres
to the strict formatting and content requirements of benchmarks like GAIA.
The framework consists of:
- A `Validator` abstract base class that defines a common interface for all
validation logic, including a `can_handle` method to determine if a
validator is applicable to a given task.
- A `ValidatorRegistry` that holds all validator instances and runs them in
a chain, sorted by a `priority` attribute.
- Several concrete validator implementations for common cleanup tasks like
removing trailing punctuation, correcting capitalization, extracting numbers,
and handling domain-specific rules (e.g., the botany question).
"""
import re
from abc import ABC, abstractmethod
class Validator(ABC):
"""Abstract base class for a validator."""
priority = 100 # Default priority
@abstractmethod
def can_handle(self, task_description: str) -> bool:
"""Whether this validator can handle the given task."""
pass
@abstractmethod
def validate(self, answer: str) -> str:
"""Validate and potentially correct the answer."""
pass
class PunctuationValidator(Validator):
"""Validator to remove trailing punctuation."""
priority = 999 # Should run last
def can_handle(self, task_description: str) -> bool:
"""This validator can handle any task."""
return True
def validate(self, answer: str) -> str:
"""Removes trailing punctuation from the answer."""
print(f"Running PunctuationValidator on: '{answer}'")
validated_answer = re.sub(r"[.,!?;:]$", "", answer).strip()
print(f"PunctuationValidator result: '{validated_answer}'")
return validated_answer
class BotanyValidator(Validator):
"""Validator for the botany question."""
priority = 10 # High priority for content-specific validation
def can_handle(self, task_description: str) -> bool:
"""Handles the botany question about fruits and vegetables."""
return "botany" in task_description.lower() and "vegetable" in task_description.lower()
def validate(self, answer: str) -> str:
"""Ensures no botanical fruits are in the vegetable list."""
print(f"Running BotanyValidator on: '{answer}'")
botanical_fruits = [
"corn", "zucchini", "green beans", "bell pepper", "tomato",
"pepper", "cucumber", "beans", "peas", "okra", "eggplant", "acorns"
]
vegetables = [veg.strip().lower() for veg in answer.split(',')]
final_vegetables = [veg for veg in vegetables if veg not in botanical_fruits]
validated_answer = ", ".join(sorted(final_vegetables))
print(f"BotanyValidator result: '{validated_answer}'")
return validated_answer
class CaseValidator(Validator):
"""Validator to capitalize single-word, all-lowercase answers."""
priority = 900 # Low priority, but before final punctuation strip
def can_handle(self, task_description: str) -> bool:
"""This validator can handle any task."""
return True
def validate(self, answer: str) -> str:
"""Capitalizes the answer if it's a single, all-lowercase word."""
print(f"Running CaseValidator on: '{answer}'")
validated_answer = answer
if len(answer.split()) == 1 and answer.islower():
validated_answer = answer.capitalize()
print(f"CaseValidator result: '{validated_answer}'")
return validated_answer
class NumericValidator(Validator):
"""Validator to ensure numeric answers are just numbers."""
priority = 50 # Medium priority
def can_handle(self, task_description: str) -> bool:
"""Handles tasks that ask 'how many'."""
return "how many" in task_description.lower()
def validate(self, answer: str) -> str:
"""Extracts only the numeric digits from the answer."""
print(f"Running NumericValidator on: '{answer}'")
validated_answer = answer
numbers = re.findall(r'\d+', answer)
if numbers:
validated_answer = numbers[0]
print(f"NumericValidator result: '{validated_answer}'")
return validated_answer
class ValidatorRegistry:
"""Registry for validators."""
def __init__(self):
self.validators = []
def register(self, validator: Validator):
"""Register a validator and maintain sorted order by priority."""
self.validators.append(validator)
self.validators.sort(key=lambda v: v.priority)
def process(self, task_description: str, answer: str) -> str:
"""Process the answer through all applicable validators in order of priority."""
processed_answer = answer
for validator in self.validators:
if validator.can_handle(task_description):
processed_answer = validator.validate(processed_answer)
return processed_answer
# Instantiate and register validators
validator_registry = ValidatorRegistry()
validator_registry.register(PunctuationValidator())
validator_registry.register(BotanyValidator())
validator_registry.register(CaseValidator())
validator_registry.register(NumericValidator())