import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput # ========================================== # Configuration Class # ========================================== class EVTConfig(PretrainedConfig): """Configuration class for the EVT Classifier.""" model_type = "evt_classifier" def __init__( self, base_model_path="sentence-transformers/all-mpnet-base-v2", num_core_labels=5, temperature=20.0, hidden_size=768, dropout_rate=0.2, pos_weights=None, **kwargs ): """ Args: base_model_path (str): The path to the base transformer model. num_core_labels (int): Number of output labels for the classifier. temperature (float): Scaling factor for the NormLinear layer. hidden_size (int): Hidden size of the base model. dropout_rate (float): Dropout rate for regularization. pos_weights (list): Positive class weights for handling class imbalance. **kwargs: Additional keyword arguments for PretrainedConfig. """ self.base_model_path = base_model_path self.num_core_labels = num_core_labels self.temperature = temperature self.hidden_size = hidden_size self.dropout_rate = dropout_rate self.pos_weights = pos_weights # Default labels for the EVT Classifier if "id2label" not in kwargs: kwargs["id2label"] = { 0: "ATTAINMENT_VALUE", 1: "COST", 2: "EXPECTANCY", 3: "INTRINSIC_VALUE", 4: "UTILITY_VALUE" } if "label2id" not in kwargs: kwargs["label2id"] = {v: k for k, v in kwargs["id2label"].items()} super().__init__(**kwargs) # ========================================== # Model Architecture Components # ========================================== class ConcatPooling(nn.Module): """Concatenates mean and max pooling of hidden states.""" def forward(self, hidden_states, attention_mask): mask = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() sum_emb = torch.sum(hidden_states * mask, 1) sum_mask = torch.clamp(mask.sum(1), min=1e-9) mean_pool = sum_emb / sum_mask masked = hidden_states.masked_fill(mask == 0, -1e9) max_pool = torch.max(masked, dim=1)[0] return torch.cat([mean_pool, max_pool], dim=-1) class NormLinear(nn.Module): """Linear layer using cosine similarity and a temperature scaling factor.""" def __init__(self, in_features, out_features, temperature=20.0): super().__init__() self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.temperature = temperature def forward(self, x): return F.linear(F.normalize(x, p=2, dim=1), F.normalize(self.weight, p=2, dim=1)) * self.temperature class EVTClassifier(PreTrainedModel): """Custom classifier model for Expectancy-Value Theory (EVT) text classification. 5-class OvR sigmoid classifier with cosine head for multi-label classification. """ config_class = EVTConfig model_type = "evt_classifier" # We don't set base_model_prefix to avoid conflicts with PreTrainedModel.base_model property # Required for transformers 5.x compatibility _tied_weights_keys = [] @property def all_tied_weights_keys(self): """Property for compatibility with transformers 5.x. Returns empty dict since no weights are tied.""" return {} def __init__(self, config): super().__init__(config) # Initialize the base transformer model # Using 'transformer' as the attribute name to avoid conflict with PreTrainedModel base_config = AutoConfig.from_pretrained(config.base_model_path) self.transformer = AutoModel.from_config(base_config) self.config = config h = config.hidden_size if hasattr(config, "hidden_size") else self.transformer.config.hidden_size # Custom layers matching training architecture self.pooler = ConcatPooling() self.dropout = nn.Dropout(getattr(config, "dropout_rate", 0.2)) self.dense = nn.Linear(h * 2, h) self.gelu = nn.GELU() self.norm_head = NormLinear( h, config.num_core_labels, temperature=getattr(config, "temperature", 20.0) ) # Register pos_weights buffer for class imbalance handling pos_weights = getattr(config, "pos_weights", None) if pos_weights is not None: self.register_buffer("pos_weights", torch.tensor(pos_weights, dtype=torch.float32)) else: self.register_buffer("pos_weights", torch.ones(config.num_core_labels)) def forward(self, input_ids, attention_mask, labels=None, **kwargs): """ Forward pass matching training architecture. Returns SequenceClassifierOutput with logits (sigmoid-based, not softmax). """ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask, **kwargs) pooled = self.pooler(outputs.last_hidden_state, attention_mask) x = self.dropout(pooled) x = self.gelu(self.dense(x)) x = self.dropout(x) logits = self.norm_head(x) return SequenceClassifierOutput(logits=logits) def load_state_dict(self, state_dict, strict=True): """ Custom loading to map 'base_model.' keys from training to 'transformer.' in this model. This handles the mismatch between training script and refactored architecture. """ new_state_dict = {} for key, value in state_dict.items(): if key.startswith("base_model."): # Map base_model.* to transformer.* new_key = key.replace("base_model.", "transformer.", 1) new_state_dict[new_key] = value else: # Keep other keys as-is (dense, norm_head, pos_weights) new_state_dict[key] = value return super().load_state_dict(new_state_dict, strict=strict) # Register both classes for Hugging Face AutoClasses EVTConfig.register_for_auto_class() EVTClassifier.register_for_auto_class("AutoModel")