tytodd commited on
Commit
d4857ae
·
verified ·
1 Parent(s): f253122

Upload probe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. probe.py +41 -0
probe.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Literal
4
+ from transformers import PretrainedConfig, PreTrainedModel
5
+
6
+
7
+ class ProbeConfig(PretrainedConfig):
8
+ model_type = "linear_probe"
9
+
10
+ def __init__(
11
+ self,
12
+ embedding_dim: int = 768,
13
+ dropout: float = 0.0,
14
+ layer_index: int = -1,
15
+ probe_type: Literal["linear", "nonlinear"] = "linear",
16
+ **kwargs,
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.embedding_dim = embedding_dim
20
+ self.dropout = dropout
21
+ self.layer_index = layer_index
22
+ self.probe_type = probe_type
23
+
24
+
25
+ class ProbeModel(PreTrainedModel):
26
+ config_class = ProbeConfig
27
+
28
+ def __init__(self, config: ProbeConfig):
29
+ super().__init__(config)
30
+ self.dropout = nn.Dropout(config.dropout) if config.dropout > 0 else None
31
+ self.linear = nn.Linear(config.embedding_dim, 1)
32
+
33
+ def forward(
34
+ self,
35
+ embeddings: torch.Tensor,
36
+ **kwargs,
37
+ ) -> torch.Tensor:
38
+ if self.dropout is not None:
39
+ embeddings = self.dropout(embeddings)
40
+ logits = self.linear(embeddings)
41
+ return torch.sigmoid(logits)