import torch import torch.nn as nn from torch.nn import functional as F import json import re class BPETokenizer: def __init__(self, model_type="gpt2"): import tiktoken self.enc = tiktoken.get_encoding(model_type) def encode(self, text): return self.enc.encode(text, allowed_special={'<|endoftext|>'}) def decode(self, ids): return self.enc.decode(ids) class MiniTransformer(nn.Module): def __init__(self, vocab_size, emb_dim=768, n_layers=12, n_heads=12, ctx_len=1024, dropout=0.1): super().__init__() self.ctx_len = ctx_len self.n_heads = n_heads self.emb_dim = emb_dim self.n_layers = n_layers self.token_embedding_table = nn.Embedding(vocab_size, emb_dim) self.position_embedding_table = nn.Embedding(ctx_len, emb_dim) self.drop = nn.Dropout(dropout) self.blocks = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=emb_dim, nhead=n_heads, dim_feedforward=emb_dim * 4, dropout=dropout, batch_first=True, norm_first=True, activation='gelu' ) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(emb_dim) self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False) def forward(self, idx, targets=None): device = idx.device B, T = idx.shape idx = torch.clamp(idx, 0, self.token_embedding_table.num_embeddings - 1) tok_emb = self.token_embedding_table(idx) pos_emb = self.position_embedding_table(torch.arange(T, device=device)) x = self.drop(tok_emb + pos_emb) mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool() for block in self.blocks: x = block(x, src_mask=mask, is_causal=True) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: targets = torch.clamp(targets, 0, self.lm_head.out_features - 1) B, T, C = logits.shape loss = F.cross_entropy(logits.view(B*T, C), targets.view(-1)) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, repetition_penalty=1.0): for _ in range(max_new_tokens): idx_cond = idx[:, -self.ctx_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if repetition_penalty != 1.0: for b in range(logits.shape[0]): for token_id in set(idx[b].tolist()): if logits[b, token_id] < 0: logits[b, token_id] *= repetition_penalty else: logits[b, token_id] /= repetition_penalty if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) if idx_next.item() == 50256: break return idx @classmethod def load(cls, path, device='cpu'): ckpt = torch.load(path, map_location=device, weights_only=False) state_dict = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt cfg = {'vocab_size': 50257, 'emb_dim': 1024, 'n_layers': 24, 'n_heads': 16, 'ctx_len': 1024} if isinstance(ckpt, dict) and 'config' in ckpt: cfg = ckpt['config'] model = cls(cfg['vocab_size'], cfg['emb_dim'], cfg['n_layers'], cfg['n_heads'], cfg['ctx_len']) new_state_dict = {} for k, v in state_dict.items(): name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v model.load_state_dict(new_state_dict, strict=False) model.to(device) model.eval() return model