atto / sample.py
AGofficial's picture
Upload 13 files
53d5954 verified
import json
import os
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class AttoLM(nn.Module):
def __init__(self, vocab_size, embd_dim, context_len):
super().__init__()
self.vocab_size = vocab_size
self.embd_dim = embd_dim
self.context_len = context_len
self.embedding = nn.Embedding(vocab_size, embd_dim)
self.mix = nn.Linear(context_len * embd_dim, embd_dim, bias=False)
def forward(self, idx):
x = self.embedding(idx)
x = x.reshape(x.size(0), -1)
x = self.mix(x)
x = torch.tanh(x)
logits = x @ self.embedding.weight.T
return logits
def load_model(path):
with open(path, "r") as f:
data = json.load(f)
cfg = data["config"]
model = AttoLM(cfg["vocab_size"], cfg["embd_dim"], cfg["context_len"])
emb = torch.tensor(data["weights"]["embedding"])
mix = torch.tensor(data["weights"]["mix"])
model.embedding.weight.data.copy_(emb)
model.mix.weight.data.copy_(mix)
itos = {int(k): v for k, v in data["vocab"].items()}
stoi = {v: k for k, v in itos.items()}
return model, itos, stoi, cfg
@torch.no_grad()
def generate(model, stoi, itos, ctx_len, prompt=" ", length=200, temperature=0.8):
model.eval()
tokens = [stoi.get(c, 0) for c in prompt]
# pad if prompt shorter than context
while len(tokens) < ctx_len:
tokens = [0] + tokens
for _ in range(length):
inp = torch.tensor(tokens[-ctx_len:], dtype=torch.long).unsqueeze(0)
logits = model(inp)
probs = F.softmax(logits / temperature, dim=-1)
nxt = torch.multinomial(probs, 1).item()
tokens.append(nxt)
return "".join(itos.get(t, "?") for t in tokens[ctx_len:])
if __name__ == "__main__":
models_dir = "models"
for name in ["atto-64", "atto-128", "atto-256", "atto-512", "atto-1024", "atto-2048", "atto-4096", "atto-8192", "atto-16384"]:
path = os.path.join(models_dir, f"{name}.json")
if not os.path.exists(path):
print(f" {path} not found, skipping")
continue
model, itos, stoi, cfg = load_model(path)
params = sum(p.numel() for p in model.parameters())
print(f"\n{'='*60}")
print(f" {name} | {params} params | embd={cfg['embd_dim']} ctx={cfg['context_len']} vocab={cfg['vocab_size']}")
print(f"{'='*60}")
for prompt in [" the ", " to be", " Ham"]:
clean_prompt = prompt.strip()
# only use chars the model knows
usable = "".join(c for c in prompt if c in stoi)
if not usable:
usable = " "
text = generate(model, stoi, itos, cfg["context_len"], usable, length=150, temperature=0.8)
print(f' prompt="{clean_prompt}":')
print(f" {text[:150]}")
print()