gpt2-llm2vec-final / custom_code /models /bidirectional_gpt2.py
aysinghal's picture
Upload custom model code (GPT2BiForMNTP)
9b4302d verified
"""Bidirectional GPT-2 variants for LLM2Vec-style conversion."""
from __future__ import annotations
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
class ModifiedGPT2Attention(GPT2Attention):
"""GPT-2 attention with causal masking removed."""
def _attn( # type: ignore[override]
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
# Key LLM2Vec-style change: skip GPT-2 causal mask so each token can
# attend to both previous and future tokens.
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class ModifiedGPT2Block(GPT2Block):
"""GPT-2 block using ModifiedGPT2Attention for self-attention."""
def __init__(self, config: GPT2Config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx=layer_idx)
self.attn = ModifiedGPT2Attention(config=config, layer_idx=layer_idx)
if config.add_cross_attention:
self.crossattention = ModifiedGPT2Attention(
config=config,
is_cross_attention=True,
layer_idx=layer_idx,
)
class GPT2BiModel(GPT2Model):
"""GPT-2 encoder stack with bidirectional self-attention."""
def __init__(self, config: GPT2Config):
super().__init__(config)
self.h = nn.ModuleList(
[ModifiedGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
)
self.post_init()
class GPT2BiForMNTP(GPT2LMHeadModel):
"""GPT-2 LM-head model whose backbone is GPT2BiModel."""
def __init__(self, config: GPT2Config):
super().__init__(config)
self.transformer = GPT2BiModel(config)
self.post_init()