Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import torch | |
| import torch.nn.functional as F | |
| def compute_first_boxed_answer_probs(b, gen_ids, gen_out, ans, task, tokenizer): | |
| # extract logprobs for each step based on the gen_ids | |
| cur_lp = [] | |
| for t, tok_id in enumerate(gen_ids): | |
| if t >= len(gen_out.scores): | |
| break | |
| step_scores = gen_out.scores[t][b] # [V] | |
| step_logprobs = F.log_softmax(step_scores, dim=-1) | |
| cur_lp.append(step_logprobs[tok_id.item()].unsqueeze(0)) | |
| lp_vec = torch.cat(cur_lp, dim=0) if cur_lp else torch.empty(0) | |
| if task.startswith("mvp_"): | |
| ans = f"\\boxed{{Answer:{ans}" | |
| _prefix_ids_tensor = tokenizer( | |
| "\\boxed{Answer:", | |
| add_special_tokens=False, | |
| return_tensors="pt", | |
| ).input_ids[0] | |
| prefix_ids = _prefix_ids_tensor.to( | |
| device=gen_ids.device, | |
| dtype=gen_ids.dtype, | |
| ) | |
| fake_lp = torch.zeros( | |
| prefix_ids.shape[0], | |
| device=lp_vec.device, | |
| dtype=lp_vec.dtype, | |
| ) | |
| gen_ids = torch.cat([prefix_ids, gen_ids], dim=0) | |
| lp_vec = torch.cat([fake_lp, lp_vec], dim=0) | |
| first = extract_first_boxed_content(ans) | |
| first_box_tok_logprobs = None | |
| if first: | |
| first_content, _, _ = first # content only | |
| # Exclude cases starting with "Let's analyze" (ignoring leading whitespace). | |
| if not first_content.lstrip().startswith("Let's analyze"): | |
| # Normalize (normalize_text may clean up choices/letters to forms like 'H', etc.) | |
| first_box_norm = normalize_text(first_content) | |
| # Locate the normalized text within the token sequence | |
| span = find_token_span_for_text( | |
| gen_ids=gen_ids, | |
| text_piece=first_box_norm, | |
| tokenizer=tokenizer, | |
| decoded_answer=ans, | |
| ) | |
| if span is not None and lp_vec.numel() > 0: | |
| s, e = span | |
| # Defensive clipping (shouldn't be necessary in theory) | |
| s = max(0, min(s, lp_vec.shape[0])) | |
| e = max(0, min(e, lp_vec.shape[0])) | |
| if e > s: | |
| first_box_tok_logprobs = lp_vec[s:e] | |
| # for mvp, the answer is like "Answer: A", so we use the last token | |
| if task.startswith("mvp_"): | |
| first_box_tok_logprobs = first_box_tok_logprobs[-1] | |
| if first_box_tok_logprobs is None: | |
| first_box_probs = -1 | |
| else: | |
| first_box_probs = first_box_tok_logprobs.mean().exp().item() | |
| return first_box_probs | |
| _PATTERN_BOXED = re.compile(r"\\boxed\{([^{}]*(?:\{(?:[^{}]+|\{[^{}]*\})*\}[^{}]*)*)\}") | |
| def extract_first_boxed_content(text: str): | |
| """ | |
| Returns: | |
| (content, inner_start, inner_end) | |
| - content: inner text of the first \\boxed{...} (group 1) | |
| - inner_start, inner_end: character indices of that inner content in `text` (end is exclusive) | |
| Requirement: the text must contain at least two \\boxed{...} occurrences; otherwise return False. | |
| """ | |
| it = _PATTERN_BOXED.finditer(text) | |
| m1 = next(it, None) | |
| if m1 is None: | |
| return False | |
| if next(it, None) is None: # require at least two boxed occurrences | |
| return False | |
| content = m1.group(1) | |
| inner_start, inner_end = m1.span(1) # return the span of the *inner* content only | |
| return content, inner_start, inner_end | |
| def _find_subsequence(haystack_ids, needle_ids): | |
| """ | |
| Return (start_idx, end_idx); return None if not found. | |
| """ | |
| if not needle_ids: | |
| return None | |
| n = len(needle_ids) | |
| limit = len(haystack_ids) - n + 1 | |
| for i in range(max(0, 0), max(0, limit)): | |
| if haystack_ids[i : i + n] == needle_ids: | |
| return i, i + n | |
| # Edge case: if the needle is longer than the haystack, fail directly | |
| if limit <= 0 and haystack_ids == needle_ids: | |
| return 0, len(haystack_ids) | |
| return None | |
| def _first_nonempty_find(text, variants): | |
| """Find the first occurring variant in `text` (in order). Return (variant, char_pos) or (None, -1).""" | |
| for v in variants: | |
| if not v: | |
| continue | |
| pos = text.find(v) | |
| if pos != -1: | |
| return v, pos | |
| return None, -1 | |
| def find_token_span_for_text(gen_ids, text_piece, tokenizer, decoded_answer): | |
| """ | |
| Goal: Given the decoded complete answer string `decoded_answer`, its generated token sequence `gen_ids`, | |
| and a text fragment `text_piece`, find the corresponding token span for that fragment. | |
| Strategy: | |
| A) Encode `text_piece` into tokens and search it as a subsequence in `gen_ids` | |
| using multiple textual variants: original / stripped / lstrip / prefixed with a space. | |
| B) If (A) fails: locate the fragment via `str.find()` in `decoded_answer`, then | |
| re-encode `decoded_answer[:pos]` and the chosen fragment to infer the token span by length. | |
| Returns: (tok_start, tok_end) or None | |
| """ | |
| # Common variants: original, strip, lstrip, prefixed space | |
| candidates_text = [ | |
| text_piece, | |
| text_piece.strip(), | |
| text_piece.lstrip(), | |
| (" " + text_piece) if not text_piece.startswith(" ") else text_piece, | |
| ] | |
| # (A) Direct token subsequence match | |
| for cand in candidates_text: | |
| cand_ids = tokenizer.encode(cand, add_special_tokens=False) | |
| if not cand_ids: | |
| continue | |
| span = _find_subsequence(gen_ids, cand_ids) | |
| if span is not None: | |
| return span | |
| # (B) Fallback: use character position + re-encoding to estimate the token span | |
| chosen, pos = _first_nonempty_find(decoded_answer, candidates_text) | |
| if chosen is not None: | |
| prefix_ids = tokenizer.encode(decoded_answer[:pos], add_special_tokens=False) | |
| chosen_ids = tokenizer.encode(chosen, add_special_tokens=False) | |
| start = len(prefix_ids) | |
| end = start + len(chosen_ids) | |
| if end <= len(gen_ids): | |
| return (start, end) | |
| return None | |
| _CHOICE_PAREN = re.compile(r"""^\s*[\(\[\{]\s*([A-Za-z])\s*[\)\]\}]\s*(?:[.)/:;\-]\s*)?""", re.X) | |
| _CHOICE_BARE_WITH_DELIM = re.compile(r"""^\s*([A-Za-z])\s*[.)/:;\-]\s*""", re.X) | |
| _CHOICE_SINGLE_LETTER = re.compile(r"""^\s*([A-Za-z])\s*[.]?\s*$""", re.X) | |
| def normalize_text(s): | |
| m = _CHOICE_PAREN.match(s) or _CHOICE_BARE_WITH_DELIM.match(s) or _CHOICE_SINGLE_LETTER.match(s) | |
| if m: | |
| return m.group(1) | |
| else: | |
| return s |