import json import os import math import torch import torch.nn as nn from torch.nn import functional as F class AttoLM(nn.Module): def __init__(self, vocab_size, embd_dim, context_len): super().__init__() self.vocab_size = vocab_size self.embd_dim = embd_dim self.context_len = context_len self.embedding = nn.Embedding(vocab_size, embd_dim) self.mix = nn.Linear(context_len * embd_dim, embd_dim, bias=False) def forward(self, idx): x = self.embedding(idx) x = x.reshape(x.size(0), -1) x = self.mix(x) x = torch.tanh(x) logits = x @ self.embedding.weight.T return logits def load_model(path): with open(path, "r") as f: data = json.load(f) cfg = data["config"] model = AttoLM(cfg["vocab_size"], cfg["embd_dim"], cfg["context_len"]) emb = torch.tensor(data["weights"]["embedding"]) mix = torch.tensor(data["weights"]["mix"]) model.embedding.weight.data.copy_(emb) model.mix.weight.data.copy_(mix) itos = {int(k): v for k, v in data["vocab"].items()} stoi = {v: k for k, v in itos.items()} return model, itos, stoi, cfg @torch.no_grad() def generate(model, stoi, itos, ctx_len, prompt=" ", length=200, temperature=0.8): model.eval() tokens = [stoi.get(c, 0) for c in prompt] # pad if prompt shorter than context while len(tokens) < ctx_len: tokens = [0] + tokens for _ in range(length): inp = torch.tensor(tokens[-ctx_len:], dtype=torch.long).unsqueeze(0) logits = model(inp) probs = F.softmax(logits / temperature, dim=-1) nxt = torch.multinomial(probs, 1).item() tokens.append(nxt) return "".join(itos.get(t, "?") for t in tokens[ctx_len:]) if __name__ == "__main__": models_dir = "models" for name in ["atto-64", "atto-128", "atto-256", "atto-512", "atto-1024", "atto-2048", "atto-4096", "atto-8192", "atto-16384"]: path = os.path.join(models_dir, f"{name}.json") if not os.path.exists(path): print(f" {path} not found, skipping") continue model, itos, stoi, cfg = load_model(path) params = sum(p.numel() for p in model.parameters()) print(f"\n{'='*60}") print(f" {name} | {params} params | embd={cfg['embd_dim']} ctx={cfg['context_len']} vocab={cfg['vocab_size']}") print(f"{'='*60}") for prompt in [" the ", " to be", " Ham"]: clean_prompt = prompt.strip() # only use chars the model knows usable = "".join(c for c in prompt if c in stoi) if not usable: usable = " " text = generate(model, stoi, itos, cfg["context_len"], usable, length=150, temperature=0.8) print(f' prompt="{clean_prompt}":') print(f" {text[:150]}") print()