Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import json | |
| import re | |
| class BPETokenizer: | |
| def __init__(self, model_type="gpt2"): | |
| import tiktoken | |
| self.enc = tiktoken.get_encoding(model_type) | |
| def encode(self, text): | |
| return self.enc.encode(text, allowed_special={'<|endoftext|>'}) | |
| def decode(self, ids): | |
| return self.enc.decode(ids) | |
| class MiniTransformer(nn.Module): | |
| def __init__(self, vocab_size, emb_dim=768, n_layers=12, n_heads=12, ctx_len=1024, dropout=0.1): | |
| super().__init__() | |
| self.ctx_len = ctx_len | |
| self.n_heads = n_heads | |
| self.emb_dim = emb_dim | |
| self.n_layers = n_layers | |
| self.token_embedding_table = nn.Embedding(vocab_size, emb_dim) | |
| self.position_embedding_table = nn.Embedding(ctx_len, emb_dim) | |
| self.drop = nn.Dropout(dropout) | |
| self.blocks = nn.ModuleList([ | |
| nn.TransformerEncoderLayer( | |
| d_model=emb_dim, | |
| nhead=n_heads, | |
| dim_feedforward=emb_dim * 4, | |
| dropout=dropout, | |
| batch_first=True, | |
| norm_first=True, | |
| activation='gelu' | |
| ) for _ in range(n_layers) | |
| ]) | |
| self.ln_f = nn.LayerNorm(emb_dim) | |
| self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False) | |
| def forward(self, idx, targets=None): | |
| device = idx.device | |
| B, T = idx.shape | |
| idx = torch.clamp(idx, 0, self.token_embedding_table.num_embeddings - 1) | |
| tok_emb = self.token_embedding_table(idx) | |
| pos_emb = self.position_embedding_table(torch.arange(T, device=device)) | |
| x = self.drop(tok_emb + pos_emb) | |
| mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool() | |
| for block in self.blocks: | |
| x = block(x, src_mask=mask, is_causal=True) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| loss = None | |
| if targets is not None: | |
| targets = torch.clamp(targets, 0, self.lm_head.out_features - 1) | |
| B, T, C = logits.shape | |
| loss = F.cross_entropy(logits.view(B*T, C), targets.view(-1)) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, repetition_penalty=1.0): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.ctx_len:] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| if repetition_penalty != 1.0: | |
| for b in range(logits.shape[0]): | |
| for token_id in set(idx[b].tolist()): | |
| if logits[b, token_id] < 0: | |
| logits[b, token_id] *= repetition_penalty | |
| else: | |
| logits[b, token_id] /= repetition_penalty | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| if idx_next.item() == 50256: | |
| break | |
| return idx | |
| def load(cls, path, device='cpu'): | |
| ckpt = torch.load(path, map_location=device, weights_only=False) | |
| state_dict = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt | |
| cfg = {'vocab_size': 50257, 'emb_dim': 1024, 'n_layers': 24, 'n_heads': 16, 'ctx_len': 1024} | |
| if isinstance(ckpt, dict) and 'config' in ckpt: | |
| cfg = ckpt['config'] | |
| model = cls(cfg['vocab_size'], cfg['emb_dim'], cfg['n_layers'], cfg['n_heads'], cfg['ctx_len']) | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| name = k[7:] if k.startswith('module.') else k | |
| new_state_dict[name] = v | |
| model.load_state_dict(new_state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| return model | |