jmccardle commited on
Commit
fa35377
·
verified ·
1 Parent(s): 80adc5b

Upload folder using huggingface_hub

Browse files
__pycache__/load_model.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model": "answerdotai/ModernBERT-large",
3
+ "nli_hidden_dim": 512,
4
+ "nli_classes": 3,
5
+ "abstention_hidden_dim": 128,
6
+ "abstention_classes": 2,
7
+ "nli_labels": [
8
+ "entailment",
9
+ "neutral",
10
+ "contradiction"
11
+ ],
12
+ "abstention_labels": [
13
+ "confident",
14
+ "uncertain"
15
+ ],
16
+ "training": {
17
+ "nli_epochs": 5,
18
+ "nli_accuracy": 0.708,
19
+ "abstention_epochs": 3,
20
+ "abstention_accuracy": 0.6546,
21
+ "abstention_recall": 0.766
22
+ }
23
+ }
load_model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load ModernBERT-NLI from HuggingFace base model + task heads.
3
+
4
+ Usage:
5
+ from load_model import load_modernbert_nli
6
+
7
+ model, tokenizer = load_modernbert_nli("path/to/task_heads.pt")
8
+
9
+ # NLI classification
10
+ logits = model(**tokenizer(premise, hypothesis, return_tensors="pt"), mode="nli")
11
+
12
+ # With abstention
13
+ nli_logits, abstention_logits = model(**inputs, mode="abstention")
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+
21
+ class ModernBERTWithNLI(nn.Module):
22
+ """ModernBERT with NLI and abstention heads."""
23
+
24
+ def __init__(self, base_model_name: str = "answerdotai/ModernBERT-large"):
25
+ super().__init__()
26
+
27
+ # Load base encoder from HuggingFace
28
+ self.encoder = AutoModel.from_pretrained(base_model_name)
29
+ hidden_size = self.encoder.config.hidden_size # 1024 for large
30
+
31
+ # NLI head (split for abstention access)
32
+ self.nli_hidden = nn.Sequential(
33
+ nn.Linear(hidden_size, 512),
34
+ nn.LayerNorm(512),
35
+ nn.GELU(),
36
+ nn.Dropout(0.1),
37
+ )
38
+ self.nli_output = nn.Linear(512, 3)
39
+
40
+ # Abstention head: takes [nli_hidden, nli_logits]
41
+ self.abstention_head = nn.Sequential(
42
+ nn.Linear(512 + 3, 128),
43
+ nn.LayerNorm(128),
44
+ nn.GELU(),
45
+ nn.Dropout(0.1),
46
+ nn.Linear(128, 2),
47
+ )
48
+
49
+ # Freeze encoder by default
50
+ for param in self.encoder.parameters():
51
+ param.requires_grad = False
52
+
53
+ def forward(
54
+ self,
55
+ input_ids: torch.Tensor,
56
+ attention_mask: torch.Tensor = None,
57
+ mode: str = "nli",
58
+ ):
59
+ """
60
+ Forward pass with multiple modes.
61
+
62
+ Args:
63
+ input_ids: Token IDs
64
+ attention_mask: Attention mask
65
+ mode: One of "embed", "late_interaction", "nli", "abstention"
66
+
67
+ Returns:
68
+ Depends on mode:
69
+ - "embed": (batch, hidden_size) CLS embeddings
70
+ - "late_interaction": (batch, seq_len, hidden_size) all token embeddings
71
+ - "nli": (batch, 3) NLI logits
72
+ - "abstention": tuple of (nli_logits, abstention_logits)
73
+ """
74
+ outputs = self.encoder(input_ids, attention_mask=attention_mask)
75
+ hidden_states = outputs.last_hidden_state
76
+
77
+ if mode == "embed":
78
+ return hidden_states[:, 0] # CLS token
79
+
80
+ elif mode == "late_interaction":
81
+ return hidden_states # All tokens
82
+
83
+ elif mode == "nli":
84
+ cls_hidden = hidden_states[:, 0]
85
+ nli_hidden = self.nli_hidden(cls_hidden)
86
+ return self.nli_output(nli_hidden)
87
+
88
+ elif mode == "abstention":
89
+ cls_hidden = hidden_states[:, 0]
90
+ nli_hidden = self.nli_hidden(cls_hidden)
91
+ nli_logits = self.nli_output(nli_hidden)
92
+
93
+ # Concat hidden and logits for abstention
94
+ abstention_input = torch.cat([nli_hidden, nli_logits], dim=-1)
95
+ abstention_logits = self.abstention_head(abstention_input)
96
+
97
+ return nli_logits, abstention_logits
98
+
99
+ else:
100
+ raise ValueError(f"Unknown mode: {mode}")
101
+
102
+
103
+ def load_modernbert_nli(
104
+ task_heads_path: str,
105
+ base_model: str = "answerdotai/ModernBERT-large",
106
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
107
+ ):
108
+ """
109
+ Load ModernBERT-NLI model.
110
+
111
+ Args:
112
+ task_heads_path: Path to task_heads.pt file
113
+ base_model: HuggingFace model ID for base encoder
114
+ device: Device to load model on
115
+
116
+ Returns:
117
+ (model, tokenizer) tuple
118
+ """
119
+ # Create model (downloads base from HuggingFace if needed)
120
+ model = ModernBERTWithNLI(base_model)
121
+
122
+ # Load task heads
123
+ task_heads = torch.load(task_heads_path, map_location=device)
124
+ model.load_state_dict(task_heads, strict=False)
125
+
126
+ model = model.to(device)
127
+ model.eval()
128
+
129
+ # Load tokenizer
130
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
131
+
132
+ return model, tokenizer
133
+
134
+
135
+ # Convenience functions
136
+ def predict_nli(model, tokenizer, premise: str, hypothesis: str, device: str = "cuda"):
137
+ """Predict NLI label for a premise-hypothesis pair."""
138
+ inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=512)
139
+ inputs = {k: v.to(device) for k, v in inputs.items()}
140
+
141
+ with torch.no_grad():
142
+ logits = model(**inputs, mode="nli")
143
+
144
+ probs = torch.softmax(logits, dim=-1)[0]
145
+ pred = probs.argmax().item()
146
+ labels = ["entailment", "neutral", "contradiction"]
147
+
148
+ return {
149
+ "label": labels[pred],
150
+ "confidence": probs[pred].item(),
151
+ "probs": {l: p.item() for l, p in zip(labels, probs)}
152
+ }
153
+
154
+
155
+ def predict_with_abstention(
156
+ model, tokenizer, premise: str, hypothesis: str,
157
+ device: str = "cuda", threshold: float = 0.5
158
+ ):
159
+ """Predict NLI with abstention flag."""
160
+ inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, max_length=512)
161
+ inputs = {k: v.to(device) for k, v in inputs.items()}
162
+
163
+ with torch.no_grad():
164
+ nli_logits, abstention_logits = model(**inputs, mode="abstention")
165
+
166
+ nli_probs = torch.softmax(nli_logits, dim=-1)[0]
167
+ abstention_probs = torch.softmax(abstention_logits, dim=-1)[0]
168
+
169
+ pred = nli_probs.argmax().item()
170
+ labels = ["entailment", "neutral", "contradiction"]
171
+ should_abstain = abstention_probs[1].item() > threshold
172
+
173
+ return {
174
+ "label": labels[pred],
175
+ "confidence": nli_probs[pred].item(),
176
+ "abstain": should_abstain,
177
+ "uncertainty": abstention_probs[1].item(),
178
+ "probs": {l: p.item() for l, p in zip(labels, nli_probs)}
179
+ }
180
+
181
+
182
+ if __name__ == "__main__":
183
+ # Example usage
184
+ model, tokenizer = load_modernbert_nli("task_heads.pt")
185
+
186
+ examples = [
187
+ ("A man is playing guitar.", "A person is making music."),
188
+ ("The cat is sleeping.", "The cat is running outside."),
189
+ ("A woman walks down the street.", "She is going to work."),
190
+ ]
191
+
192
+ print("NLI Predictions with Abstention:\n")
193
+ for premise, hypothesis in examples:
194
+ result = predict_with_abstention(model, tokenizer, premise, hypothesis)
195
+ status = "ABSTAIN" if result["abstain"] else "CONFIDENT"
196
+ print(f"P: {premise}")
197
+ print(f"H: {hypothesis}")
198
+ print(f"-> {result['label']} ({result['confidence']:.1%}) [{status}]\n")
task_heads.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ab7c429576ee9f7ae98b3ac7b9b66ceaf73d2cc9b74f3cde854ec324b6e8391
3
+ size 2380277