"""Bidirectional GPT-2 variants for LLM2Vec-style conversion.""" from __future__ import annotations from typing import Optional, Tuple import torch from torch import nn from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block class ModifiedGPT2Attention(GPT2Attention): """GPT-2 attention with causal masking removed.""" def _attn( # type: ignore[override] self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device, ) if self.scale_attn_by_inverse_layer_idx: attn_weights = attn_weights / float(self.layer_idx + 1) # Key LLM2Vec-style change: skip GPT-2 causal mask so each token can # attend to both previous and future tokens. if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights class ModifiedGPT2Block(GPT2Block): """GPT-2 block using ModifiedGPT2Attention for self-attention.""" def __init__(self, config: GPT2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx=layer_idx) self.attn = ModifiedGPT2Attention(config=config, layer_idx=layer_idx) if config.add_cross_attention: self.crossattention = ModifiedGPT2Attention( config=config, is_cross_attention=True, layer_idx=layer_idx, ) class GPT2BiModel(GPT2Model): """GPT-2 encoder stack with bidirectional self-attention.""" def __init__(self, config: GPT2Config): super().__init__(config) self.h = nn.ModuleList( [ModifiedGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)] ) self.post_init() class GPT2BiForMNTP(GPT2LMHeadModel): """GPT-2 LM-head model whose backbone is GPT2BiModel.""" def __init__(self, config: GPT2Config): super().__init__(config) self.transformer = GPT2BiModel(config) self.post_init()