File size: 5,380 Bytes
8bb6cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7cf74
 
 
 
 
3d98bd6
 
1b7cf74
 
 
 
 
 
 
 
 
 
 
 
3d98bd6
 
1b7cf74
 
 
 
 
 
a599583
 
 
 
1b7cf74
 
 
3d98bd6
 
1b7cf74
 
 
 
 
 
a599583
1b7cf74
 
 
 
 
 
a599583
 
 
1b7cf74
 
3d98bd6
 
 
1b7cf74
3d98bd6
 
1b7cf74
 
3d98bd6
a599583
 
3d98bd6
a599583
 
 
1b7cf74
 
 
3d98bd6
 
1b7cf74
 
 
 
 
 
a599583
 
1b7cf74
 
a599583
 
 
1b7cf74
 
 
 
 
 
 
3d98bd6
1b7cf74
3d98bd6
1b7cf74
 
3d98bd6
1b7cf74
 
 
 
 
 
 
 
 
 
 
3d98bd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
A validation framework to post-process and correct agent answers.

This module provides a system for validating and correcting the final answers
generated by the AI agent. This is crucial for ensuring that the output adheres
to the strict formatting and content requirements of benchmarks like GAIA.

The framework consists of:
-   A `Validator` abstract base class that defines a common interface for all
    validation logic, including a `can_handle` method to determine if a
    validator is applicable to a given task.
-   A `ValidatorRegistry` that holds all validator instances and runs them in
    a chain, sorted by a `priority` attribute.
-   Several concrete validator implementations for common cleanup tasks like
    removing trailing punctuation, correcting capitalization, extracting numbers,
    and handling domain-specific rules (e.g., the botany question).
"""
import re
from abc import ABC, abstractmethod

class Validator(ABC):
    """Abstract base class for a validator."""
    priority = 100  # Default priority

    @abstractmethod
    def can_handle(self, task_description: str) -> bool:
        """Whether this validator can handle the given task."""
        pass

    @abstractmethod
    def validate(self, answer: str) -> str:
        """Validate and potentially correct the answer."""
        pass

class PunctuationValidator(Validator):
    """Validator to remove trailing punctuation."""
    priority = 999 # Should run last

    def can_handle(self, task_description: str) -> bool:
        """This validator can handle any task."""
        return True

    def validate(self, answer: str) -> str:
        """Removes trailing punctuation from the answer."""
        print(f"Running PunctuationValidator on: '{answer}'")
        validated_answer = re.sub(r"[.,!?;:]$", "", answer).strip()
        print(f"PunctuationValidator result: '{validated_answer}'")
        return validated_answer

class BotanyValidator(Validator):
    """Validator for the botany question."""
    priority = 10 # High priority for content-specific validation

    def can_handle(self, task_description: str) -> bool:
        """Handles the botany question about fruits and vegetables."""
        return "botany" in task_description.lower() and "vegetable" in task_description.lower()

    def validate(self, answer: str) -> str:
        """Ensures no botanical fruits are in the vegetable list."""
        print(f"Running BotanyValidator on: '{answer}'")
        botanical_fruits = [
            "corn", "zucchini", "green beans", "bell pepper", "tomato", 
            "pepper", "cucumber", "beans", "peas", "okra", "eggplant", "acorns"
        ]
        vegetables = [veg.strip().lower() for veg in answer.split(',')]
        final_vegetables = [veg for veg in vegetables if veg not in botanical_fruits]
        validated_answer = ", ".join(sorted(final_vegetables))
        print(f"BotanyValidator result: '{validated_answer}'")
        return validated_answer

class CaseValidator(Validator):
    """Validator to capitalize single-word, all-lowercase answers."""
    priority = 900 # Low priority, but before final punctuation strip

    def can_handle(self, task_description: str) -> bool:
        """This validator can handle any task."""
        return True

    def validate(self, answer: str) -> str:
        """Capitalizes the answer if it's a single, all-lowercase word."""
        print(f"Running CaseValidator on: '{answer}'")
        validated_answer = answer
        if len(answer.split()) == 1 and answer.islower():
            validated_answer = answer.capitalize()
        print(f"CaseValidator result: '{validated_answer}'")
        return validated_answer

class NumericValidator(Validator):
    """Validator to ensure numeric answers are just numbers."""
    priority = 50 # Medium priority

    def can_handle(self, task_description: str) -> bool:
        """Handles tasks that ask 'how many'."""
        return "how many" in task_description.lower()

    def validate(self, answer: str) -> str:
        """Extracts only the numeric digits from the answer."""
        print(f"Running NumericValidator on: '{answer}'")
        validated_answer = answer
        numbers = re.findall(r'\d+', answer)
        if numbers:
            validated_answer = numbers[0]
        print(f"NumericValidator result: '{validated_answer}'")
        return validated_answer

class ValidatorRegistry:
    """Registry for validators."""
    def __init__(self):
        self.validators = []

    def register(self, validator: Validator):
        """Register a validator and maintain sorted order by priority."""
        self.validators.append(validator)
        self.validators.sort(key=lambda v: v.priority)

    def process(self, task_description: str, answer: str) -> str:
        """Process the answer through all applicable validators in order of priority."""
        processed_answer = answer
        for validator in self.validators:
            if validator.can_handle(task_description):
                processed_answer = validator.validate(processed_answer)
        return processed_answer

# Instantiate and register validators
validator_registry = ValidatorRegistry()
validator_registry.register(PunctuationValidator())
validator_registry.register(BotanyValidator())
validator_registry.register(CaseValidator())
validator_registry.register(NumericValidator())