File size: 3,197 Bytes
d7fbd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import torch
import torch.nn.functional as F
import tiktoken
from model import Crimson, MAX_SEQ_LEN

MODEL_PATH = "crimson_8.5M.pt"
VOCAB_PATH = "vocab_map.pt"
TOKENIZER_NAME = "gpt2"

PAD_ID = 0
SEP_ID = 1
EOS_ID = 2
OFFSET = 3

def load_model_and_vocab(device):
    if not os.path.exists(VOCAB_PATH):
        return None, None, None
    vocab_data = torch.load(VOCAB_PATH, map_location="cpu")
    used_tokens = vocab_data["used_tokens"]
    id2new = vocab_data["id2new"]
    vocab_size = len(used_tokens) + OFFSET
    model = Crimson(vocab_size).to(device)
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        model.eval()
    else:
        return None, None, None
    return model, used_tokens, id2new

@torch.no_grad()
def generate(model, prompt, tokenizer, id2new, used_tokens, device, max_new_tokens=200, temperature=0.8, top_k=50):
    model.eval()
    raw_ids = tokenizer.encode(prompt)
    input_ids = [id2new[rid] for rid in raw_ids if rid in id2new]
    if not input_ids: input_ids = [PAD_ID]
    x = torch.tensor([input_ids], dtype=torch.long, device=device)
    generated = []
    for _ in range(max_new_tokens):
        ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
        logits = model(ctx)
        next_token_logits = logits[:, -1, :] / temperature
        if top_k is not None:
            v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
            next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = next_token.item()
        if idx == EOS_ID: break
        x = torch.cat((x, next_token), dim=1)
        generated.append(idx)
    return tokenizer.decode([used_tokens[i - OFFSET] for i in generated if i >= OFFSET])

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    model, used_tokens, id2new = load_model_and_vocab(device)
    enc = tiktoken.get_encoding(TOKENIZER_NAME)
    if model:
        newline_id = id2new.get(enc.encode("\n")[0], OFFSET)
        while True:
            x = torch.tensor([[newline_id]], dtype=torch.long, device=device)
            with torch.no_grad():
                for _ in range(900):
                    ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
                    logits = model(ctx)
                    logits = logits[:, -1, :] / 0.8
                    v, _ = torch.topk(logits, min(50, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                    probs = F.softmax(logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                    idx = next_token.item()
                    x = torch.cat((x, next_token), dim=1)
                    if idx == EOS_ID: break
                    if idx >= OFFSET:
                        print(enc.decode([used_tokens[idx - OFFSET]]), end="", flush=True)
            if input("\nPress [Enter] to generate again, or type 'exit': ").lower() == 'exit': break