import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class MultiheadAttention(nn.Module): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False): super().__init__() self.d_out = d_out self.num_heads = num_heads self.head_dim = d_out // num_heads #step 3 self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) self.out_proj = nn.Linear(d_out, d_out) self.dropout = nn.Dropout(dropout) self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length), diagonal=1)) def forward(self, x): b, num_tokens, d_in = x.shape #step 4 keys = self.W_key(x) queries = self.W_query(x) values = self.W_value(x) #step 5 keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim) #step 6 keys = keys.transpose(1,2) queries = queries.transpose(1,2) values = values.transpose(1,2) #step 7 attn_scores = queries @ keys.transpose(2,3) #step 8 mask_bool = self.mask.bool()[:num_tokens, :num_tokens] attn_scores.masked_fill_(mask_bool, -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = self.dropout(attn_weights) #step 9 - 11 ctx_vec = (attn_weights @ values).transpose(1, 2) #step 12 ctx_vec = ctx_vec.contiguous().view(b, num_tokens, self.d_out) ctx_vec = self.out_proj(ctx_vec) return ctx_vec #========================================================================== class LayerNorm(nn.Module): def __init__(self, emb_dim): super().__init__() self.eps = 1e-5 self.scale = nn.Parameter(torch.ones(emb_dim)) self.shift = nn.Parameter(torch.zeros(emb_dim)) def forward(self, x): mean = x.mean(dim=-1, keepdim=True) var = x.var(dim=-1, keepdim=True, unbiased=False) norm_x = (x - mean) / torch.sqrt(var + self.eps) return self.scale * norm_x + self.shift #========================================================================== class GeLU(nn.Module): def __init__(self): super().__init__() def forward(self, x): return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0/torch.pi)) * (x + 0.044715 * torch.pow(x,3)))) #========================================================================== class FeedForward(nn.Module): def __init__(self, cfg): super().__init__() self.layers = nn.Sequential( nn.Linear(cfg.emb_dim, 4*cfg.emb_dim), GeLU(), nn.Linear(4*cfg.emb_dim, cfg.emb_dim) ) def forward(self, x): return self.layers(x) #========================================================================== class TransformerBlock(nn.Module): def __init__(self, cfg): super().__init__() self.att = MultiheadAttention( d_in = cfg.emb_dim, d_out = cfg.emb_dim, context_length = cfg.context_length, dropout = cfg.drop_rate, num_heads = cfg.n_heads, qkv_bias = cfg.qkv_bias ) self.ff = FeedForward(cfg) self.norm1 = LayerNorm(cfg.emb_dim) self.norm2 = LayerNorm(cfg.emb_dim) self.drop_shortcut = nn.Dropout(cfg.drop_rate) def forward(self, x): shortcut = x x = self.norm1(x) x = self.att(x) x = self.drop_shortcut(x) x = x + shortcut shortcut = x x = self.norm2(x) x = self.ff(x) x = self.drop_shortcut(x) x = x + shortcut return x #========================================================================== class CustomGPTConfig(PretrainedConfig): model_type = "custom_gpt" # Unique identifier for the AutoClass def __init__(self, context_length=1024, drop_rate=0.0, emb_dim=1024, n_heads=16, n_layers=24, qkv_bias=True, vocab_size=50257, **kwargs): super().__init__(**kwargs) self.context_length = context_length self.drop_rate = drop_rate self.emb_dim = emb_dim self.n_heads = n_heads self.n_layers = n_layers self.qkv_bias = qkv_bias self.vocab_size = vocab_size #========================================================================== class CustomGPT( PreTrainedModel, ): config_class = CustomGPTConfig def __init__(self, config): super().__init__(config) self.tok_emb = nn.Embedding(config.vocab_size, config.emb_dim) self.pos_emb = nn.Embedding(config.context_length, config.emb_dim) self.drop_emb = nn.Dropout(config.drop_rate) self.trf_blocks = nn.Sequential( *[TransformerBlock(config) for _ in range(config.n_layers)] ) self.final_norm = LayerNorm(config.emb_dim) self.out_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False) def forward(self, x): batch_size, seq_len = x.shape tok_embeddings = self.tok_emb(x) #[2,4,768] pos_embeddings = self.pos_emb(torch.arange(seq_len, device=x.device)) #[2,4,768] x = tok_embeddings + pos_embeddings #[2,4,768] x = self.drop_emb(x) x = self.trf_blocks(x) x = self.final_norm(x) logits = self.out_head(x) #[2,4,50257] return logits def format_input(self, entry): instruction_text = ( f"Below is an instruction that describes a task. " f"Write a response that appropriately completes the request." f"\n\n### Instruction:\n{entry['instruction']}" ) # input_text = f"\n\n### Input:\n{entry['input']}" if entry["input"] else "" input_text = f"\n\n### Input:\n{entry.get('input', '')}" return instruction_text + input_text def text_to_token_ids(self, text, tokenizer): encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'}) encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension return encoded_tensor def token_ids_to_text(self, token_ids, tokenizer): flat = token_ids.squeeze(0) # remove batch dimension return tokenizer.decode(flat.tolist()) def generate(self, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): for _ in range(max_new_tokens): idx_cond = idx[:, -context_size:] with torch.no_grad(): logits = self(idx_cond) logits = logits[:, -1, :] if top_k is not None: # Keep only top_k values top_logits, _ = torch.topk(logits, top_k) min_val = top_logits[:, -1] # select the last element i.e., the smallest from each batch's output logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits) # New: Apply temperature scaling if temperature > 0.0: logits = logits / temperature # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) # Sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) # Otherwise same as before: get idx of the vocab entry with the highest logits value else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1) if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified break # Same as before: append sampled index to the running sequence idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1) return idx def generate_response(self, input_entry, tokenizer, temperature=0.0, topk=None): current_device = next(self.parameters()).device self.eval() input_text = self.format_input(input_entry) token_ids = self.generate( idx=self.text_to_token_ids(input_text, tokenizer).to(current_device), max_new_tokens=256, context_size=1024, temperature=temperature, top_k=topk, eos_id=50256 ) generated_text = self.token_ids_to_text(token_ids, tokenizer) response_text = ( generated_text[len(input_text):] .replace("### Response:", "") .strip() ) return response_text.strip()