import math from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.pytorch_utils import Conv1D from transformers.utils import ( ModelOutput, logging, ) from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel, GenerationMixin from transformers.cache_utils import Cache from .configuration_backpack_gpt2 import BackpackGPT2Config logger = logging.get_logger(__name__) ### Backpack-Specific class BackpackGPT2PreTrainedModel(GPT2PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias"] config_class = BackpackGPT2Config base_model_prefix = "backpack" is_parallelizable = True supports_gradient_checkpointing = False _no_split_modules = ["GPT2Block", "BackpackNoMixBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) class BackpackMLP(nn.Module): def __init__(self, embed_dim, intermediate_dim, out_dim, config): super().__init__() self.c_fc = Conv1D(intermediate_dim, embed_dim) self.c_proj = Conv1D(out_dim, intermediate_dim) self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class BackpackNoMixBlock(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.mlp = BackpackMLP(config.n_embd, config.n_embd*4, config.n_embd, config) self.resid_dropout1 = nn.Dropout(config.resid_pdrop) self.resid_dropout2 = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states, residual): residual = self.resid_dropout1(hidden_states) + residual hidden_states = self.ln_1(residual) mlp_out = self.mlp(hidden_states) residual = self.resid_dropout2(mlp_out) + residual hidden_states = self.ln_2(residual) return hidden_states class BackpackSenseNetwork(nn.Module): def __init__(self, config, num_senses, device=None, dtype=None): super().__init__() self.num_senses = num_senses #self.embeddings = embeddings self.n_embd = config.n_embd self.dropout = nn.Dropout(config.embd_pdrop) self.block = BackpackNoMixBlock(config) self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon) self.final_mlp = BackpackMLP( embed_dim=config.n_embd, intermediate_dim=config.sense_intermediate_scale*config.n_embd, out_dim=config.n_embd*config.num_senses, config=config, ) def forward(self, input_embeds): residual = self.dropout(input_embeds) hidden_states = self.ln(residual) hidden_states = self.block(hidden_states, residual) senses = self.final_mlp(hidden_states) bs, s, nvd = senses.shape return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1,2) # (bs, nv, s, d) class BackpackWeightNetwork(nn.Module): def __init__(self, num_senses, embed_dim): super().__init__() self.n_embd = embed_dim self.num_senses = num_senses self.embed_per_sense = embed_dim // num_senses self.c_attn = nn.Linear(embed_dim, 2 * num_senses * self.embed_per_sense) self.softmax_scale = None def forward(self, encoded): """ b, s, d = encoded.shape encoded = self.c_attn(encoded) # (b, s, 2*d) encoded = encoded.reshape(b, s, 2, self.num_senses, self.embed_per_sense) #(b, s, 2, nv, d//nv) batch_size, seqlen = encoded.shape[0], encoded.shape[1] # compute scores & mask q, k = encoded.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) scores = scores + causal_mask.to(dtype=scores.dtype) return torch.softmax(scores, dim=-1, dtype=q.dtype) """ b, s, d = encoded.shape x = self.c_attn(encoded) # (b, s, 2*d) x = x.reshape(b, s, 2, self.num_senses, self.embed_per_sense) # (b, s, 2, nv, d//nv) # q,k: (b, s, nv, d//nv) q, k = x.unbind(dim=2) # scale (compute as float32 to reduce rounding error, then cast) scale = (self.softmax_scale if self.softmax_scale is not None else 1.0 / math.sqrt(q.shape[-1])) # einsum gives (b, nv, s, s) scores = torch.einsum('bthd,bshd->bhts', q, k) * scale # keep native dtype here # boolean causal mask: True = mask-out # shape (s, s) → broadcast to (1, 1, s, s) → (b, nv, s, s) causal_mask = torch.ones(s, s, device=scores.device, dtype=torch.bool).triu_(1) scores = scores.float() # do the numerically sensitive bits in fp32 scores = scores.masked_fill(causal_mask, float('-inf')) attn = torch.softmax(scores, dim=-1) # fp32 softmax attn = attn.to(q.dtype) # cast back (fp16/bf16) for downstream return attn @dataclass class BackpackGPT2BaseModelOutput(ModelOutput): hidden_states: Optional[torch.FloatTensor] = None contextualization: Optional[torch.FloatTensor] = None senses: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple] = None # include cache in base output too class BackpackGPT2Model(BackpackGPT2PreTrainedModel): _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"] def __init__(self, config): super().__init__(config) self.embed_dim = config.n_embd self.num_senses = config.num_senses self.gpt2_model = GPT2Model(config) self.sense_network = BackpackSenseNetwork(config, self.num_senses, self.gpt2_model.wte) self.word_embeddings = self.gpt2_model.wte self.position_embeddings = self.gpt2_model.wpe self.sense_weight_net = BackpackWeightNetwork(self.num_senses, self.embed_dim) # Model parallel self.model_parallel = False self.device_map = None self.gradient_checkpointing = False def get_num_senses(self): return self.num_senses def get_word_embeddings(self): return self.word_embeddings def get_sense_network(self): return self.sense_network def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, new_embeddings): self.word_embeddings = new_embeddings def forward( self, input_ids, position_ids, cache_position: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs): # Compute senses sense_input_embeds = self.word_embeddings(input_ids) senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d) # Compute contextualization weights #contextl_hidden_states = self.gpt2_model(input_ids, position_ids=position_ids).last_hidden_state # (bs, s, d) gpt2_out = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, return_dict=True,**kwargs) contextl_hidden_states = gpt2_out.last_hidden_state contextualization = self.sense_weight_net(contextl_hidden_states) # (bs, nv, s, s) # Compute resulting outputs hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d) return BackpackGPT2BaseModelOutput( hidden_states=hidden_states, contextualization=contextualization, senses=senses, past_key_values=gpt2_out.past_key_values ) def run_with_custom_contextualization(self, input_ids, contextualization): # Compute senses sense_input_embeds = self.word_embeddings(input_ids) senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d) # Compute resulting outputs hidden_states = torch.sum(contextualization @ senses, dim=1) # (bs, nv, s, d) -> (bs, s, d) return BackpackGPT2BaseModelOutput( hidden_states=hidden_states, contextualization=contextualization, senses=senses ) @dataclass class BackpackGPT2LMHeadModelOutput(ModelOutput): # Make the FIRST field Optional so HF won't enforce “only one required field” logits: Optional[torch.FloatTensor] = None contextualization: Optional[torch.FloatTensor] = None backpack_hidden_states: Optional[torch.FloatTensor] = None loss: Optional[torch.Tensor] = None # smoothed (for training) loss_unsmoothed: Optional[torch.Tensor] = None # raw CE for logging senses: Optional[torch.FloatTensor] = None past_key_values: Optional[Tuple] = None # <<< required for GenerationMixin class BackpackGPT2LMHeadModel(BackpackGPT2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [r".*attn.masked_bias", r".*attn.bias"] accepts_loss_kwargs = False def __init__(self, config): super().__init__(config) self.backpack = BackpackGPT2Model(config) #self.lm_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Model parallel self.model_parallel = False self.device_map = None self.tie_weights() def tie_weights(self): self.lm_head.weight = self.backpack.word_embeddings.weight # also tied with the underlying underlying transf def get_lm_head(self): return self.lm_head def get_input_embeddings(self): return self.backpack.word_embeddings def can_generate(self): # Hint to GenerationMixin that this is generative return True def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, **kwargs ): # GPT-2 style incremental decoding: if we have cache, only feed the last token if past_key_values is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": attention_mask, "use_cache": kwargs.get("use_cache", True), } def forward( self, input_ids, position_ids=None, labels: Optional[torch.LongTensor] = None, label_smoothing: Optional[float] = 0, cache_position: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs): outputs = self.backpack( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, past_key_values=past_key_values, cache_position=cache_position, return_dict=True, **kwargs ) hidden_states, contextualization = outputs.hidden_states, outputs.contextualization senses = outputs.senses #hidden_states = self.lm_ln(hidden_states) lm_logits = self.lm_head(hidden_states) # (bs, s, V) loss = None loss_unsmoothed = None if labels is not None: labels = labels.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='mean', label_smoothing=label_smoothing) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # Reporting loss: **unsmoothed** (no grad) in case we used label smoothing with torch.no_grad(): ce_raw = CrossEntropyLoss(ignore_index=-100, reduction="mean") loss_unsmoothed = ce_raw( shift_logits.detach().view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) return BackpackGPT2LMHeadModelOutput( logits=lm_logits, contextualization=contextualization, backpack_hidden_states=hidden_states, loss=loss, loss_unsmoothed=loss_unsmoothed, senses=senses, past_key_values=outputs.past_key_values ) def run_with_custom_contextualization(self, input_ids, contextualization): outputs = self.backpack.run_with_custom_contextualization(input_ids, contextualization) hidden_states, contextualization = outputs.hidden_states, outputs.contextualization lm_logits = self.lm_head(hidden_states) return BackpackGPT2LMHeadModelOutput( logits=lm_logits, contextualization=contextualization, backpack_hidden_states=hidden_states, )