File size: 7,801 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import torch
from torch import nn
from typing import Optional
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
try:
from ..model_components.attention import AttentionLogitsSequence, AttentionLogitsToken, Linear
except ImportError:
try:
from protify.model_components.attention import AttentionLogitsSequence, AttentionLogitsToken, Linear
except ImportError:
from model_components.attention import AttentionLogitsSequence, AttentionLogitsToken, Linear
try:
from ..model_components.transformer import TokenFormer, Transformer
except ImportError:
try:
from protify.model_components.transformer import TokenFormer, Transformer
except ImportError:
from model_components.transformer import TokenFormer, Transformer
from .losses import get_loss_fct
class RetrievalNetConfig(PretrainedConfig):
model_type = "retrievalnet"
def __init__(
self,
input_size: int = 768,
hidden_size: int = 512,
dropout: float = 0.2,
num_labels: int = 2,
n_layers: int = 1,
sim_type: str = 'dot',
token_attention: bool = False,
n_heads: int = 4,
task_type: str = 'singlelabel',
expansion_ratio: float = 8 / 3,
**kwargs,
):
super().__init__(**kwargs)
assert task_type != 'regression' or num_labels == 1, "Regression task must have exactly one label"
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.task_type = task_type
self.num_labels = num_labels
self.n_layers = n_layers
self.sim_type = sim_type
self.token_attention = token_attention
self.expansion_ratio = expansion_ratio
self.n_heads = n_heads
class RetrievalNetForSequenceClassification(PreTrainedModel):
config_class = RetrievalNetConfig
all_tied_weights_keys = {}
def __init__(self, config: RetrievalNetConfig):
super().__init__(config)
# If n_layers == 0, only learn how to distribute labels over the raw embeddings
if config.n_layers > 0:
self.input_proj = nn.Linear(config.input_size, config.hidden_size)
transformer_class = TokenFormer if config.token_attention else Transformer
self.transformer = transformer_class(
hidden_size=config.hidden_size,
n_heads=config.n_heads,
n_layers=config.n_layers,
expansion_ratio=config.expansion_ratio,
dropout=config.dropout,
rotary=True,
)
self.get_logits = AttentionLogitsSequence(
hidden_size=config.hidden_size if config.n_layers > 0 else config.input_size,
num_labels=config.num_labels,
sim_type=config.sim_type,
)
self.n_layers = config.n_layers
self.num_labels = config.num_labels
self.task_type = config.task_type
self.loss_fct = get_loss_fct(config.task_type)
def forward(
self,
embeddings,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
) -> SequenceClassifierOutput:
# Convert embeddings to match model's dtype to avoid dtype mismatch errors
# This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
if self.n_layers > 0:
embeddings = embeddings.to(next(self.input_proj.parameters()).dtype)
x = self.input_proj(embeddings) # (bs, seq_len, hidden_size)
x = self.transformer(x, attention_mask) # (bs, seq_len, hidden_size)
else:
# If no layers, still need to match dtype for get_logits
embeddings = embeddings.to(next(self.get_logits.parameters()).dtype)
x = embeddings
logits, sims, x = self.get_logits(x, attention_mask) # (bs, num_labels)
if self.task_type == 'sigmoid_regression':
logits = logits.sigmoid()
loss = None
if labels is not None:
if self.task_type == 'regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'sigmoid_regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'multilabel':
loss = self.loss_fct(logits, labels.float())
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=x if output_hidden_states else None,
attentions=sims if output_attentions else None
)
class RetrievalNetForTokenClassification(PreTrainedModel):
config_class = RetrievalNetConfig
all_tied_weights_keys = {}
def __init__(self, config: RetrievalNetConfig):
super().__init__(config)
if config.n_layers > 0:
self.input_proj = nn.Linear(config.input_size, config.hidden_size)
self.transformer = TokenFormer(
hidden_size=config.hidden_size,
n_heads=config.n_heads,
n_layers=config.n_layers,
expansion_ratio=config.expansion_ratio,
dropout=config.dropout,
rotary=config.rotary,
)
self.get_logits = AttentionLogitsToken(
hidden_size=config.hidden_size if config.n_layers > 0 else config.input_size,
num_labels=config.num_labels,
sim_type=config.sim_type,
)
self.n_layers = config.n_layers
self.num_labels = config.num_labels
self.task_type = config.task_type
self.loss_fct = get_loss_fct(config.task_type)
def forward(
self,
embeddings,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
# Convert embeddings to match model's dtype to avoid dtype mismatch errors
# This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
if self.n_layers > 0:
embeddings = embeddings.to(next(self.input_proj.parameters()).dtype)
x = self.input_proj(embeddings) # (bs, seq_len, hidden_size)
x = self.transformer(x, attention_mask) # (bs, seq_len, hidden_size)
else:
# If no layers, still need to match dtype for get_logits
embeddings = embeddings.to(next(self.get_logits.parameters()).dtype)
x = embeddings
logits = self.get_logits(x, attention_mask)
if self.task_type == 'sigmoid_regression':
logits = logits.sigmoid()
loss = None
if labels is not None:
if self.task_type == 'regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'sigmoid_regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'multilabel':
loss = self.loss_fct(logits, labels.float())
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
) |