import json import torch import torch.nn as nn from transformers import GPT2Tokenizer def load_tinylm(model_dir, device="cpu"): # Load config with open(f"{model_dir}/config.json") as f: config = json.load(f) VOCAB_SIZE = config["vocab_size"] EMBED_RANK = config["embed_rank"] D_MODEL = config["d_model"] N_HEADS = config["n_heads"] FFN_DIM = config["ffn_dim"] N_LAYERS = config["n_layers"] MAX_SEQ_LEN = config["max_seq_len"] DROPOUT = config["dropout"] class FactoredEmbedding(nn.Module): def __init__(self, vocab_size, rank, d_model): super().__init__() self.in_proj = nn.Embedding(vocab_size, rank) self.out_proj = nn.Linear(rank, d_model, bias=False) def forward(self, x): return self.out_proj(self.in_proj(x)) class TransformerBlock(nn.Module): def __init__(self): super().__init__() self.ln1 = nn.LayerNorm(D_MODEL) self.attn = nn.MultiheadAttention(D_MODEL, N_HEADS, dropout=DROPOUT, batch_first=True) self.ln2 = nn.LayerNorm(D_MODEL) self.ffn = nn.Sequential( nn.Linear(D_MODEL, FFN_DIM), nn.GELU(), nn.Linear(FFN_DIM, D_MODEL), nn.Dropout(DROPOUT), ) def forward(self, x, attn_mask=None, key_padding_mask=None): x_norm = self.ln1(x) attn_out, _ = self.attn(x_norm, x_norm, x_norm, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=True) x = x + attn_out x = x + self.ffn(self.ln2(x)) return x class TinyLM(nn.Module): def __init__(self): super().__init__() self.tok_emb = FactoredEmbedding(VOCAB_SIZE, EMBED_RANK, D_MODEL) self.pos_emb = nn.Embedding(MAX_SEQ_LEN, D_MODEL) self.drop = nn.Dropout(DROPOUT) self.blocks = nn.ModuleList([TransformerBlock() for _ in range(N_LAYERS)]) self.ln_final = nn.LayerNorm(D_MODEL) self.head_down = nn.Linear(D_MODEL, EMBED_RANK, bias=False) self.head_vocab = nn.Linear(EMBED_RANK, VOCAB_SIZE, bias=False) self.head_vocab.weight = nn.Parameter(self.tok_emb.in_proj.weight) def forward(self, idx): B, T = idx.shape if T > MAX_SEQ_LEN: idx = idx[:, :MAX_SEQ_LEN] T = idx.shape[1] positions = torch.arange(T, device=idx.device).unsqueeze(0) x = self.drop(self.tok_emb(idx) + self.pos_emb(positions)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=idx.device) for block in self.blocks: x = block(x, attn_mask=mask) x = self.ln_final(x) x = self.head_down(x) return self.head_vocab(x) # Build and load weights model = TinyLM().to(device) state_dict = torch.load(f"{model_dir}/pytorch_model.bin", map_location=device) model.load_state_dict(state_dict) model.eval() # Load tokenizer tokenizer = GPT2Tokenizer.from_pretrained(model_dir) tokenizer.pad_token = tokenizer.eos_token return model, tokenizer, config def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.1, top_k=25, device="cpu"): MAX_SEQ_LEN = model.pos_emb.num_embeddings model.eval() ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) with torch.no_grad(): for _ in range(max_new_tokens): idx_cond = ids[:, -MAX_SEQ_LEN:] logits = model(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: values, _ = torch.topk(logits, top_k) logits[logits < values[:, -1:]] = -float("inf") probs = torch.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) if next_id.item() == tokenizer.eos_token_id: break ids = torch.cat([ids, next_id], dim=1) return tokenizer.decode(ids[0], skip_special_tokens=True) if __name__ == "__main__": model, tokenizer, config = load_tinylm("./tinylm") print("Model loaded!") print("Use 'module.generate(model, tokenizer, \"Once upon a time\")' to generate.")