File size: 2,891 Bytes
53d5954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import os
import math
import torch
import torch.nn as nn
from torch.nn import functional as F


class AttoLM(nn.Module):
    def __init__(self, vocab_size, embd_dim, context_len):
        super().__init__()
        self.vocab_size = vocab_size
        self.embd_dim = embd_dim
        self.context_len = context_len
        self.embedding = nn.Embedding(vocab_size, embd_dim)
        self.mix = nn.Linear(context_len * embd_dim, embd_dim, bias=False)

    def forward(self, idx):
        x = self.embedding(idx)
        x = x.reshape(x.size(0), -1)
        x = self.mix(x)
        x = torch.tanh(x)
        logits = x @ self.embedding.weight.T
        return logits


def load_model(path):
    with open(path, "r") as f:
        data = json.load(f)
    cfg = data["config"]
    model = AttoLM(cfg["vocab_size"], cfg["embd_dim"], cfg["context_len"])

    emb = torch.tensor(data["weights"]["embedding"])
    mix = torch.tensor(data["weights"]["mix"])
    model.embedding.weight.data.copy_(emb)
    model.mix.weight.data.copy_(mix)

    itos = {int(k): v for k, v in data["vocab"].items()}
    stoi = {v: k for k, v in itos.items()}
    return model, itos, stoi, cfg


@torch.no_grad()
def generate(model, stoi, itos, ctx_len, prompt=" ", length=200, temperature=0.8):
    model.eval()
    tokens = [stoi.get(c, 0) for c in prompt]
    # pad if prompt shorter than context
    while len(tokens) < ctx_len:
        tokens = [0] + tokens

    for _ in range(length):
        inp = torch.tensor(tokens[-ctx_len:], dtype=torch.long).unsqueeze(0)
        logits = model(inp)
        probs = F.softmax(logits / temperature, dim=-1)
        nxt = torch.multinomial(probs, 1).item()
        tokens.append(nxt)

    return "".join(itos.get(t, "?") for t in tokens[ctx_len:])


if __name__ == "__main__":
    models_dir = "models"

    for name in ["atto-64", "atto-128", "atto-256", "atto-512", "atto-1024", "atto-2048", "atto-4096", "atto-8192", "atto-16384"]:
        path = os.path.join(models_dir, f"{name}.json")
        if not os.path.exists(path):
            print(f"  {path} not found, skipping")
            continue

        model, itos, stoi, cfg = load_model(path)
        params = sum(p.numel() for p in model.parameters())

        print(f"\n{'='*60}")
        print(f"  {name}  |  {params} params  |  embd={cfg['embd_dim']}  ctx={cfg['context_len']}  vocab={cfg['vocab_size']}")
        print(f"{'='*60}")

        for prompt in [" the ", " to be", " Ham"]:
            clean_prompt = prompt.strip()
            # only use chars the model knows
            usable = "".join(c for c in prompt if c in stoi)
            if not usable:
                usable = " "
            text = generate(model, stoi, itos, cfg["context_len"], usable, length=150, temperature=0.8)
            print(f'  prompt="{clean_prompt}":')
            print(f"    {text[:150]}")
            print()