File size: 9,734 Bytes
59a2ae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
from typing import Optional, Union, Tuple
import torch
from torch import nn
from torch.nn.functional import binary_cross_entropy_with_logits
from transformers import PreTrainedModel
from transformers.models.deberta.configuration_deberta import DebertaConfig
from transformers.models.deberta.modeling_deberta import DebertaModel
from transformers.modeling_outputs import SequenceClassifierOutput
class ResidualBlock(nn.Module):
def __init__(self, input_dim: int, output_dim: int, num_groups: int = 8):
super().__init__()
self.linear_layers = nn.Sequential(
nn.Linear(input_dim, 512),
nn.GroupNorm(num_groups, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, output_dim),
nn.GroupNorm(num_groups, output_dim),
nn.ReLU(),
)
self.projection = (
nn.Linear(input_dim, output_dim)
if input_dim != output_dim
else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_layers(x) + self.projection(x)
class EnhancedDebertaForSequenceClassification(PreTrainedModel):
"""
DeBERTa-based classifier with optional extra feature branches.
This is a HF-compatible reimplementation of your EnhancedDebertaModel.
For the *baseline* model on the Hub, all extra feature dims are zero,
so it behaves like "DeBERTa + linear multi-label head".
"""
config_class = DebertaConfig
# Optional: you can give it a custom type name if you like
model_type = "enhanced-deberta"
def __init__(self, config: DebertaConfig):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
# ---- Backbone ----
# Keep the attribute name "transformer" so old state_dict keys match.
self.transformer = DebertaModel(config)
# Extra feature dimensions (defaults for baseline are all zero)
num_categories = getattr(config, "num_categories", 0)
ling_feature_dim = getattr(config, "ling_feature_dim", 0)
ner_feature_dim = getattr(config, "ner_feature_dim", 0)
topic_feature_dim = getattr(config, "topic_feature_dim", 0)
multilayer = getattr(config, "multilayer", False)
residualblock = getattr(config, "residualblock", False)
previous_sentences = getattr(config, "previous_sentences", False)
num_groups = getattr(config, "num_groups", 8)
# ---- Lexicon branch ----
if num_categories > 0:
self.lexicon_layer = nn.Sequential(
nn.Linear(num_categories, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 128),
nn.ReLU(),
)
else:
self.lexicon_layer = None
# ---- Linguistic branch ----
if ling_feature_dim > 0:
self.ling_layer = nn.Sequential(
nn.Linear(ling_feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.ling_layer = None
# ---- NER branch ----
if ner_feature_dim > 0:
self.ner_layer = nn.Sequential(
nn.Linear(ner_feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.ner_layer = None
# ---- Topic branch ----
if topic_feature_dim > 0:
self.topic_layer = nn.Sequential(
nn.Linear(topic_feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.topic_layer = None
# ---- Text embedding head (optional multilayer / residual) ----
self.multilayer = multilayer
self.residualblock = residualblock
if multilayer:
if residualblock:
self.text_embedding_layer = ResidualBlock(
self.transformer.config.hidden_size, 256, num_groups=num_groups
)
else:
self.text_embedding_layer = nn.Sequential(
nn.Linear(self.transformer.config.hidden_size, 512),
nn.GroupNorm(num_groups, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.GroupNorm(num_groups, 256),
nn.ReLU(),
)
hidden_size = 256
else:
self.text_embedding_layer = None
hidden_size = self.transformer.config.hidden_size
# ---- Previous-sentence labels branch ----
if previous_sentences:
# 2 previous sentences × num_labels
self.prev_label_size = 2 * self.num_labels
self.prev_label_layer = nn.Sequential(
nn.Linear(self.prev_label_size, 16),
nn.ReLU(),
nn.Dropout(0.4),
)
else:
self.prev_label_size = 0
self.prev_label_layer = None
# ---- Final classification head ----
input_dim = hidden_size
if self.lexicon_layer is not None:
input_dim += 128
if self.ling_layer is not None:
input_dim += 128
if self.ner_layer is not None:
input_dim += 128
if self.topic_layer is not None:
input_dim += 128
if self.prev_label_layer is not None:
input_dim += 16
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classification_head = nn.Linear(input_dim, self.num_labels)
# label mappings (already in config, but we mirror them here)
self.id2label = getattr(config, "id2label", None)
self.label2id = getattr(config, "label2id", None)
# Initialize weights (will be overwritten by from_pretrained)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
lexicon_features: Optional[torch.Tensor] = None,
linguistic_features: Optional[torch.Tensor] = None,
ner_features: Optional[torch.Tensor] = None,
topic_features: Optional[torch.Tensor] = None,
prev_label_features: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> SequenceClassifierOutput:
"""
Forward pass.
Extra feature tensors (lexicon_features, linguistic_features, etc.)
are expected to be of shape [batch_size, feat_dim] when used.
"""
# Ensure integer token IDs
if input_ids is not None:
input_ids = input_ids.to(torch.long)
# ---- Transformer backbone ----
if inputs_embeds is not None:
backbone_outputs = self.transformer(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
else:
backbone_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
)
# CLS representation
hidden_state = backbone_outputs.last_hidden_state
cls_embed = hidden_state[:, 0, :] # [batch, hidden]
# Optional multilayer / residual processing
if self.text_embedding_layer is not None:
text_embeddings = self.text_embedding_layer(cls_embed)
else:
text_embeddings = cls_embed
combined = text_embeddings
# ---- Lexicon branch ----
if self.lexicon_layer is not None and lexicon_features is not None:
lexicon_features = lexicon_features.to(torch.float32)
lexicon_output = self.lexicon_layer(lexicon_features)
combined = torch.cat([combined, lexicon_output], dim=-1)
# ---- Linguistic branch ----
if self.ling_layer is not None and linguistic_features is not None:
linguistic_features = linguistic_features.to(combined.device)
ling_output = self.ling_layer(linguistic_features)
combined = torch.cat([combined, ling_output], dim=-1)
# ---- NER branch ----
if self.ner_layer is not None and ner_features is not None:
ner_features = ner_features.to(combined.device)
ner_output = self.ner_layer(ner_features)
combined = torch.cat([combined, ner_output], dim=-1)
# ---- Topic branch ----
if self.topic_layer is not None and topic_features is not None:
topic_features = topic_features.to(combined.device)
topic_output = self.topic_layer(topic_features)
combined = torch.cat([combined, topic_output], dim=-1)
# ---- Previous-sentence labels branch ----
if self.prev_label_layer is not None and prev_label_features is not None:
prev_label_features = prev_label_features.to(combined.device).float()
prev_output = self.prev_label_layer(prev_label_features)
combined = torch.cat([combined, prev_output], dim=-1)
combined = self.dropout(combined)
logits = self.classification_head(combined)
loss = None
if labels is not None:
labels = labels.float()
if labels.dim() == 1:
labels = labels.unsqueeze(1)
loss = binary_cross_entropy_with_logits(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=backbone_outputs.hidden_states,
attentions=backbone_outputs.attentions,
) |