| | import json |
| | import torch |
| | import torch.nn as nn |
| | from transformers import GPT2Tokenizer |
| |
|
| |
|
| | def load_tinylm(model_dir, device="cpu"): |
| | |
| | with open(f"{model_dir}/config.json") as f: |
| | config = json.load(f) |
| |
|
| | VOCAB_SIZE = config["vocab_size"] |
| | EMBED_RANK = config["embed_rank"] |
| | D_MODEL = config["d_model"] |
| | N_HEADS = config["n_heads"] |
| | FFN_DIM = config["ffn_dim"] |
| | N_LAYERS = config["n_layers"] |
| | MAX_SEQ_LEN = config["max_seq_len"] |
| | DROPOUT = config["dropout"] |
| |
|
| | class FactoredEmbedding(nn.Module): |
| | def __init__(self, vocab_size, rank, d_model): |
| | super().__init__() |
| | self.in_proj = nn.Embedding(vocab_size, rank) |
| | self.out_proj = nn.Linear(rank, d_model, bias=False) |
| |
|
| | def forward(self, x): |
| | return self.out_proj(self.in_proj(x)) |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(D_MODEL) |
| | self.attn = nn.MultiheadAttention(D_MODEL, N_HEADS, dropout=DROPOUT, batch_first=True) |
| | self.ln2 = nn.LayerNorm(D_MODEL) |
| | self.ffn = nn.Sequential( |
| | nn.Linear(D_MODEL, FFN_DIM), |
| | nn.GELU(), |
| | nn.Linear(FFN_DIM, D_MODEL), |
| | nn.Dropout(DROPOUT), |
| | ) |
| |
|
| | def forward(self, x, attn_mask=None, key_padding_mask=None): |
| | x_norm = self.ln1(x) |
| | attn_out, _ = self.attn(x_norm, x_norm, x_norm, |
| | attn_mask=attn_mask, |
| | key_padding_mask=key_padding_mask, |
| | is_causal=True) |
| | x = x + attn_out |
| | x = x + self.ffn(self.ln2(x)) |
| | return x |
| |
|
| | class TinyLM(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.tok_emb = FactoredEmbedding(VOCAB_SIZE, EMBED_RANK, D_MODEL) |
| | self.pos_emb = nn.Embedding(MAX_SEQ_LEN, D_MODEL) |
| | self.drop = nn.Dropout(DROPOUT) |
| | self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)]) |
| | self.ln_final = nn.LayerNorm(D_MODEL) |
| | self.head_down = nn.Linear(D_MODEL, EMBED_RANK, bias=False) |
| | self.head_vocab = nn.Linear(EMBED_RANK, VOCAB_SIZE, bias=False) |
| | self.head_vocab.weight = nn.Parameter(self.tok_emb.in_proj.weight) |
| |
|
| | def forward(self, idx): |
| | B, T = idx.shape |
| | if T > MAX_SEQ_LEN: |
| | idx = idx[:, :MAX_SEQ_LEN] |
| | T = idx.shape[1] |
| | positions = torch.arange(T, device=idx.device).unsqueeze(0) |
| | x = self.drop(self.tok_emb(idx) + self.pos_emb(positions)) |
| | mask = nn.Transformer.generate_square_subsequent_mask(T, device=idx.device) |
| | for block in self.blocks: |
| | x = block(x, attn_mask=mask) |
| | x = self.ln_final(x) |
| | x = self.head_down(x) |
| | return self.head_vocab(x) |
| |
|
| | |
| | model = TinyLM().to(device) |
| | state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=device) |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| |
|
| | |
| | tokenizer = GPT2Tokenizer.from_pretrained(model_dir) |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | return model, tokenizer, config |
| |
|
| |
|
| | def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.1, top_k=25, device="cpu"): |
| | MAX_SEQ_LEN = model.pos_emb.num_embeddings |
| | model.eval() |
| | ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) |
| |
|
| | with torch.no_grad(): |
| | for _ in range(max_new_tokens): |
| | idx_cond = ids[:, -MAX_SEQ_LEN:] |
| | logits = model(idx_cond) |
| | logits = logits[:, -1, :] / temperature |
| | if top_k is not None: |
| | values, _ = torch.topk(logits, top_k) |
| | logits[logits < values[:, -1:]] = -float("inf") |
| | probs = torch.softmax(logits, dim=-1) |
| | next_id = torch.multinomial(probs, num_samples=1) |
| | if next_id.item() == tokenizer.eos_token_id: |
| | break |
| | ids = torch.cat([ids, next_id], dim=1) |
| |
|
| | return tokenizer.decode(ids[0], skip_special_tokens=True) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model, tokenizer, config = load_tinylm("./tinylm") |
| | print("Model loaded!") |
| | print("Use 'module.generate(model, tokenizer, \"Once upon a time\")' to generate.") |
| |
|