| import json |
| import os |
| import math |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
|
|
| class AttoLM(nn.Module): |
| def __init__(self, vocab_size, embd_dim, context_len): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.embd_dim = embd_dim |
| self.context_len = context_len |
| self.embedding = nn.Embedding(vocab_size, embd_dim) |
| self.mix = nn.Linear(context_len * embd_dim, embd_dim, bias=False) |
|
|
| def forward(self, idx): |
| x = self.embedding(idx) |
| x = x.reshape(x.size(0), -1) |
| x = self.mix(x) |
| x = torch.tanh(x) |
| logits = x @ self.embedding.weight.T |
| return logits |
|
|
|
|
| def load_model(path): |
| with open(path, "r") as f: |
| data = json.load(f) |
| cfg = data["config"] |
| model = AttoLM(cfg["vocab_size"], cfg["embd_dim"], cfg["context_len"]) |
|
|
| emb = torch.tensor(data["weights"]["embedding"]) |
| mix = torch.tensor(data["weights"]["mix"]) |
| model.embedding.weight.data.copy_(emb) |
| model.mix.weight.data.copy_(mix) |
|
|
| itos = {int(k): v for k, v in data["vocab"].items()} |
| stoi = {v: k for k, v in itos.items()} |
| return model, itos, stoi, cfg |
|
|
|
|
| @torch.no_grad() |
| def generate(model, stoi, itos, ctx_len, prompt=" ", length=200, temperature=0.8): |
| model.eval() |
| tokens = [stoi.get(c, 0) for c in prompt] |
| |
| while len(tokens) < ctx_len: |
| tokens = [0] + tokens |
|
|
| for _ in range(length): |
| inp = torch.tensor(tokens[-ctx_len:], dtype=torch.long).unsqueeze(0) |
| logits = model(inp) |
| probs = F.softmax(logits / temperature, dim=-1) |
| nxt = torch.multinomial(probs, 1).item() |
| tokens.append(nxt) |
|
|
| return "".join(itos.get(t, "?") for t in tokens[ctx_len:]) |
|
|
|
|
| if __name__ == "__main__": |
| models_dir = "models" |
|
|
| for name in ["atto-64", "atto-128", "atto-256", "atto-512", "atto-1024", "atto-2048", "atto-4096", "atto-8192", "atto-16384"]: |
| path = os.path.join(models_dir, f"{name}.json") |
| if not os.path.exists(path): |
| print(f" {path} not found, skipping") |
| continue |
|
|
| model, itos, stoi, cfg = load_model(path) |
| params = sum(p.numel() for p in model.parameters()) |
|
|
| print(f"\n{'='*60}") |
| print(f" {name} | {params} params | embd={cfg['embd_dim']} ctx={cfg['context_len']} vocab={cfg['vocab_size']}") |
| print(f"{'='*60}") |
|
|
| for prompt in [" the ", " to be", " Ham"]: |
| clean_prompt = prompt.strip() |
| |
| usable = "".join(c for c in prompt if c in stoi) |
| if not usable: |
| usable = " " |
| text = generate(model, stoi, itos, cfg["context_len"], usable, length=150, temperature=0.8) |
| print(f' prompt="{clean_prompt}":') |
| print(f" {text[:150]}") |
| print() |
|
|