|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as f |
|
|
from dataclasses import dataclass |
|
|
import inspect |
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
context_length : int = 1024 |
|
|
vocab_size: int = 50257 |
|
|
num_layers : int = 12 |
|
|
embedding_dim : int = 768 |
|
|
num_heads: int = 12 |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
|
def __init__(self,config : Config,masked=False): |
|
|
super(MultiHeadAttention,self).__init__() |
|
|
self.num_heads = config.num_heads |
|
|
self.masked = masked |
|
|
self.embedding_dim = config.embedding_dim |
|
|
self.c_attention = nn.Linear(config.embedding_dim,3*config.embedding_dim) |
|
|
self.c_projection = nn.Linear(config.embedding_dim,config.embedding_dim) |
|
|
self.c_projection.SCALE_INIT = 1.0 |
|
|
|
|
|
def forward(self,x): |
|
|
B, T, C = x.shape |
|
|
QKV = self.c_attention(x) |
|
|
Query_q,Key_k,Value_v = QKV.split(self.embedding_dim,dim=-1) |
|
|
Query_q = Query_q.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2) |
|
|
Key_k = Key_k.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2) |
|
|
Value_v = Value_v.view(B,T,self.num_heads,self.embedding_dim//self.num_heads).transpose(1,2) |
|
|
|
|
|
|
|
|
if self.masked: |
|
|
out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=True) |
|
|
else: |
|
|
out = f.scaled_dot_product_attention(Query_q,Key_k,Value_v,is_causal=False) |
|
|
out = out.transpose(1,2).contiguous().view(B,T,C) |
|
|
return self.c_projection(out) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self,config : Config): |
|
|
super(MLP,self).__init__() |
|
|
self.c_fc = nn.Linear(config.embedding_dim,4*config.embedding_dim) |
|
|
self.gelu = nn.GELU(approximate='tanh') |
|
|
self.c_projection = nn.Linear(4*config.embedding_dim,config.embedding_dim) |
|
|
self.c_projection.SCALE_INIT = 1.0 |
|
|
def forward(self,x): |
|
|
x = self.c_fc(x) |
|
|
x = self.gelu(x) |
|
|
x = self.c_projection(x) |
|
|
return x |
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
|
def __init__(self,config : Config): |
|
|
"""Decoder block without the encoder output""" |
|
|
super(DecoderBlock,self).__init__() |
|
|
self.masked_attention = MultiHeadAttention(config,masked=True) |
|
|
self.layer_norm1 = nn.LayerNorm(config.embedding_dim) |
|
|
|
|
|
|
|
|
self.mlp = MLP(config) |
|
|
self.layer_norm3 = nn.LayerNorm(config.embedding_dim) |
|
|
|
|
|
def forward(self,x): |
|
|
x = x + self.masked_attention(self.layer_norm1(x)) |
|
|
|
|
|
x = x + self.mlp(self.layer_norm3(x)) |
|
|
return x |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
def __init__(self,config : Config): |
|
|
super(TransformerDecoder,self).__init__() |
|
|
self.config = config |
|
|
self.word_token_embedding = nn.Embedding(self.config.vocab_size,self.config.embedding_dim) |
|
|
self.word_position_embedding = nn.Embedding(self.config.context_length,self.config.embedding_dim) |
|
|
layers = [DecoderBlock(config) for _ in range(config.num_layers)] |
|
|
self.hidden_layers = nn.Sequential(*layers) |
|
|
self.layer_norm = nn.LayerNorm(self.config.embedding_dim) |
|
|
|
|
|
def forward(self,idx): |
|
|
B,T = idx.shape |
|
|
pos = torch.arange(0,T,dtype=torch.long,device=idx.device) |
|
|
pos_embed = self.word_position_embedding(pos) |
|
|
token_embed = self.word_token_embedding(idx) |
|
|
x = pos_embed + token_embed |
|
|
x = self.hidden_layers(x) |
|
|
x = self.layer_norm(x) |
|
|
return x |
|
|
|
|
|
class GPT(nn.Module): |
|
|
def __init__(self,config : Config): |
|
|
super(GPT,self).__init__() |
|
|
self.config=config |
|
|
self.transformerDecoder = TransformerDecoder(config) |
|
|
self.language_modeling_head = nn.Linear(config.embedding_dim,config.vocab_size,bias=False) |
|
|
self.transformerDecoder.word_token_embedding.weight = self.language_modeling_head.weight |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self,module): |
|
|
if isinstance(module,nn.Linear): |
|
|
std=0.02 |
|
|
if hasattr(module,'SCALE_INIT'): |
|
|
std /= (2*self.config.num_layers)**0.5 |
|
|
torch.nn.init.normal_(module.weight,mean=0,std=std) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module,nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight,mean=0,std=0.02) |
|
|
|
|
|
def forward(self,idx,targets=None): |
|
|
x = self.transformerDecoder(idx) |
|
|
logits = self.language_modeling_head(x) |
|
|
loss = None |
|
|
if targets is not None: |
|
|
loss = f.cross_entropy(logits.view(-1,logits.shape[-1]),targets.view(-1)) |
|
|
return logits,loss |
|
|
@torch.no_grad() |
|
|
def generate(self, idx, max_new_tokens=50, temperature=0.8, top_k=None, do_sample=False, eos_token_id=None): |
|
|
self.eval() |
|
|
|
|
|
B, T = idx.shape |
|
|
device = idx.device |
|
|
context_len = self.config.context_length |
|
|
|
|
|
if T > context_len: |
|
|
idx = idx[:, -context_len:] |
|
|
T = idx.shape[1] |
|
|
|
|
|
generated = idx.clone() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
input_ids = generated[:, -context_len:] |
|
|
|
|
|
logits, _ = self.forward(input_ids, targets=None) |
|
|
next_logits = logits[:, -1, :] |
|
|
|
|
|
if temperature != 1.0 and temperature > 0.0: |
|
|
next_logits = next_logits / temperature |
|
|
|
|
|
if do_sample: |
|
|
if top_k is not None and top_k > 0: |
|
|
vals, idxs = next_logits.topk(top_k, dim=-1) |
|
|
min_vals = vals[:, -1].unsqueeze(-1) |
|
|
mask = next_logits < min_vals |
|
|
next_logits = next_logits.masked_fill(mask, float('-inf')) |
|
|
|
|
|
probs = torch.softmax(next_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(next_logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
if eos_token_id is not None: |
|
|
if (generated == eos_token_id).any(dim=1).all(): |
|
|
break |
|
|
|
|
|
return generated |
|
|
def configure_optimizer(self,weight_decay,lr,device_type,master_process): |
|
|
param_dict = {pn:p for pn, p in self.named_parameters() if p.requires_grad} |
|
|
|
|
|
decay_params = [p for pn, p in param_dict.items() if p.dim() >=2] |
|
|
nodecay_params = [p for pn, p in param_dict.items() if p.dim() < 2] |
|
|
optim_groups = [ |
|
|
{'params':decay_params,'weight_decay':weight_decay}, |
|
|
{'params':nodecay_params,'weight_decay':0.0} |
|
|
] |
|
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
|
if master_process: |
|
|
print(f'num decay parameter tensors: {len(decay_params)} with {num_decay_params:,} parameters') |
|
|
print(f'num nodecay parameter tensors: {len(nodecay_params)} with {num_nodecay_params:,} parameters') |
|
|
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters |
|
|
use_fused = fused_available and device_type == 'cuda' |
|
|
if master_process: |
|
|
print(f'using fused AdamW optimizer: {use_fused}') |
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.95), eps=1e-8, fused=use_fused) |
|
|
return optimizer |
|
|
|