Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,510 Bytes
ca5692a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import re
import torch
import torch.nn.functional as F
def compute_first_boxed_answer_probs(b, gen_ids, gen_out, ans, task, tokenizer):
# extract logprobs for each step based on the gen_ids
cur_lp = []
for t, tok_id in enumerate(gen_ids):
if t >= len(gen_out.scores):
break
step_scores = gen_out.scores[t][b] # [V]
step_logprobs = F.log_softmax(step_scores, dim=-1)
cur_lp.append(step_logprobs[tok_id.item()].unsqueeze(0))
lp_vec = torch.cat(cur_lp, dim=0) if cur_lp else torch.empty(0)
if task.startswith("mvp_"):
ans = f"\\boxed{{Answer:{ans}"
_prefix_ids_tensor = tokenizer(
"\\boxed{Answer:",
add_special_tokens=False,
return_tensors="pt",
).input_ids[0]
prefix_ids = _prefix_ids_tensor.to(
device=gen_ids.device,
dtype=gen_ids.dtype,
)
fake_lp = torch.zeros(
prefix_ids.shape[0],
device=lp_vec.device,
dtype=lp_vec.dtype,
)
gen_ids = torch.cat([prefix_ids, gen_ids], dim=0)
lp_vec = torch.cat([fake_lp, lp_vec], dim=0)
first = extract_first_boxed_content(ans)
first_box_tok_logprobs = None
if first:
first_content, _, _ = first # content only
# Exclude cases starting with "Let's analyze" (ignoring leading whitespace).
if not first_content.lstrip().startswith("Let's analyze"):
# Normalize (normalize_text may clean up choices/letters to forms like 'H', etc.)
first_box_norm = normalize_text(first_content)
# Locate the normalized text within the token sequence
span = find_token_span_for_text(
gen_ids=gen_ids,
text_piece=first_box_norm,
tokenizer=tokenizer,
decoded_answer=ans,
)
if span is not None and lp_vec.numel() > 0:
s, e = span
# Defensive clipping (shouldn't be necessary in theory)
s = max(0, min(s, lp_vec.shape[0]))
e = max(0, min(e, lp_vec.shape[0]))
if e > s:
first_box_tok_logprobs = lp_vec[s:e]
# for mvp, the answer is like "Answer: A", so we use the last token
if task.startswith("mvp_"):
first_box_tok_logprobs = first_box_tok_logprobs[-1]
if first_box_tok_logprobs is None:
first_box_probs = -1
else:
first_box_probs = first_box_tok_logprobs.mean().exp().item()
return first_box_probs
_PATTERN_BOXED = re.compile(r"\\boxed\{([^{}]*(?:\{(?:[^{}]+|\{[^{}]*\})*\}[^{}]*)*)\}")
def extract_first_boxed_content(text: str):
"""
Returns:
(content, inner_start, inner_end)
- content: inner text of the first \\boxed{...} (group 1)
- inner_start, inner_end: character indices of that inner content in `text` (end is exclusive)
Requirement: the text must contain at least two \\boxed{...} occurrences; otherwise return False.
"""
it = _PATTERN_BOXED.finditer(text)
m1 = next(it, None)
if m1 is None:
return False
if next(it, None) is None: # require at least two boxed occurrences
return False
content = m1.group(1)
inner_start, inner_end = m1.span(1) # return the span of the *inner* content only
return content, inner_start, inner_end
def _find_subsequence(haystack_ids, needle_ids):
"""
Return (start_idx, end_idx); return None if not found.
"""
if not needle_ids:
return None
n = len(needle_ids)
limit = len(haystack_ids) - n + 1
for i in range(max(0, 0), max(0, limit)):
if haystack_ids[i : i + n] == needle_ids:
return i, i + n
# Edge case: if the needle is longer than the haystack, fail directly
if limit <= 0 and haystack_ids == needle_ids:
return 0, len(haystack_ids)
return None
def _first_nonempty_find(text, variants):
"""Find the first occurring variant in `text` (in order). Return (variant, char_pos) or (None, -1)."""
for v in variants:
if not v:
continue
pos = text.find(v)
if pos != -1:
return v, pos
return None, -1
def find_token_span_for_text(gen_ids, text_piece, tokenizer, decoded_answer):
"""
Goal: Given the decoded complete answer string `decoded_answer`, its generated token sequence `gen_ids`,
and a text fragment `text_piece`, find the corresponding token span for that fragment.
Strategy:
A) Encode `text_piece` into tokens and search it as a subsequence in `gen_ids`
using multiple textual variants: original / stripped / lstrip / prefixed with a space.
B) If (A) fails: locate the fragment via `str.find()` in `decoded_answer`, then
re-encode `decoded_answer[:pos]` and the chosen fragment to infer the token span by length.
Returns: (tok_start, tok_end) or None
"""
# Common variants: original, strip, lstrip, prefixed space
candidates_text = [
text_piece,
text_piece.strip(),
text_piece.lstrip(),
(" " + text_piece) if not text_piece.startswith(" ") else text_piece,
]
# (A) Direct token subsequence match
for cand in candidates_text:
cand_ids = tokenizer.encode(cand, add_special_tokens=False)
if not cand_ids:
continue
span = _find_subsequence(gen_ids, cand_ids)
if span is not None:
return span
# (B) Fallback: use character position + re-encoding to estimate the token span
chosen, pos = _first_nonempty_find(decoded_answer, candidates_text)
if chosen is not None:
prefix_ids = tokenizer.encode(decoded_answer[:pos], add_special_tokens=False)
chosen_ids = tokenizer.encode(chosen, add_special_tokens=False)
start = len(prefix_ids)
end = start + len(chosen_ids)
if end <= len(gen_ids):
return (start, end)
return None
_CHOICE_PAREN = re.compile(r"""^\s*[\(\[\{]\s*([A-Za-z])\s*[\)\]\}]\s*(?:[.)/:;\-]\s*)?""", re.X)
_CHOICE_BARE_WITH_DELIM = re.compile(r"""^\s*([A-Za-z])\s*[.)/:;\-]\s*""", re.X)
_CHOICE_SINGLE_LETTER = re.compile(r"""^\s*([A-Za-z])\s*[.]?\s*$""", re.X)
def normalize_text(s):
m = _CHOICE_PAREN.match(s) or _CHOICE_BARE_WITH_DELIM.match(s) or _CHOICE_SINGLE_LETTER.match(s)
if m:
return m.group(1)
else:
return s |