RON-110M / code /ask.py
endurasolution's picture
Upload Ron-110M: pretrain + summarizer + tokenizer + code
3b97420 verified
from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from searshorai.model import GPT, GPTConfig
from searshorai.tokenizer import TextTokenizer
# Must match the prompts used in make_xsum_sft.py / make_paragraph_sft.py.
# Using the first template is the canonical choice at inference time.
DEFAULT_PROMPT_TEMPLATE = (
"Read the article and write a one-sentence summary.\n\n"
"Article:\n{passage}\n\nSummary:\n"
)
def strip_compile_prefix(state_dict):
cleaned = {}
for key, value in state_dict.items():
if key.startswith("_orig_mod."):
key = key[len("_orig_mod.") :]
cleaned[key] = value
return cleaned
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Ask the paragraph-explainer model.")
parser.add_argument("--checkpoint", type=Path, required=True)
parser.add_argument("--tokenizer", type=Path, default=Path("data/wikitext103/tokenizer.json"))
parser.add_argument("--text", type=str, required=True, help="The passage to explain.")
parser.add_argument("--prompt_template", type=str, default=DEFAULT_PROMPT_TEMPLATE)
parser.add_argument("--max_new_tokens", type=int, default=120)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=40)
parser.add_argument("--top_p", type=float, default=0.9,
help="Nucleus sampling cutoff. Set 1.0 to disable.")
parser.add_argument("--repetition_penalty", type=float, default=1.3,
help="Penalty for re-emitting tokens already in the context. 1.0 = off.")
parser.add_argument("--no_repeat_ngram_size", type=int, default=3,
help="Block any n-gram of this size from appearing twice. 0 = off.")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"])
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args()
def banned_tokens_from_ngrams(generated: list[int], n: int) -> set[int]:
"""
For no-repeat-ngram blocking: given the tokens generated so far, return
the set of token ids that would close a previously-seen n-gram if emitted
next.
"""
if n <= 0 or len(generated) < n - 1:
return set()
prefix = tuple(generated[-(n - 1):])
banned: set[int] = set()
for i in range(len(generated) - n + 1):
ngram = tuple(generated[i : i + n - 1])
if ngram == prefix:
banned.add(generated[i + n - 1])
return banned
def generate(
model: GPT,
prompt_ids: list[int],
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
no_repeat_ngram_size: int,
eos_id: int | None,
device: str,
) -> list[int]:
"""
Custom sampling loop with repetition penalty, top-k, top-p (nucleus),
and no-repeat-ngram blocking. Returns the list of newly generated token
ids (does not include the prompt).
"""
block_size = model.config.block_size
context = list(prompt_ids)
generated: list[int] = []
for _ in range(max_new_tokens):
idx_cond = context if len(context) <= block_size else context[-block_size:]
x = torch.tensor([idx_cond], dtype=torch.long, device=device)
with torch.no_grad():
logits, _ = model(x)
logits = logits[:, -1, :].squeeze(0).float()
if repetition_penalty != 1.0 and len(context) > 0:
seen = torch.tensor(list(set(context)), dtype=torch.long, device=device)
scores = logits[seen]
scores = torch.where(scores > 0, scores / repetition_penalty, scores * repetition_penalty)
logits[seen] = scores
if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size - 1:
banned = banned_tokens_from_ngrams(generated, no_repeat_ngram_size)
for tok_id in banned:
logits[tok_id] = -float("inf")
logits = logits / max(temperature, 1e-5)
if top_k is not None and top_k > 0:
k = min(top_k, logits.size(-1))
top_vals, _ = torch.topk(logits, k)
cutoff = top_vals[-1]
logits = torch.where(logits < cutoff, torch.full_like(logits, -float("inf")), logits)
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs_sorted = F.softmax(sorted_logits, dim=-1)
cumulative = torch.cumsum(probs_sorted, dim=-1)
mask = cumulative > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits = sorted_logits.masked_fill(mask, -float("inf"))
logits = torch.full_like(logits, -float("inf"))
logits.scatter_(0, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
if not torch.isfinite(probs).all() or probs.sum() <= 0:
next_tok = int(logits.argmax().item())
else:
next_tok = int(torch.multinomial(probs, num_samples=1).item())
if eos_id is not None and next_tok == eos_id:
break
context.append(next_tok)
generated.append(next_tok)
return generated
def main() -> None:
args = parse_args()
if args.seed:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
tok = TextTokenizer(args.tokenizer)
ckpt = torch.load(args.checkpoint, map_location=device, weights_only=False)
config = GPTConfig(**ckpt["config"])
config.dropout = 0.0
config.gradient_checkpointing = False
model = GPT(config)
state_dict = strip_compile_prefix(ckpt["model"])
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
if tok.vocab_size != model.config.vocab_size:
raise RuntimeError(
f"Tokenizer vocab_size {tok.vocab_size} != model vocab_size {model.config.vocab_size}. "
"Use the same tokenizer.json that was used for pretrain/SFT."
)
prompt = args.prompt_template.format(passage=args.text.strip())
prompt_ids = tok.encode(prompt, add_bos=True, add_eos=False)
max_prompt_len = model.config.block_size - args.max_new_tokens - 1
if max_prompt_len < 16:
raise RuntimeError(
f"max_new_tokens={args.max_new_tokens} is too large for block_size={model.config.block_size}."
)
if len(prompt_ids) > max_prompt_len:
bos = [prompt_ids[0]] if prompt_ids and prompt_ids[0] == tok.bos_id else []
tail = prompt_ids[-(max_prompt_len - len(bos)) :]
prompt_ids = bos + tail
new_ids = generate(
model=model,
prompt_ids=prompt_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
no_repeat_ngram_size=args.no_repeat_ngram_size,
eos_id=tok.eos_id,
device=device,
)
answer = tok.decode(new_ids, skip_special_tokens=True).strip()
print(answer)
if __name__ == "__main__":
main()