File size: 3,832 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torch import nn
from typing import Optional
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
try:
    from ..model_components.mlp import intermediate_correction_fn
except ImportError:
    try:
        from protify.model_components.mlp import intermediate_correction_fn
    except ImportError:
        from model_components.mlp import intermediate_correction_fn
from .losses import get_loss_fct


class LinearProbeConfig(PretrainedConfig):
    model_type = "linear_probe"
    def __init__(
            self,
            input_size: int = 768,
            hidden_size: int = 8192,
            dropout: float = 0.2,
            num_labels: int = 2,
            n_layers: int = 1,
            task_type: str = 'singlelabel',
            pre_ln: bool = True,
            use_bias: bool = False,
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.task_type = task_type
        self.num_labels = num_labels
        self.n_layers = n_layers
        self.pre_ln = pre_ln
        self.use_bias = use_bias


class LinearProbe(PreTrainedModel):
    config_class = LinearProbeConfig
    all_tied_weights_keys = {}
    def __init__(self, config: LinearProbeConfig):
        super().__init__(config)
        self.config = config
        self.task_type = config.task_type
        self.loss_fct = get_loss_fct(config.task_type)
        self.num_labels = config.num_labels
        use_bias = config.use_bias
        layers = []
        if config.pre_ln:
            layers.append(nn.LayerNorm(config.input_size))
        layers.append(nn.Linear(config.input_size, config.hidden_size, bias=use_bias))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(config.dropout))
        
        for _ in range(config.n_layers):
            layers.append(nn.Linear(config.hidden_size, config.hidden_size, bias=use_bias))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(config.dropout))

        proj_dim = intermediate_correction_fn(2, config.num_labels) # finds nearest multiple of 256 of 2 * num_labels
        layers.append(nn.LayerNorm(config.hidden_size))
        layers.append(nn.Linear(config.hidden_size, proj_dim, bias=use_bias))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(config.dropout))
        layers.append(nn.Linear(proj_dim, config.num_labels, bias=use_bias))
        self.layers = nn.Sequential(*layers)

    def forward(self, embeddings: torch.Tensor, labels: Optional[torch.Tensor] = None) -> SequenceClassifierOutput:
        # Convert embeddings to match model's dtype to avoid dtype mismatch errors
        # This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
        embeddings = embeddings.to(next(self.layers.parameters()).dtype)
        logits = self.layers(embeddings)
        if self.task_type == 'sigmoid_regression':
            logits = logits.sigmoid()
        loss = None
        if labels is not None:
            bs = logits.size(0)
            if self.task_type == 'regression':
                loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
            elif self.task_type == 'sigmoid_regression':
                loss = self.loss_fct(logits.view(-1), labels.view(-1).float())
            elif self.task_type == 'multilabel':
                loss = self.loss_fct(logits.view(bs, -1), labels.view(bs, -1).float())
            else:
                loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )