| | import re
|
| | import torch
|
| |
|
| |
|
| | def check_multiple_choice_with_regex(model_outputs, correct_answers):
|
| | results = []
|
| | for model_output, correct_answer in zip(model_outputs, correct_answers):
|
| |
|
| | correct_answer = correct_answer.rstrip('\n').upper()
|
| |
|
| |
|
| | patterns = [
|
| | rf"\b{correct_answer}\b",
|
| | rf"\b{correct_answer}[.,)]",
|
| | rf"\(.*{correct_answer}.*\)",
|
| | ]
|
| |
|
| | match_found = False
|
| | for pattern in patterns:
|
| | if re.search(pattern, model_output):
|
| | match_found = True
|
| | break
|
| | results.append(match_found)
|
| | return results
|
| |
|
| |
|
| | def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
|
| | """
|
| | Apply top-k and/or nucleus (top-p) filtering to logits.
|
| | """
|
| | top_k = min(top_k, logits.size(-1))
|
| |
|
| | if top_k > 0:
|
| |
|
| | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| | logits = logits.masked_fill(indices_to_remove, filter_value)
|
| |
|
| | if top_p < 1.0:
|
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| | cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
| |
|
| |
|
| | sorted_indices_to_remove = cumulative_probs > top_p
|
| |
|
| |
|
| | sorted_indices_to_remove[..., 0] = False
|
| |
|
| | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| | logits = logits.masked_fill(indices_to_remove, filter_value)
|
| |
|
| | return logits
|
| |
|