dpe1's picture
download
raw
8.59 kB
import torch
import torch.nn.functional as F
from src.tokenizer import CharTokenizer
from src.model import TinyReasonerModel
from src.capabilities import dispatch_capability
class Sampler:
def __init__(self, model, tokenizer, device="cpu"):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.to(device)
self.model.eval()
def sample(self, prompt, max_len=512, temperature=1.0, top_k=0, stop_at_eos=True):
tokens = self.tokenizer.encode(prompt)
input_ids = torch.tensor([tokens]).long().to(self.device)
generated = tokens
hidden = None
# Pre-fill hidden state
logits, hidden = self.model(input_ids, hidden)
while len(generated) < max_len:
last_logit = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
if temperature == 0:
next_token = torch.argmax(last_logit, dim=-1).item()
else:
probs = F.softmax(last_logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
generated.append(next_token)
token_str = self.tokenizer.itos.get(next_token, "")
if token_str in ["[DEFINE]", "[SYMPY]"]:
cap_type = token_str[1:-1]
payload_tokens = []
found_stop = False
while len(generated) < max_len:
input_ids = torch.tensor([[generated[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
# Inside capability call, we usually want greedy or same temp
last_logit = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
if temperature == 0:
nt = torch.argmax(last_logit, dim=-1).item()
else:
nt = torch.multinomial(F.softmax(last_logit, dim=-1), num_samples=1).item()
generated.append(nt)
if nt == self.tokenizer.stop_token_id:
found_stop = True
break
payload_tokens.append(nt)
if found_stop:
payload = self.tokenizer.decode(payload_tokens)
result = dispatch_capability(cap_type, payload)
result_tokens = self.tokenizer.encode(result)
result_tokens.append(self.tokenizer.stop_token_id)
for r_token in result_tokens:
generated.append(r_token)
input_ids = torch.tensor([[generated[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
else:
break
elif next_token == self.tokenizer.eos_token_id:
if stop_at_eos:
break
else:
input_ids = torch.tensor([[generated[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
else:
input_ids = torch.tensor([[generated[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
return self.tokenizer.decode(generated)
def grpo_rollout(self, prompt, num_rollouts=8, max_len=512, temperature=1.0, noise_std=0.0):
"""Perform multiple rollouts for GRPO.
Returns completions, log_probs, and mask for model-generated tokens.
"""
prompt_tokens = self.tokenizer.encode(prompt)
all_completions = []
all_log_probs = []
all_masks = [] # 1 for model generated, 0 for prompt or tool injected
for _ in range(num_rollouts):
generated_tokens = list(prompt_tokens)
log_probs = [] # Only for tokens where mask == 1
mask = [0] * (len(prompt_tokens) - 1) # Mask for prompt tokens (except BOS)
hidden = None
input_ids = torch.tensor([prompt_tokens]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
curr_len = len(generated_tokens)
while curr_len < max_len:
if noise_std > 0:
h, c = hidden
h = h + torch.randn_like(h) * noise_std
c = c + torch.randn_like(c) * noise_std
hidden = (h, c)
last_logit = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
probs = F.softmax(last_logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
lp = torch.log(probs[0, next_token] + 1e-10)
generated_tokens.append(next_token)
log_probs.append(lp)
mask.append(1) # Model generated
token_str = self.tokenizer.itos.get(next_token, "")
if token_str in ["[DEFINE]", "[SYMPY]"]:
cap_type = token_str[1:-1]
payload_tokens = []
found_stop = False
while len(generated_tokens) < max_len:
input_ids = torch.tensor([[generated_tokens[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
if noise_std > 0:
h, c = hidden
h = h + torch.randn_like(h) * noise_std
c = c + torch.randn_like(c) * noise_std
hidden = (h, c)
last_logit = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
probs = F.softmax(last_logit, dim=-1)
nt = torch.multinomial(probs, num_samples=1).item()
lp = torch.log(probs[0, nt] + 1e-10)
generated_tokens.append(nt)
# SPECIAL TOKENS handled by tokenizer might be multi-char but they are single tokens here.
# Wait, the tokenizer.encode handles [DEFINE] as one token.
log_probs.append(lp)
mask.append(1) # Model generated
if nt == self.tokenizer.stop_token_id:
found_stop = True
break
payload_tokens.append(nt)
if found_stop:
payload = self.tokenizer.decode(payload_tokens)
result = dispatch_capability(cap_type, payload)
result_tokens = self.tokenizer.encode(result)
result_tokens.append(self.tokenizer.stop_token_id)
for r_token in result_tokens:
generated_tokens.append(r_token)
mask.append(0) # Tool injected
input_ids = torch.tensor([[generated_tokens[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
else:
break
elif next_token == self.tokenizer.eos_token_id:
break
else:
input_ids = torch.tensor([[generated_tokens[-1]]]).long().to(self.device)
logits, hidden = self.model(input_ids, hidden)
curr_len = len(generated_tokens)
all_completions.append(self.tokenizer.decode(generated_tokens))
all_log_probs.append(torch.stack(log_probs))
all_masks.append(torch.tensor(mask).to(self.device))
return all_completions, all_log_probs, all_masks
if __name__ == "__main__":
import os
import sys
tokenizer = CharTokenizer()
model = TinyReasonerModel(tokenizer.vocab_size)
model_path = "models/sft_model.pt"
if len(sys.argv) > 1:
model_path = sys.argv[1]
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location="cpu"))
print(f"Loaded model from {model_path}")
else:
print(f"Model {model_path} not found, using random weights.")
sampler = Sampler(model, tokenizer)
prompts = [
"[BOS]What is the definition of apple?",
"[BOS]What is the sum of 10 and 20?"
]
for p in prompts:
print(f"\nPrompt: {p}")
output = sampler.sample(p, max_len=256, temperature=0)
print(f"Output: {output}")

Xet Storage Details

Size:
8.59 kB
·
Xet hash:
d690b99ae5798998fb2e519cf4b5d91c58458fe0e007179f5e91b3872804dc6b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.