| import torch | |
| import torch.nn as nn | |
| from typing import Literal | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| class ProbeConfig(PretrainedConfig): | |
| model_type = "linear_probe" | |
| def __init__( | |
| self, | |
| embedding_dim: int = 768, | |
| dropout: float = 0.0, | |
| layer_index: int = -1, | |
| probe_type: Literal["linear", "nonlinear"] = "linear", | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.embedding_dim = embedding_dim | |
| self.dropout = dropout | |
| self.layer_index = layer_index | |
| self.probe_type = probe_type | |
| class ProbeModel(PreTrainedModel): | |
| config_class = ProbeConfig | |
| def __init__(self, config: ProbeConfig): | |
| super().__init__(config) | |
| self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else None | |
| self.linear = nn.Linear(config.embedding_dim, 1) | |
| def forward( | |
| self, | |
| embeddings: torch.Tensor, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| if self.dropout is not None: | |
| embeddings = self.dropout(embeddings) | |
| logits = self.linear(embeddings) | |
| return torch.sigmoid(logits) | |