Spaces:
Running
Running
| """Answer extraction and validation utilities. | |
| Consolidates answer-related logic: | |
| - Extraction from LLM responses (CoT format) | |
| - Validation against valid choices | |
| - Normalization with fallback defaults | |
| """ | |
| import re | |
| import string | |
| from src.utils.logging import print_log | |
| def extract_answer(response: str, num_choices: int = 4, require_end: bool = False) -> str | None: | |
| """Extract answer letter from LLM response using strict explicit answer lines. | |
| Only accepts answers from explicit final-answer lines with colon: | |
| - "Đáp án: A", "Answer: B" (preferred) | |
| - "Lựa chọn: C" (secondary) | |
| Returns the LAST valid explicit answer line found (later lines override earlier). | |
| Args: | |
| response: Response text from LLM | |
| num_choices: Number of valid choices | |
| require_end: If True, only extract answer from last 20% of response | |
| Returns: | |
| Answer letter (A, B, C, D) or None if no explicit answer found | |
| """ | |
| if not response: | |
| return None | |
| valid_labels = string.ascii_uppercase[:num_choices] | |
| # If require_end, only look at last 20% of response | |
| search_text = response | |
| if require_end and len(response) > 100: | |
| cutoff = int(len(response) * 0.8) | |
| search_text = response[cutoff:] | |
| # Pattern for primary labels: "Đáp án:" or "Answer:" (highest priority) | |
| primary_pattern = r"(?:Đáp\s*án|Answer)[ \t]*[::][ \t]*\**([A-Z])\b" | |
| # Pattern for secondary label: "Lựa chọn:" (lower priority) | |
| secondary_pattern = r"Lựa\s*chọn[ \t]*[::][ \t]*\**([A-Z])\b" | |
| # Find all matches for both patterns | |
| primary_matches = re.findall(primary_pattern, search_text, flags=re.IGNORECASE) | |
| secondary_matches = re.findall(secondary_pattern, search_text, flags=re.IGNORECASE) | |
| if primary_matches: | |
| answer = primary_matches[-1].upper() | |
| if answer in valid_labels: | |
| return answer | |
| if secondary_matches: | |
| answer = secondary_matches[-1].upper() | |
| if answer in valid_labels: | |
| return answer | |
| # Single letter response (entire response is just a letter) | |
| clean_response = search_text.strip() | |
| if len(clean_response) == 1 and clean_response.upper() in valid_labels: | |
| return clean_response.upper() | |
| return None | |
| def validate_answer(answer: str, num_choices: int) -> tuple[bool, str]: | |
| """Validate if answer is within valid range and normalize it. | |
| Args: | |
| answer: Raw answer string from model | |
| num_choices: Number of choices available (A, B, C, D, ...) | |
| Returns: | |
| Tuple of (is_valid, normalized_answer) | |
| """ | |
| valid_answers = string.ascii_uppercase[:num_choices] | |
| if answer and answer.upper() in valid_answers: | |
| return True, answer.upper() | |
| return False, answer or "" | |
| def normalize_answer( | |
| answer: str | None, | |
| num_choices: int, | |
| question_id: str | None = None, | |
| default: str = "A", | |
| ) -> str: | |
| """Normalize and validate answer with fallback to default. | |
| Combines extraction, validation, and normalization: | |
| - Validates answer is within valid range (A, B, C, D, ...) | |
| - Normalizes refusal responses | |
| - Falls back to default for invalid answers | |
| Args: | |
| answer: Raw answer string from model (can be None) | |
| num_choices: Number of choices available | |
| question_id: Optional question ID for logging warnings | |
| default: Default answer if validation fails | |
| Returns: | |
| Normalized answer string | |
| """ | |
| if answer is None: | |
| if question_id: | |
| print_log( | |
| f" [Warning] No answer extracted for {question_id}, " | |
| f"defaulting to {default}" | |
| ) | |
| return default | |
| is_valid, normalized = validate_answer(answer, num_choices) | |
| if not is_valid: | |
| if question_id: | |
| print_log( | |
| f" [Warning] Invalid answer '{answer}' for {question_id}, " | |
| f"defaulting to {default}" | |
| ) | |
| return default | |
| return normalized | |
| def extract_and_normalize( | |
| response: str, | |
| num_choices: int, | |
| question_id: str | None = None, | |
| default: str = "A", | |
| ) -> str: | |
| """Extract answer from response and normalize it (convenience function). | |
| Combines extract_answer() and normalize_answer() into a single call. | |
| Args: | |
| response: Raw LLM response text | |
| num_choices: Number of valid choices | |
| question_id: Optional question ID for logging | |
| default: Default answer if extraction/validation fails | |
| Returns: | |
| Normalized answer string | |
| """ | |
| extracted = extract_answer(response, num_choices=num_choices) | |
| return normalize_answer(extracted, num_choices, question_id, default) | |