File size: 4,483 Bytes
ff5d275 c7517b8 ff5d275 c7517b8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | import json
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer
def load_tinylm(model_dir, device="cpu"):
# Load config
with open(f"{model_dir}/config.json") as f:
config = json.load(f)
VOCAB_SIZE = config["vocab_size"]
EMBED_RANK = config["embed_rank"]
D_MODEL = config["d_model"]
N_HEADS = config["n_heads"]
FFN_DIM = config["ffn_dim"]
N_LAYERS = config["n_layers"]
MAX_SEQ_LEN = config["max_seq_len"]
DROPOUT = config["dropout"]
class FactoredEmbedding(nn.Module):
def __init__(self, vocab_size, rank, d_model):
super().__init__()
self.in_proj = nn.Embedding(vocab_size, rank)
self.out_proj = nn.Linear(rank, d_model, bias=False)
def forward(self, x):
return self.out_proj(self.in_proj(x))
class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(D_MODEL)
self.attn = nn.MultiheadAttention(D_MODEL, N_HEADS, dropout=DROPOUT, batch_first=True)
self.ln2 = nn.LayerNorm(D_MODEL)
self.ffn = nn.Sequential(
nn.Linear(D_MODEL, FFN_DIM),
nn.GELU(),
nn.Linear(FFN_DIM, D_MODEL),
nn.Dropout(DROPOUT),
)
def forward(self, x, attn_mask=None, key_padding_mask=None):
x_norm = self.ln1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=True)
x = x + attn_out
x = x + self.ffn(self.ln2(x))
return x
class TinyLM(nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = FactoredEmbedding(VOCAB_SIZE, EMBED_RANK, D_MODEL)
self.pos_emb = nn.Embedding(MAX_SEQ_LEN, D_MODEL)
self.drop = nn.Dropout(DROPOUT)
self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)])
self.ln_final = nn.LayerNorm(D_MODEL)
self.head_down = nn.Linear(D_MODEL, EMBED_RANK, bias=False)
self.head_vocab = nn.Linear(EMBED_RANK, VOCAB_SIZE, bias=False)
self.head_vocab.weight = nn.Parameter(self.tok_emb.in_proj.weight)
def forward(self, idx):
B, T = idx.shape
if T > MAX_SEQ_LEN:
idx = idx[:, :MAX_SEQ_LEN]
T = idx.shape[1]
positions = torch.arange(T, device=idx.device).unsqueeze(0)
x = self.drop(self.tok_emb(idx) + self.pos_emb(positions))
mask = nn.Transformer.generate_square_subsequent_mask(T, device=idx.device)
for block in self.blocks:
x = block(x, attn_mask=mask)
x = self.ln_final(x)
x = self.head_down(x)
return self.head_vocab(x)
# Build and load weights
model = TinyLM().to(device)
state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=device)
model.load_state_dict(state_dict)
model.eval()
# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer, config
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.1, top_k=25, device="cpu"):
MAX_SEQ_LEN = model.pos_emb.num_embeddings
model.eval()
ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
for _ in range(max_new_tokens):
idx_cond = ids[:, -MAX_SEQ_LEN:]
logits = model(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
values, _ = torch.topk(logits, top_k)
logits[logits < values[:, -1:]] = -float("inf")
probs = torch.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
if next_id.item() == tokenizer.eos_token_id:
break
ids = torch.cat([ids, next_id], dim=1)
return tokenizer.decode(ids[0], skip_special_tokens=True)
if __name__ == "__main__":
model, tokenizer, config = load_tinylm("./tinylm")
print("Model loaded!")
print("Use 'module.generate(model, tokenizer, \"Once upon a time\")' to generate.")
|