File size: 2,850 Bytes
9b4302d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Bidirectional GPT-2 variants for LLM2Vec-style conversion."""

from __future__ import annotations

from typing import Optional, Tuple

import torch
from torch import nn
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block


class ModifiedGPT2Attention(GPT2Attention):
    """GPT-2 attention with causal masking removed."""

    def _attn(  # type: ignore[override]
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [],
                value.size(-1) ** 0.5,
                dtype=attn_weights.dtype,
                device=attn_weights.device,
            )

        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        # Key LLM2Vec-style change: skip GPT-2 causal mask so each token can
        # attend to both previous and future tokens.
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


class ModifiedGPT2Block(GPT2Block):
    """GPT-2 block using ModifiedGPT2Attention for self-attention."""

    def __init__(self, config: GPT2Config, layer_idx: Optional[int] = None):
        super().__init__(config, layer_idx=layer_idx)
        self.attn = ModifiedGPT2Attention(config=config, layer_idx=layer_idx)
        if config.add_cross_attention:
            self.crossattention = ModifiedGPT2Attention(
                config=config,
                is_cross_attention=True,
                layer_idx=layer_idx,
            )


class GPT2BiModel(GPT2Model):
    """GPT-2 encoder stack with bidirectional self-attention."""

    def __init__(self, config: GPT2Config):
        super().__init__(config)
        self.h = nn.ModuleList(
            [ModifiedGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
        )
        self.post_init()


class GPT2BiForMNTP(GPT2LMHeadModel):
    """GPT-2 LM-head model whose backbone is GPT2BiModel."""

    def __init__(self, config: GPT2Config):
        super().__init__(config)
        self.transformer = GPT2BiModel(config)
        self.post_init()