ai1 / model.py
Kleinpuki2's picture
Update model.py
9125647 verified
import torch
import torch.nn as nn
from torch.nn import functional as F
import json
import re
class BPETokenizer:
def __init__(self, model_type="gpt2"):
import tiktoken
self.enc = tiktoken.get_encoding(model_type)
def encode(self, text):
return self.enc.encode(text, allowed_special={'<|endoftext|>'})
def decode(self, ids):
return self.enc.decode(ids)
class MiniTransformer(nn.Module):
def __init__(self, vocab_size, emb_dim=768, n_layers=12, n_heads=12, ctx_len=1024, dropout=0.1):
super().__init__()
self.ctx_len = ctx_len
self.n_heads = n_heads
self.emb_dim = emb_dim
self.n_layers = n_layers
self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
self.position_embedding_table = nn.Embedding(ctx_len, emb_dim)
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=emb_dim,
nhead=n_heads,
dim_feedforward=emb_dim * 4,
dropout=dropout,
batch_first=True,
norm_first=True,
activation='gelu'
) for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(emb_dim)
self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False)
def forward(self, idx, targets=None):
device = idx.device
B, T = idx.shape
idx = torch.clamp(idx, 0, self.token_embedding_table.num_embeddings - 1)
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = self.drop(tok_emb + pos_emb)
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
for block in self.blocks:
x = block(x, src_mask=mask, is_causal=True)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
targets = torch.clamp(targets, 0, self.lm_head.out_features - 1)
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B*T, C), targets.view(-1))
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, repetition_penalty=1.0):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.ctx_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if repetition_penalty != 1.0:
for b in range(logits.shape[0]):
for token_id in set(idx[b].tolist()):
if logits[b, token_id] < 0:
logits[b, token_id] *= repetition_penalty
else:
logits[b, token_id] /= repetition_penalty
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
if idx_next.item() == 50256:
break
return idx
@classmethod
def load(cls, path, device='cpu'):
ckpt = torch.load(path, map_location=device, weights_only=False)
state_dict = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt
cfg = {'vocab_size': 50257, 'emb_dim': 1024, 'n_layers': 24, 'n_heads': 16, 'ctx_len': 1024}
if isinstance(ckpt, dict) and 'config' in ckpt:
cfg = ckpt['config']
model = cls(cfg['vocab_size'], cfg['emb_dim'], cfg['n_layers'], cfg['n_heads'], cfg['ctx_len'])
new_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict, strict=False)
model.to(device)
model.eval()
return model