VictorYeste commited on
Commit
59a2ae6
·
verified ·
1 Parent(s): 4c09c2a

Create modeling_enhanced_deberta.py

Browse files
Files changed (1) hide show
  1. modeling_enhanced_deberta.py +262 -0
modeling_enhanced_deberta.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.functional import binary_cross_entropy_with_logits
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.models.deberta.configuration_deberta import DebertaConfig
9
+ from transformers.models.deberta.modeling_deberta import DebertaModel
10
+ from transformers.modeling_outputs import SequenceClassifierOutput
11
+
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, input_dim: int, output_dim: int, num_groups: int = 8):
15
+ super().__init__()
16
+ self.linear_layers = nn.Sequential(
17
+ nn.Linear(input_dim, 512),
18
+ nn.GroupNorm(num_groups, 512),
19
+ nn.ReLU(),
20
+ nn.Dropout(0.4),
21
+ nn.Linear(512, output_dim),
22
+ nn.GroupNorm(num_groups, output_dim),
23
+ nn.ReLU(),
24
+ )
25
+ self.projection = (
26
+ nn.Linear(input_dim, output_dim)
27
+ if input_dim != output_dim
28
+ else nn.Identity()
29
+ )
30
+
31
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
32
+ return self.linear_layers(x) + self.projection(x)
33
+
34
+
35
+ class EnhancedDebertaForSequenceClassification(PreTrainedModel):
36
+ """
37
+ DeBERTa-based classifier with optional extra feature branches.
38
+
39
+ This is a HF-compatible reimplementation of your EnhancedDebertaModel.
40
+ For the *baseline* model on the Hub, all extra feature dims are zero,
41
+ so it behaves like "DeBERTa + linear multi-label head".
42
+ """
43
+
44
+ config_class = DebertaConfig
45
+ # Optional: you can give it a custom type name if you like
46
+ model_type = "enhanced-deberta"
47
+
48
+ def __init__(self, config: DebertaConfig):
49
+ super().__init__(config)
50
+ self.config = config
51
+ self.num_labels = config.num_labels
52
+
53
+ # ---- Backbone ----
54
+ # Keep the attribute name "transformer" so old state_dict keys match.
55
+ self.transformer = DebertaModel(config)
56
+
57
+ # Extra feature dimensions (defaults for baseline are all zero)
58
+ num_categories = getattr(config, "num_categories", 0)
59
+ ling_feature_dim = getattr(config, "ling_feature_dim", 0)
60
+ ner_feature_dim = getattr(config, "ner_feature_dim", 0)
61
+ topic_feature_dim = getattr(config, "topic_feature_dim", 0)
62
+ multilayer = getattr(config, "multilayer", False)
63
+ residualblock = getattr(config, "residualblock", False)
64
+ previous_sentences = getattr(config, "previous_sentences", False)
65
+ num_groups = getattr(config, "num_groups", 8)
66
+
67
+ # ---- Lexicon branch ----
68
+ if num_categories > 0:
69
+ self.lexicon_layer = nn.Sequential(
70
+ nn.Linear(num_categories, 256),
71
+ nn.ReLU(),
72
+ nn.Dropout(0.4),
73
+ nn.Linear(256, 128),
74
+ nn.ReLU(),
75
+ )
76
+ else:
77
+ self.lexicon_layer = None
78
+
79
+ # ---- Linguistic branch ----
80
+ if ling_feature_dim > 0:
81
+ self.ling_layer = nn.Sequential(
82
+ nn.Linear(ling_feature_dim, 128),
83
+ nn.ReLU(),
84
+ nn.Dropout(0.4),
85
+ )
86
+ else:
87
+ self.ling_layer = None
88
+
89
+ # ---- NER branch ----
90
+ if ner_feature_dim > 0:
91
+ self.ner_layer = nn.Sequential(
92
+ nn.Linear(ner_feature_dim, 128),
93
+ nn.ReLU(),
94
+ nn.Dropout(0.4),
95
+ )
96
+ else:
97
+ self.ner_layer = None
98
+
99
+ # ---- Topic branch ----
100
+ if topic_feature_dim > 0:
101
+ self.topic_layer = nn.Sequential(
102
+ nn.Linear(topic_feature_dim, 128),
103
+ nn.ReLU(),
104
+ nn.Dropout(0.4),
105
+ )
106
+ else:
107
+ self.topic_layer = None
108
+
109
+ # ---- Text embedding head (optional multilayer / residual) ----
110
+ self.multilayer = multilayer
111
+ self.residualblock = residualblock
112
+
113
+ if multilayer:
114
+ if residualblock:
115
+ self.text_embedding_layer = ResidualBlock(
116
+ self.transformer.config.hidden_size, 256, num_groups=num_groups
117
+ )
118
+ else:
119
+ self.text_embedding_layer = nn.Sequential(
120
+ nn.Linear(self.transformer.config.hidden_size, 512),
121
+ nn.GroupNorm(num_groups, 512),
122
+ nn.ReLU(),
123
+ nn.Dropout(0.4),
124
+ nn.Linear(512, 256),
125
+ nn.GroupNorm(num_groups, 256),
126
+ nn.ReLU(),
127
+ )
128
+ hidden_size = 256
129
+ else:
130
+ self.text_embedding_layer = None
131
+ hidden_size = self.transformer.config.hidden_size
132
+
133
+ # ---- Previous-sentence labels branch ----
134
+ if previous_sentences:
135
+ # 2 previous sentences × num_labels
136
+ self.prev_label_size = 2 * self.num_labels
137
+ self.prev_label_layer = nn.Sequential(
138
+ nn.Linear(self.prev_label_size, 16),
139
+ nn.ReLU(),
140
+ nn.Dropout(0.4),
141
+ )
142
+ else:
143
+ self.prev_label_size = 0
144
+ self.prev_label_layer = None
145
+
146
+ # ---- Final classification head ----
147
+ input_dim = hidden_size
148
+ if self.lexicon_layer is not None:
149
+ input_dim += 128
150
+ if self.ling_layer is not None:
151
+ input_dim += 128
152
+ if self.ner_layer is not None:
153
+ input_dim += 128
154
+ if self.topic_layer is not None:
155
+ input_dim += 128
156
+ if self.prev_label_layer is not None:
157
+ input_dim += 16
158
+
159
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
160
+ self.classification_head = nn.Linear(input_dim, self.num_labels)
161
+
162
+ # label mappings (already in config, but we mirror them here)
163
+ self.id2label = getattr(config, "id2label", None)
164
+ self.label2id = getattr(config, "label2id", None)
165
+
166
+ # Initialize weights (will be overwritten by from_pretrained)
167
+ self.post_init()
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: Optional[torch.Tensor] = None,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ inputs_embeds: Optional[torch.Tensor] = None,
174
+ lexicon_features: Optional[torch.Tensor] = None,
175
+ linguistic_features: Optional[torch.Tensor] = None,
176
+ ner_features: Optional[torch.Tensor] = None,
177
+ topic_features: Optional[torch.Tensor] = None,
178
+ prev_label_features: Optional[torch.Tensor] = None,
179
+ labels: Optional[torch.Tensor] = None,
180
+ **kwargs,
181
+ ) -> SequenceClassifierOutput:
182
+ """
183
+ Forward pass.
184
+
185
+ Extra feature tensors (lexicon_features, linguistic_features, etc.)
186
+ are expected to be of shape [batch_size, feat_dim] when used.
187
+ """
188
+
189
+ # Ensure integer token IDs
190
+ if input_ids is not None:
191
+ input_ids = input_ids.to(torch.long)
192
+
193
+ # ---- Transformer backbone ----
194
+ if inputs_embeds is not None:
195
+ backbone_outputs = self.transformer(
196
+ inputs_embeds=inputs_embeds,
197
+ attention_mask=attention_mask,
198
+ )
199
+ else:
200
+ backbone_outputs = self.transformer(
201
+ input_ids=input_ids,
202
+ attention_mask=attention_mask,
203
+ )
204
+
205
+ # CLS representation
206
+ hidden_state = backbone_outputs.last_hidden_state
207
+ cls_embed = hidden_state[:, 0, :] # [batch, hidden]
208
+
209
+ # Optional multilayer / residual processing
210
+ if self.text_embedding_layer is not None:
211
+ text_embeddings = self.text_embedding_layer(cls_embed)
212
+ else:
213
+ text_embeddings = cls_embed
214
+
215
+ combined = text_embeddings
216
+
217
+ # ---- Lexicon branch ----
218
+ if self.lexicon_layer is not None and lexicon_features is not None:
219
+ lexicon_features = lexicon_features.to(torch.float32)
220
+ lexicon_output = self.lexicon_layer(lexicon_features)
221
+ combined = torch.cat([combined, lexicon_output], dim=-1)
222
+
223
+ # ---- Linguistic branch ----
224
+ if self.ling_layer is not None and linguistic_features is not None:
225
+ linguistic_features = linguistic_features.to(combined.device)
226
+ ling_output = self.ling_layer(linguistic_features)
227
+ combined = torch.cat([combined, ling_output], dim=-1)
228
+
229
+ # ---- NER branch ----
230
+ if self.ner_layer is not None and ner_features is not None:
231
+ ner_features = ner_features.to(combined.device)
232
+ ner_output = self.ner_layer(ner_features)
233
+ combined = torch.cat([combined, ner_output], dim=-1)
234
+
235
+ # ---- Topic branch ----
236
+ if self.topic_layer is not None and topic_features is not None:
237
+ topic_features = topic_features.to(combined.device)
238
+ topic_output = self.topic_layer(topic_features)
239
+ combined = torch.cat([combined, topic_output], dim=-1)
240
+
241
+ # ---- Previous-sentence labels branch ----
242
+ if self.prev_label_layer is not None and prev_label_features is not None:
243
+ prev_label_features = prev_label_features.to(combined.device).float()
244
+ prev_output = self.prev_label_layer(prev_label_features)
245
+ combined = torch.cat([combined, prev_output], dim=-1)
246
+
247
+ combined = self.dropout(combined)
248
+ logits = self.classification_head(combined)
249
+
250
+ loss = None
251
+ if labels is not None:
252
+ labels = labels.float()
253
+ if labels.dim() == 1:
254
+ labels = labels.unsqueeze(1)
255
+ loss = binary_cross_entropy_with_logits(logits, labels)
256
+
257
+ return SequenceClassifierOutput(
258
+ loss=loss,
259
+ logits=logits,
260
+ hidden_states=backbone_outputs.hidden_states,
261
+ attentions=backbone_outputs.attentions,
262
+ )