CrimsonGCLM / sample.py
AGofficial's picture
Upload 9 files
d7fbd1f verified
import os
import torch
import torch.nn.functional as F
import tiktoken
from model import Crimson, MAX_SEQ_LEN
MODEL_PATH = "crimson_8.5M.pt"
VOCAB_PATH = "vocab_map.pt"
TOKENIZER_NAME = "gpt2"
PAD_ID = 0
SEP_ID = 1
EOS_ID = 2
OFFSET = 3
def load_model_and_vocab(device):
if not os.path.exists(VOCAB_PATH):
return None, None, None
vocab_data = torch.load(VOCAB_PATH, map_location="cpu")
used_tokens = vocab_data["used_tokens"]
id2new = vocab_data["id2new"]
vocab_size = len(used_tokens) + OFFSET
model = Crimson(vocab_size).to(device)
if os.path.exists(MODEL_PATH):
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
else:
return None, None, None
return model, used_tokens, id2new
@torch.no_grad()
def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50):
model.eval()
raw_ids = tokenizer.encode(prompt)
input_ids = [id2new[rid] for rid in raw_ids if rid in id2new]
if not input_ids: input_ids = [PAD_ID]
x = torch.tensor([input_ids], dtype=torch.long, device=device)
generated = []
for _ in range(max_new_tokens):
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
logits = model(ctx)
next_token_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
if idx == EOS_ID: break
x = torch.cat((x, next_token), dim=1)
generated.append(idx)
return tokenizer.decode([used_tokens[i - OFFSET] for i in generated if i >= OFFSET])
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model, used_tokens, id2new = load_model_and_vocab(device)
enc = tiktoken.get_encoding(TOKENIZER_NAME)
if model:
newline_id = id2new.get(enc.encode("\n")[0], OFFSET)
while True:
x = torch.tensor([[newline_id]], dtype=torch.long, device=device)
with torch.no_grad():
for _ in range(900):
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
logits = model(ctx)
logits = logits[:, -1, :] / 0.8
v, _ = torch.topk(logits, min(50, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
x = torch.cat((x, next_token), dim=1)
if idx == EOS_ID: break
if idx >= OFFSET:
print(enc.decode([used_tokens[idx - OFFSET]]), end="", flush=True)
if input("\nPress [Enter] to generate again, or type 'exit': ").lower() == 'exit': break