File size: 7,393 Bytes
3b97420 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from searshorai.model import GPT, GPTConfig
from searshorai.tokenizer import TextTokenizer
# Must match the prompts used in make_xsum_sft.py / make_paragraph_sft.py.
# Using the first template is the canonical choice at inference time.
DEFAULT_PROMPT_TEMPLATE = (
"Read the article and write a one-sentence summary.\n\n"
"Article:\n{passage}\n\nSummary:\n"
)
def strip_compile_prefix(state_dict):
cleaned = {}
for key, value in state_dict.items():
if key.startswith("_orig_mod."):
key = key[len("_orig_mod.") :]
cleaned[key] = value
return cleaned
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Ask the paragraph-explainer model.")
parser.add_argument("--checkpoint", type=Path, required=True)
parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json"))
parser.add_argument("--text", type=str, required=True, help="The passage to explain.")
parser.add_argument("--prompt_template", type=str, default=DEFAULT_PROMPT_TEMPLATE)
parser.add_argument("--max_new_tokens", type=int, default=120)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--top_p", type=float, default=0.9,
help="Nucleus sampling cutoff. Set 1.0 to disable.")
parser.add_argument("--repetition_penalty", type=float, default=1.3,
help="Penalty for re-emitting tokens already in the context. 1.0 = off.")
parser.add_argument("--no_repeat_ngram_size", type=int, default=3,
help="Block any n-gram of this size from appearing twice. 0 = off.")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args()
def banned_tokens_from_ngrams(generated: list[int], n: int) -> set[int]:
"""
For no-repeat-ngram blocking: given the tokens generated so far, return
the set of token ids that would close a previously-seen n-gram if emitted
next.
"""
if n <= 0 or len(generated) < n - 1:
return set()
prefix = tuple(generated[-(n - 1):])
banned: set[int] = set()
for i in range(len(generated) - n + 1):
ngram = tuple(generated[i : i + n - 1])
if ngram == prefix:
banned.add(generated[i + n - 1])
return banned
def generate(
model: GPT,
prompt_ids: list[int],
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
no_repeat_ngram_size: int,
eos_id: int | None,
device: str,
) -> list[int]:
"""
Custom sampling loop with repetition penalty, top-k, top-p (nucleus),
and no-repeat-ngram blocking. Returns the list of newly generated token
ids (does not include the prompt).
"""
block_size = model.config.block_size
context = list(prompt_ids)
generated: list[int] = []
for _ in range(max_new_tokens):
idx_cond = context if len(context) <= block_size else context[-block_size:]
x = torch.tensor([idx_cond], dtype=torch.long, device=device)
with torch.no_grad():
logits, _ = model(x)
logits = logits[:, -1, :].squeeze(0).float()
if repetition_penalty != 1.0 and len(context) > 0:
seen = torch.tensor(list(set(context)), dtype=torch.long, device=device)
scores = logits[seen]
scores = torch.where(scores > 0, scores / repetition_penalty, scores * repetition_penalty)
logits[seen] = scores
if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size - 1:
banned = banned_tokens_from_ngrams(generated, no_repeat_ngram_size)
for tok_id in banned:
logits[tok_id] = -float("inf")
logits = logits / max(temperature, 1e-5)
if top_k is not None and top_k > 0:
k = min(top_k, logits.size(-1))
top_vals, _ = torch.topk(logits, k)
cutoff = top_vals[-1]
logits = torch.where(logits < cutoff, torch.full_like(logits, -float("inf")), logits)
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumulative = torch.cumsum(probs_sorted, dim=-1)
mask = cumulative > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits = sorted_logits.masked_fill(mask, -float("inf"))
logits = torch.full_like(logits, -float("inf"))
logits.scatter_(0, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
if not torch.isfinite(probs).all() or probs.sum() <= 0:
next_tok = int(logits.argmax().item())
else:
next_tok = int(torch.multinomial(probs, num_samples=1).item())
if eos_id is not None and next_tok == eos_id:
break
context.append(next_tok)
generated.append(next_tok)
return generated
def main() -> None:
args = parse_args()
if args.seed:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
tok = TextTokenizer(args.tokenizer)
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
config = GPTConfig(**ckpt["config"])
config.dropout = 0.0
config.gradient_checkpointing = False
model = GPT(config)
state_dict = strip_compile_prefix(ckpt["model"])
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
if tok.vocab_size != model.config.vocab_size:
raise RuntimeError(
f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. "
"Use the same tokenizer.json that was used for pretrain/SFT."
)
prompt = args.prompt_template.format(passage=args.text.strip())
prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False)
max_prompt_len = model.config.block_size - args.max_new_tokens - 1
if max_prompt_len < 16:
raise RuntimeError(
f"max_new_tokens={args.max_new_tokens} is too large for block_size={model.config.block_size}."
)
if len(prompt_ids) > max_prompt_len:
bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else []
tail = prompt_ids[-(max_prompt_len - len(bos)) :]
prompt_ids = bos + tail
new_ids = generate(
model=model,
prompt_ids=prompt_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
no_repeat_ngram_size=args.no_repeat_ngram_size,
eos_id=tok.eos_id,
device=device,
)
answer = tok.decode(new_ids, skip_special_tokens=True).strip()
print(answer)
if __name__ == "__main__":
main() |