import os import torch import torch.nn.functional as F import tiktoken from model import Crimson, MAX_SEQ_LEN MODEL_PATH = "crimson_8.5M.pt" VOCAB_PATH = "vocab_map.pt" TOKENIZER_NAME = "gpt2" PAD_ID = 0 SEP_ID = 1 EOS_ID = 2 OFFSET = 3 def load_model_and_vocab(device): if not os.path.exists(VOCAB_PATH): return None, None, None vocab_data = torch.load(VOCAB_PATH, map_location="cpu") used_tokens = vocab_data["used_tokens"] id2new = vocab_data["id2new"] vocab_size = len(used_tokens) + OFFSET model = Crimson(vocab_size).to(device) if os.path.exists(MODEL_PATH): model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.eval() else: return None, None, None return model, used_tokens, id2new @torch.no_grad() def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50): model.eval() raw_ids = tokenizer.encode(prompt) input_ids = [id2new[rid] for rid in raw_ids if rid in id2new] if not input_ids: input_ids = [PAD_ID] x = torch.tensor([input_ids], dtype=torch.long, device=device) generated = [] for _ in range(max_new_tokens): ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x logits = model(ctx) next_token_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() if idx == EOS_ID: break x = torch.cat((x, next_token), dim=1) generated.append(idx) return tokenizer.decode([used_tokens[i - OFFSET] for i in generated if i >= OFFSET]) if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" model, used_tokens, id2new = load_model_and_vocab(device) enc = tiktoken.get_encoding(TOKENIZER_NAME) if model: newline_id = id2new.get(enc.encode("\n")[0], OFFSET) while True: x = torch.tensor([[newline_id]], dtype=torch.long, device=device) with torch.no_grad(): for _ in range(900): ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x logits = model(ctx) logits = logits[:, -1, :] / 0.8 v, _ = torch.topk(logits, min(50, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() x = torch.cat((x, next_token), dim=1) if idx == EOS_ID: break if idx >= OFFSET: print(enc.decode([used_tokens[idx - OFFSET]]), end="", flush=True) if input("\nPress [Enter] to generate again, or type 'exit': ").lower() == 'exit': break