mpnet-EVT-classifier / modeling_evt.py
christiqn's picture
Upload EVTClassifier
d71d985 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
# ==========================================
# Configuration Class
# ==========================================
class EVTConfig(PretrainedConfig):
"""Configuration class for the EVT Classifier."""
model_type = "evt_classifier"
def __init__(
self,
base_model_path="sentence-transformers/all-mpnet-base-v2",
num_core_labels=5,
temperature=20.0,
hidden_size=768,
dropout_rate=0.2,
pos_weights=None,
**kwargs
):
"""
Args:
base_model_path (str): The path to the base transformer model.
num_core_labels (int): Number of output labels for the classifier.
temperature (float): Scaling factor for the NormLinear layer.
hidden_size (int): Hidden size of the base model.
dropout_rate (float): Dropout rate for regularization.
pos_weights (list): Positive class weights for handling class imbalance.
**kwargs: Additional keyword arguments for PretrainedConfig.
"""
self.base_model_path = base_model_path
self.num_core_labels = num_core_labels
self.temperature = temperature
self.hidden_size = hidden_size
self.dropout_rate = dropout_rate
self.pos_weights = pos_weights
# Default labels for the EVT Classifier
if "id2label" not in kwargs:
kwargs["id2label"] = {
0: "ATTAINMENT_VALUE",
1: "COST",
2: "EXPECTANCY",
3: "INTRINSIC_VALUE",
4: "UTILITY_VALUE"
}
if "label2id" not in kwargs:
kwargs["label2id"] = {v: k for k, v in kwargs["id2label"].items()}
super().__init__(**kwargs)
# ==========================================
# Model Architecture Components
# ==========================================
class ConcatPooling(nn.Module):
"""Concatenates mean and max pooling of hidden states."""
def forward(self, hidden_states, attention_mask):
mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_emb = torch.sum(hidden_states * mask, 1)
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pool = sum_emb / sum_mask
masked = hidden_states.masked_fill(mask == 0, -1e9)
max_pool = torch.max(masked, dim=1)[0]
return torch.cat([mean_pool, max_pool], dim=-1)
class NormLinear(nn.Module):
"""Linear layer using cosine similarity and a temperature scaling factor."""
def __init__(self, in_features, out_features, temperature=20.0):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.temperature = temperature
def forward(self, x):
return F.linear(F.normalize(x, p=2, dim=1),
F.normalize(self.weight, p=2, dim=1)) * self.temperature
class EVTClassifier(PreTrainedModel):
"""Custom classifier model for Expectancy-Value Theory (EVT) text classification.
5-class OvR sigmoid classifier with cosine head for multi-label classification.
"""
config_class = EVTConfig
model_type = "evt_classifier"
# We don't set base_model_prefix to avoid conflicts with PreTrainedModel.base_model property
# Required for transformers 5.x compatibility
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
"""Property for compatibility with transformers 5.x. Returns empty dict since no weights are tied."""
return {}
def __init__(self, config):
super().__init__(config)
# Initialize the base transformer model
# Using 'transformer' as the attribute name to avoid conflict with PreTrainedModel
base_config = AutoConfig.from_pretrained(config.base_model_path)
self.transformer = AutoModel.from_config(base_config)
self.config = config
h = config.hidden_size if hasattr(config, "hidden_size") else self.transformer.config.hidden_size
# Custom layers matching training architecture
self.pooler = ConcatPooling()
self.dropout = nn.Dropout(getattr(config, "dropout_rate", 0.2))
self.dense = nn.Linear(h * 2, h)
self.gelu = nn.GELU()
self.norm_head = NormLinear(
h,
config.num_core_labels,
temperature=getattr(config, "temperature", 20.0)
)
# Register pos_weights buffer for class imbalance handling
pos_weights = getattr(config, "pos_weights", None)
if pos_weights is not None:
self.register_buffer("pos_weights", torch.tensor(pos_weights, dtype=torch.float32))
else:
self.register_buffer("pos_weights", torch.ones(config.num_core_labels))
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
"""
Forward pass matching training architecture.
Returns SequenceClassifierOutput with logits (sigmoid-based, not softmax).
"""
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
pooled = self.pooler(outputs.last_hidden_state, attention_mask)
x = self.dropout(pooled)
x = self.gelu(self.dense(x))
x = self.dropout(x)
logits = self.norm_head(x)
return SequenceClassifierOutput(logits=logits)
def load_state_dict(self, state_dict, strict=True):
"""
Custom loading to map 'base_model.' keys from training to 'transformer.' in this model.
This handles the mismatch between training script and refactored architecture.
"""
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith("base_model."):
# Map base_model.* to transformer.*
new_key = key.replace("base_model.", "transformer.", 1)
new_state_dict[new_key] = value
else:
# Keep other keys as-is (dense, norm_head, pos_weights)
new_state_dict[key] = value
return super().load_state_dict(new_state_dict, strict=strict)
# Register both classes for Hugging Face AutoClasses
EVTConfig.register_for_auto_class()
EVTClassifier.register_for_auto_class("AutoModel")