quanho114
Deploy VietQA API
ebb8326
"""Answer extraction and validation utilities.
Consolidates answer-related logic:
- Extraction from LLM responses (CoT format)
- Validation against valid choices
- Normalization with fallback defaults
"""
import re
import string
from src.utils.logging import print_log
def extract_answer(response: str, num_choices: int = 4, require_end: bool = False) -> str | None:
"""Extract answer letter from LLM response using strict explicit answer lines.
Only accepts answers from explicit final-answer lines with colon:
- "Đáp án: A", "Answer: B" (preferred)
- "Lựa chọn: C" (secondary)
Returns the LAST valid explicit answer line found (later lines override earlier).
Args:
response: Response text from LLM
num_choices: Number of valid choices
require_end: If True, only extract answer from last 20% of response
Returns:
Answer letter (A, B, C, D) or None if no explicit answer found
"""
if not response:
return None
valid_labels = string.ascii_uppercase[:num_choices]
# If require_end, only look at last 20% of response
search_text = response
if require_end and len(response) > 100:
cutoff = int(len(response) * 0.8)
search_text = response[cutoff:]
# Pattern for primary labels: "Đáp án:" or "Answer:" (highest priority)
primary_pattern = r"(?:Đáp\s*án|Answer)[ \t]*[::][ \t]*\**([A-Z])\b"
# Pattern for secondary label: "Lựa chọn:" (lower priority)
secondary_pattern = r"Lựa\s*chọn[ \t]*[::][ \t]*\**([A-Z])\b"
# Find all matches for both patterns
primary_matches = re.findall(primary_pattern, search_text, flags=re.IGNORECASE)
secondary_matches = re.findall(secondary_pattern, search_text, flags=re.IGNORECASE)
if primary_matches:
answer = primary_matches[-1].upper()
if answer in valid_labels:
return answer
if secondary_matches:
answer = secondary_matches[-1].upper()
if answer in valid_labels:
return answer
# Single letter response (entire response is just a letter)
clean_response = search_text.strip()
if len(clean_response) == 1 and clean_response.upper() in valid_labels:
return clean_response.upper()
return None
def validate_answer(answer: str, num_choices: int) -> tuple[bool, str]:
"""Validate if answer is within valid range and normalize it.
Args:
answer: Raw answer string from model
num_choices: Number of choices available (A, B, C, D, ...)
Returns:
Tuple of (is_valid, normalized_answer)
"""
valid_answers = string.ascii_uppercase[:num_choices]
if answer and answer.upper() in valid_answers:
return True, answer.upper()
return False, answer or ""
def normalize_answer(
answer: str | None,
num_choices: int,
question_id: str | None = None,
default: str = "A",
) -> str:
"""Normalize and validate answer with fallback to default.
Combines extraction, validation, and normalization:
- Validates answer is within valid range (A, B, C, D, ...)
- Normalizes refusal responses
- Falls back to default for invalid answers
Args:
answer: Raw answer string from model (can be None)
num_choices: Number of choices available
question_id: Optional question ID for logging warnings
default: Default answer if validation fails
Returns:
Normalized answer string
"""
if answer is None:
if question_id:
print_log(
f" [Warning] No answer extracted for {question_id}, "
f"defaulting to {default}"
)
return default
is_valid, normalized = validate_answer(answer, num_choices)
if not is_valid:
if question_id:
print_log(
f" [Warning] Invalid answer '{answer}' for {question_id}, "
f"defaulting to {default}"
)
return default
return normalized
def extract_and_normalize(
response: str,
num_choices: int,
question_id: str | None = None,
default: str = "A",
) -> str:
"""Extract answer from response and normalize it (convenience function).
Combines extract_answer() and normalize_answer() into a single call.
Args:
response: Raw LLM response text
num_choices: Number of valid choices
question_id: Optional question ID for logging
default: Default answer if extraction/validation fails
Returns:
Normalized answer string
"""
extracted = extract_answer(response, num_choices=num_choices)
return normalize_answer(extracted, num_choices, question_id, default)