File size: 4,852 Bytes
ebb8326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Answer extraction and validation utilities.

Consolidates answer-related logic:
- Extraction from LLM responses (CoT format)
- Validation against valid choices
- Normalization with fallback defaults
"""

import re
import string

from src.utils.logging import print_log


def extract_answer(response: str, num_choices: int = 4, require_end: bool = False) -> str | None:
    """Extract answer letter from LLM response using strict explicit answer lines.
    
    Only accepts answers from explicit final-answer lines with colon:
    - "Đáp án: A", "Answer: B" (preferred)
    - "Lựa chọn: C" (secondary)
    
    Returns the LAST valid explicit answer line found (later lines override earlier).
    
    Args:
        response: Response text from LLM
        num_choices: Number of valid choices
        require_end: If True, only extract answer from last 20% of response
        
    Returns:
        Answer letter (A, B, C, D) or None if no explicit answer found
    """
    if not response:
        return None
    
    valid_labels = string.ascii_uppercase[:num_choices]
    
    # If require_end, only look at last 20% of response
    search_text = response
    if require_end and len(response) > 100:
        cutoff = int(len(response) * 0.8)
        search_text = response[cutoff:]
    
    # Pattern for primary labels: "Đáp án:" or "Answer:" (highest priority)
    primary_pattern = r"(?:Đáp\s*án|Answer)[ \t]*[::][ \t]*\**([A-Z])\b"
    
    # Pattern for secondary label: "Lựa chọn:" (lower priority)
    secondary_pattern = r"Lựa\s*chọn[ \t]*[::][ \t]*\**([A-Z])\b"
    
    # Find all matches for both patterns
    primary_matches = re.findall(primary_pattern, search_text, flags=re.IGNORECASE)
    secondary_matches = re.findall(secondary_pattern, search_text, flags=re.IGNORECASE)
    
    if primary_matches:
        answer = primary_matches[-1].upper()
        if answer in valid_labels:
            return answer
    
    if secondary_matches:
        answer = secondary_matches[-1].upper()
        if answer in valid_labels:
            return answer
    
    # Single letter response (entire response is just a letter)
    clean_response = search_text.strip()
    if len(clean_response) == 1 and clean_response.upper() in valid_labels:
        return clean_response.upper()
    
    return None


def validate_answer(answer: str, num_choices: int) -> tuple[bool, str]:
    """Validate if answer is within valid range and normalize it.
    
    Args:
        answer: Raw answer string from model
        num_choices: Number of choices available (A, B, C, D, ...)
        
    Returns:
        Tuple of (is_valid, normalized_answer)
    """
    valid_answers = string.ascii_uppercase[:num_choices]
    if answer and answer.upper() in valid_answers:
        return True, answer.upper()
    
    return False, answer or ""


def normalize_answer(
    answer: str | None,
    num_choices: int,
    question_id: str | None = None,
    default: str = "A",
) -> str:
    """Normalize and validate answer with fallback to default.
    
    Combines extraction, validation, and normalization:
    - Validates answer is within valid range (A, B, C, D, ...)
    - Normalizes refusal responses  
    - Falls back to default for invalid answers
    
    Args:
        answer: Raw answer string from model (can be None)
        num_choices: Number of choices available
        question_id: Optional question ID for logging warnings
        default: Default answer if validation fails
        
    Returns:
        Normalized answer string
    """
    if answer is None:
        if question_id:
            print_log(
                f"        [Warning] No answer extracted for {question_id}, "
                f"defaulting to {default}"
            )
        return default
    
    is_valid, normalized = validate_answer(answer, num_choices)
    
    if not is_valid:
        if question_id:
            print_log(
                f"        [Warning] Invalid answer '{answer}' for {question_id}, "
                f"defaulting to {default}"
            )
        return default
    
    return normalized


def extract_and_normalize(
    response: str,
    num_choices: int,
    question_id: str | None = None,
    default: str = "A",
) -> str:
    """Extract answer from response and normalize it (convenience function).
    
    Combines extract_answer() and normalize_answer() into a single call.
    
    Args:
        response: Raw LLM response text
        num_choices: Number of valid choices
        question_id: Optional question ID for logging
        default: Default answer if extraction/validation fails
        
    Returns:
        Normalized answer string
    """
    extracted = extract_answer(response, num_choices=num_choices)
    return normalize_answer(extracted, num_choices, question_id, default)