Spaces:
Sleeping
Sleeping
File size: 5,380 Bytes
8bb6cdd 1b7cf74 3d98bd6 1b7cf74 3d98bd6 1b7cf74 a599583 1b7cf74 3d98bd6 1b7cf74 a599583 1b7cf74 a599583 1b7cf74 3d98bd6 1b7cf74 3d98bd6 1b7cf74 3d98bd6 a599583 3d98bd6 a599583 1b7cf74 3d98bd6 1b7cf74 a599583 1b7cf74 a599583 1b7cf74 3d98bd6 1b7cf74 3d98bd6 1b7cf74 3d98bd6 1b7cf74 3d98bd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
A validation framework to post-process and correct agent answers.
This module provides a system for validating and correcting the final answers
generated by the AI agent. This is crucial for ensuring that the output adheres
to the strict formatting and content requirements of benchmarks like GAIA.
The framework consists of:
- A `Validator` abstract base class that defines a common interface for all
validation logic, including a `can_handle` method to determine if a
validator is applicable to a given task.
- A `ValidatorRegistry` that holds all validator instances and runs them in
a chain, sorted by a `priority` attribute.
- Several concrete validator implementations for common cleanup tasks like
removing trailing punctuation, correcting capitalization, extracting numbers,
and handling domain-specific rules (e.g., the botany question).
"""
import re
from abc import ABC, abstractmethod
class Validator(ABC):
"""Abstract base class for a validator."""
priority = 100 # Default priority
@abstractmethod
def can_handle(self, task_description: str) -> bool:
"""Whether this validator can handle the given task."""
pass
@abstractmethod
def validate(self, answer: str) -> str:
"""Validate and potentially correct the answer."""
pass
class PunctuationValidator(Validator):
"""Validator to remove trailing punctuation."""
priority = 999 # Should run last
def can_handle(self, task_description: str) -> bool:
"""This validator can handle any task."""
return True
def validate(self, answer: str) -> str:
"""Removes trailing punctuation from the answer."""
print(f"Running PunctuationValidator on: '{answer}'")
validated_answer = re.sub(r"[.,!?;:]$", "", answer).strip()
print(f"PunctuationValidator result: '{validated_answer}'")
return validated_answer
class BotanyValidator(Validator):
"""Validator for the botany question."""
priority = 10 # High priority for content-specific validation
def can_handle(self, task_description: str) -> bool:
"""Handles the botany question about fruits and vegetables."""
return "botany" in task_description.lower() and "vegetable" in task_description.lower()
def validate(self, answer: str) -> str:
"""Ensures no botanical fruits are in the vegetable list."""
print(f"Running BotanyValidator on: '{answer}'")
botanical_fruits = [
"corn", "zucchini", "green beans", "bell pepper", "tomato",
"pepper", "cucumber", "beans", "peas", "okra", "eggplant", "acorns"
]
vegetables = [veg.strip().lower() for veg in answer.split(',')]
final_vegetables = [veg for veg in vegetables if veg not in botanical_fruits]
validated_answer = ", ".join(sorted(final_vegetables))
print(f"BotanyValidator result: '{validated_answer}'")
return validated_answer
class CaseValidator(Validator):
"""Validator to capitalize single-word, all-lowercase answers."""
priority = 900 # Low priority, but before final punctuation strip
def can_handle(self, task_description: str) -> bool:
"""This validator can handle any task."""
return True
def validate(self, answer: str) -> str:
"""Capitalizes the answer if it's a single, all-lowercase word."""
print(f"Running CaseValidator on: '{answer}'")
validated_answer = answer
if len(answer.split()) == 1 and answer.islower():
validated_answer = answer.capitalize()
print(f"CaseValidator result: '{validated_answer}'")
return validated_answer
class NumericValidator(Validator):
"""Validator to ensure numeric answers are just numbers."""
priority = 50 # Medium priority
def can_handle(self, task_description: str) -> bool:
"""Handles tasks that ask 'how many'."""
return "how many" in task_description.lower()
def validate(self, answer: str) -> str:
"""Extracts only the numeric digits from the answer."""
print(f"Running NumericValidator on: '{answer}'")
validated_answer = answer
numbers = re.findall(r'\d+', answer)
if numbers:
validated_answer = numbers[0]
print(f"NumericValidator result: '{validated_answer}'")
return validated_answer
class ValidatorRegistry:
"""Registry for validators."""
def __init__(self):
self.validators = []
def register(self, validator: Validator):
"""Register a validator and maintain sorted order by priority."""
self.validators.append(validator)
self.validators.sort(key=lambda v: v.priority)
def process(self, task_description: str, answer: str) -> str:
"""Process the answer through all applicable validators in order of priority."""
processed_answer = answer
for validator in self.validators:
if validator.can_handle(task_description):
processed_answer = validator.validate(processed_answer)
return processed_answer
# Instantiate and register validators
validator_registry = ValidatorRegistry()
validator_registry.register(PunctuationValidator())
validator_registry.register(BotanyValidator())
validator_registry.register(CaseValidator())
validator_registry.register(NumericValidator())
|