|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import tiktoken |
|
|
from model import Crimson, MAX_SEQ_LEN |
|
|
|
|
|
MODEL_PATH = "crimson_8.5M.pt" |
|
|
VOCAB_PATH = "vocab_map.pt" |
|
|
TOKENIZER_NAME = "gpt2" |
|
|
|
|
|
PAD_ID = 0 |
|
|
SEP_ID = 1 |
|
|
EOS_ID = 2 |
|
|
OFFSET = 3 |
|
|
|
|
|
def load_model_and_vocab(device): |
|
|
if not os.path.exists(VOCAB_PATH): |
|
|
return None, None, None |
|
|
vocab_data = torch.load(VOCAB_PATH, map_location="cpu") |
|
|
used_tokens = vocab_data["used_tokens"] |
|
|
id2new = vocab_data["id2new"] |
|
|
vocab_size = len(used_tokens) + OFFSET |
|
|
model = Crimson(vocab_size).to(device) |
|
|
if os.path.exists(MODEL_PATH): |
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) |
|
|
model.eval() |
|
|
else: |
|
|
return None, None, None |
|
|
return model, used_tokens, id2new |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50): |
|
|
model.eval() |
|
|
raw_ids = tokenizer.encode(prompt) |
|
|
input_ids = [id2new[rid] for rid in raw_ids if rid in id2new] |
|
|
if not input_ids: input_ids = [PAD_ID] |
|
|
x = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
generated = [] |
|
|
for _ in range(max_new_tokens): |
|
|
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
|
|
logits = model(ctx) |
|
|
next_token_logits = logits[:, -1, :] / temperature |
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
|
|
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
idx = next_token.item() |
|
|
if idx == EOS_ID: break |
|
|
x = torch.cat((x, next_token), dim=1) |
|
|
generated.append(idx) |
|
|
return tokenizer.decode([used_tokens[i - OFFSET] for i in generated if i >= OFFSET]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
model, used_tokens, id2new = load_model_and_vocab(device) |
|
|
enc = tiktoken.get_encoding(TOKENIZER_NAME) |
|
|
if model: |
|
|
newline_id = id2new.get(enc.encode("\n")[0], OFFSET) |
|
|
while True: |
|
|
x = torch.tensor([[newline_id]], dtype=torch.long, device=device) |
|
|
with torch.no_grad(): |
|
|
for _ in range(900): |
|
|
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
|
|
logits = model(ctx) |
|
|
logits = logits[:, -1, :] / 0.8 |
|
|
v, _ = torch.topk(logits, min(50, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
idx = next_token.item() |
|
|
x = torch.cat((x, next_token), dim=1) |
|
|
if idx == EOS_ID: break |
|
|
if idx >= OFFSET: |
|
|
print(enc.decode([used_tokens[idx - OFFSET]]), end="", flush=True) |
|
|
if input("\nPress [Enter] to generate again, or type 'exit': ").lower() == 'exit': break |
|
|
|