directionality_probe / protify /probes /retrievalnet.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
from torch import nn
from typing import Optional
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
try:
from ..model_components.attention import AttentionLogitsSequence, AttentionLogitsToken, Linear
except ImportError:
try:
from protify.model_components.attention import AttentionLogitsSequence, AttentionLogitsToken, Linear
except ImportError:
from model_components.attention import AttentionLogitsSequence, AttentionLogitsToken, Linear
try:
from ..model_components.transformer import TokenFormer, Transformer
except ImportError:
try:
from protify.model_components.transformer import TokenFormer, Transformer
except ImportError:
from model_components.transformer import TokenFormer, Transformer
from .losses import get_loss_fct
class RetrievalNetConfig(PretrainedConfig):
model_type = "retrievalnet"
def __init__(
self,
input_size: int = 768,
hidden_size: int = 512,
dropout: float = 0.2,
num_labels: int = 2,
n_layers: int = 1,
sim_type: str = 'dot',
token_attention: bool = False,
n_heads: int = 4,
task_type: str = 'singlelabel',
expansion_ratio: float = 8 / 3,
**kwargs,
):
super().__init__(**kwargs)
assert task_type != 'regression' or num_labels == 1, "Regression task must have exactly one label"
self.input_size = input_size
self.hidden_size = hidden_size
self.dropout = dropout
self.task_type = task_type
self.num_labels = num_labels
self.n_layers = n_layers
self.sim_type = sim_type
self.token_attention = token_attention
self.expansion_ratio = expansion_ratio
self.n_heads = n_heads
class RetrievalNetForSequenceClassification(PreTrainedModel):
config_class = RetrievalNetConfig
all_tied_weights_keys = {}
def __init__(self, config: RetrievalNetConfig):
super().__init__(config)
# If n_layers == 0, only learn how to distribute labels over the raw embeddings
if config.n_layers > 0:
self.input_proj = nn.Linear(config.input_size, config.hidden_size)
transformer_class = TokenFormer if config.token_attention else Transformer
self.transformer = transformer_class(
hidden_size=config.hidden_size,
n_heads=config.n_heads,
n_layers=config.n_layers,
expansion_ratio=config.expansion_ratio,
dropout=config.dropout,
rotary=True,
)
self.get_logits = AttentionLogitsSequence(
hidden_size=config.hidden_size if config.n_layers > 0 else config.input_size,
num_labels=config.num_labels,
sim_type=config.sim_type,
)
self.n_layers = config.n_layers
self.num_labels = config.num_labels
self.task_type = config.task_type
self.loss_fct = get_loss_fct(config.task_type)
def forward(
self,
embeddings,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
) -> SequenceClassifierOutput:
# Convert embeddings to match model's dtype to avoid dtype mismatch errors
# This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
if self.n_layers > 0:
embeddings = embeddings.to(next(self.input_proj.parameters()).dtype)
x = self.input_proj(embeddings) # (bs, seq_len, hidden_size)
x = self.transformer(x, attention_mask) # (bs, seq_len, hidden_size)
else:
# If no layers, still need to match dtype for get_logits
embeddings = embeddings.to(next(self.get_logits.parameters()).dtype)
x = embeddings
logits, sims, x = self.get_logits(x, attention_mask) # (bs, num_labels)
if self.task_type == 'sigmoid_regression':
logits = logits.sigmoid()
loss = None
if labels is not None:
if self.task_type == 'regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'sigmoid_regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'multilabel':
loss = self.loss_fct(logits, labels.float())
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=x if output_hidden_states else None,
attentions=sims if output_attentions else None
)
class RetrievalNetForTokenClassification(PreTrainedModel):
config_class = RetrievalNetConfig
all_tied_weights_keys = {}
def __init__(self, config: RetrievalNetConfig):
super().__init__(config)
if config.n_layers > 0:
self.input_proj = nn.Linear(config.input_size, config.hidden_size)
self.transformer = TokenFormer(
hidden_size=config.hidden_size,
n_heads=config.n_heads,
n_layers=config.n_layers,
expansion_ratio=config.expansion_ratio,
dropout=config.dropout,
rotary=config.rotary,
)
self.get_logits = AttentionLogitsToken(
hidden_size=config.hidden_size if config.n_layers > 0 else config.input_size,
num_labels=config.num_labels,
sim_type=config.sim_type,
)
self.n_layers = config.n_layers
self.num_labels = config.num_labels
self.task_type = config.task_type
self.loss_fct = get_loss_fct(config.task_type)
def forward(
self,
embeddings,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
# Convert embeddings to match model's dtype to avoid dtype mismatch errors
# This handles cases where embeddings are fp32 but model is fp16 (or vice versa)
if self.n_layers > 0:
embeddings = embeddings.to(next(self.input_proj.parameters()).dtype)
x = self.input_proj(embeddings) # (bs, seq_len, hidden_size)
x = self.transformer(x, attention_mask) # (bs, seq_len, hidden_size)
else:
# If no layers, still need to match dtype for get_logits
embeddings = embeddings.to(next(self.get_logits.parameters()).dtype)
x = embeddings
logits = self.get_logits(x, attention_mask)
if self.task_type == 'sigmoid_regression':
logits = logits.sigmoid()
loss = None
if labels is not None:
if self.task_type == 'regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'sigmoid_regression':
loss = self.loss_fct(logits.flatten(), labels.view(-1).float())
elif self.task_type == 'multilabel':
loss = self.loss_fct(logits, labels.float())
else:
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1).long())
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)