Spaces:
Running
Running
File size: 4,852 Bytes
ebb8326 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
"""Answer extraction and validation utilities.
Consolidates answer-related logic:
- Extraction from LLM responses (CoT format)
- Validation against valid choices
- Normalization with fallback defaults
"""
import re
import string
from src.utils.logging import print_log
def extract_answer(response: str, num_choices: int = 4, require_end: bool = False) -> str | None:
"""Extract answer letter from LLM response using strict explicit answer lines.
Only accepts answers from explicit final-answer lines with colon:
- "Đáp án: A", "Answer: B" (preferred)
- "Lựa chọn: C" (secondary)
Returns the LAST valid explicit answer line found (later lines override earlier).
Args:
response: Response text from LLM
num_choices: Number of valid choices
require_end: If True, only extract answer from last 20% of response
Returns:
Answer letter (A, B, C, D) or None if no explicit answer found
"""
if not response:
return None
valid_labels = string.ascii_uppercase[:num_choices]
# If require_end, only look at last 20% of response
search_text = response
if require_end and len(response) > 100:
cutoff = int(len(response) * 0.8)
search_text = response[cutoff:]
# Pattern for primary labels: "Đáp án:" or "Answer:" (highest priority)
primary_pattern = r"(?:Đáp\s*án|Answer)[ \t]*[::][ \t]*\**([A-Z])\b"
# Pattern for secondary label: "Lựa chọn:" (lower priority)
secondary_pattern = r"Lựa\s*chọn[ \t]*[::][ \t]*\**([A-Z])\b"
# Find all matches for both patterns
primary_matches = re.findall(primary_pattern, search_text, flags=re.IGNORECASE)
secondary_matches = re.findall(secondary_pattern, search_text, flags=re.IGNORECASE)
if primary_matches:
answer = primary_matches[-1].upper()
if answer in valid_labels:
return answer
if secondary_matches:
answer = secondary_matches[-1].upper()
if answer in valid_labels:
return answer
# Single letter response (entire response is just a letter)
clean_response = search_text.strip()
if len(clean_response) == 1 and clean_response.upper() in valid_labels:
return clean_response.upper()
return None
def validate_answer(answer: str, num_choices: int) -> tuple[bool, str]:
"""Validate if answer is within valid range and normalize it.
Args:
answer: Raw answer string from model
num_choices: Number of choices available (A, B, C, D, ...)
Returns:
Tuple of (is_valid, normalized_answer)
"""
valid_answers = string.ascii_uppercase[:num_choices]
if answer and answer.upper() in valid_answers:
return True, answer.upper()
return False, answer or ""
def normalize_answer(
answer: str | None,
num_choices: int,
question_id: str | None = None,
default: str = "A",
) -> str:
"""Normalize and validate answer with fallback to default.
Combines extraction, validation, and normalization:
- Validates answer is within valid range (A, B, C, D, ...)
- Normalizes refusal responses
- Falls back to default for invalid answers
Args:
answer: Raw answer string from model (can be None)
num_choices: Number of choices available
question_id: Optional question ID for logging warnings
default: Default answer if validation fails
Returns:
Normalized answer string
"""
if answer is None:
if question_id:
print_log(
f" [Warning] No answer extracted for {question_id}, "
f"defaulting to {default}"
)
return default
is_valid, normalized = validate_answer(answer, num_choices)
if not is_valid:
if question_id:
print_log(
f" [Warning] Invalid answer '{answer}' for {question_id}, "
f"defaulting to {default}"
)
return default
return normalized
def extract_and_normalize(
response: str,
num_choices: int,
question_id: str | None = None,
default: str = "A",
) -> str:
"""Extract answer from response and normalize it (convenience function).
Combines extract_answer() and normalize_answer() into a single call.
Args:
response: Raw LLM response text
num_choices: Number of valid choices
question_id: Optional question ID for logging
default: Default answer if extraction/validation fails
Returns:
Normalized answer string
"""
extracted = extract_answer(response, num_choices=num_choices)
return normalize_answer(extracted, num_choices, question_id, default)
|