| |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, AutoModel, PretrainedConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| |
| class DualModernBERTConfig(PretrainedConfig): |
| model_type = "dual-modernbert" |
|
|
| def __init__( |
| self, |
| base_model_name="answerdotai/ModernBERT-base", |
| fusion_hidden_dim=512, |
| ordinal_dropout=0.5, |
| encoder_dropout=0.35, |
| fusion_dropout_rates={ |
| 'cross_attn': 0.3, |
| 'transform_dropout': 0.4, |
| 'fusion_dropout1': 0.5, |
| 'fusion_dropout2': 0.45, |
| 'gate_dropout': 0.3, |
| 'output_dropout': 0.45 |
| }, |
| num_labels=5, |
| num_ordinal_labels=4, |
| freeze_base_encoder_layers=5, |
| problem_type="ordinal_regression", |
| |
| loss_beta=0.9999, |
| loss_base_boundary_weight=0.1, |
| loss_boundary_weights=[1.0, 1.2, 1.2, 1.0], |
| loss_smoothing=0.1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.base_model_name = base_model_name |
| self.fusion_hidden_dim = fusion_hidden_dim |
| self.ordinal_dropout = ordinal_dropout |
| self.encoder_dropout = encoder_dropout |
| self.fusion_dropout_rates = fusion_dropout_rates |
| self.num_labels = num_labels |
| self.num_ordinal_labels = num_ordinal_labels |
| self.freeze_base_encoder_layers = freeze_base_encoder_layers |
| self.problem_type = problem_type |
| |
| self.loss_beta = loss_beta |
| self.loss_base_boundary_weight = loss_base_boundary_weight |
| self.loss_boundary_weights = loss_boundary_weights |
| self.loss_smoothing = loss_smoothing |
| |
| |
|
|
| |
|
|
| |
| class EnhancedFusion(nn.Module): |
| def __init__(self, config: DualModernBERTConfig): |
| super().__init__() |
| hidden_dim = config.fusion_hidden_dim |
| dropout_rates = config.fusion_dropout_rates |
| base_hidden_size = getattr(config, 'hidden_size', 768) |
|
|
| |
| self.cross_attention = nn.ModuleDict({ |
| 'title2text': nn.MultiheadAttention(embed_dim=base_hidden_size, num_heads=12, dropout=dropout_rates['cross_attn']), |
| 'text2title': nn.MultiheadAttention(embed_dim=base_hidden_size, num_heads=12, dropout=dropout_rates['cross_attn']), |
| 'self_title': nn.MultiheadAttention(embed_dim=base_hidden_size, num_heads=12, dropout=dropout_rates['cross_attn']), |
| 'self_text': nn.MultiheadAttention(embed_dim=base_hidden_size, num_heads=12, dropout=dropout_rates['cross_attn']) |
| }) |
|
|
| |
| self.scale_projections = nn.ModuleDict({ |
| 'scale1': nn.Linear(base_hidden_size, hidden_dim), |
| 'scale2': nn.Linear(base_hidden_size, hidden_dim // 2), |
| 'scale3': nn.Linear(base_hidden_size, hidden_dim // 4) |
| }) |
|
|
| |
| self.feature_transform = nn.ModuleDict({ |
| 'title': nn.Sequential( |
| nn.Linear(base_hidden_size, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), |
| nn.Dropout(dropout_rates['transform_dropout']), |
| nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), |
| nn.Dropout(dropout_rates['transform_dropout']) |
| ), |
| 'text': nn.Sequential( |
| nn.Linear(base_hidden_size, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), |
| nn.Dropout(dropout_rates['transform_dropout']), |
| nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), |
| nn.Dropout(dropout_rates['transform_dropout']) |
| ) |
| }) |
|
|
| |
| self.fusion_network = nn.Sequential( |
| nn.Linear(hidden_dim * 4, hidden_dim * 3), nn.LayerNorm(hidden_dim * 3), nn.GELU(), |
| nn.Dropout(dropout_rates['fusion_dropout1']), |
| nn.Linear(hidden_dim * 3, hidden_dim * 2), nn.LayerNorm(hidden_dim * 2), nn.GELU(), |
| nn.Dropout(dropout_rates['fusion_dropout2']), |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim) |
| ) |
|
|
| |
| self.cross_connections = nn.ModuleDict({ |
| 'title': nn.Linear(hidden_dim * 2, hidden_dim), |
| 'text': nn.Linear(hidden_dim * 2, hidden_dim) |
| }) |
|
|
| |
| self.residual_proj = nn.Sequential( |
| nn.Linear(base_hidden_size * 2, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU() |
| ) |
|
|
| |
| self.gate = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), |
| nn.Dropout(dropout_rates['gate_dropout']), |
| nn.Linear(hidden_dim, hidden_dim), nn.Sigmoid() |
| ) |
|
|
| |
| self.output_layer = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), |
| nn.Dropout(dropout_rates['output_dropout']), |
| nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim) |
| ) |
|
|
| def forward(self, title, text): |
| |
| title_q = title.unsqueeze(0) |
| text_q = text.unsqueeze(0) |
|
|
| |
| title2text, _ = self.cross_attention['title2text'](text_q, title_q, title_q) |
| text2title, _ = self.cross_attention['text2title'](title_q, text_q, text_q) |
| title_self, _ = self.cross_attention['self_title'](title_q, title_q, title_q) |
| text_self, _ = self.cross_attention['self_text'](text_q, text_q, text_q) |
|
|
| |
| title_feats = self.feature_transform['title'](title2text.squeeze(0)) |
| text_feats = self.feature_transform['text'](text2title.squeeze(0)) |
| title_self_feats = self.feature_transform['title'](title_self.squeeze(0)) |
| text_self_feats = self.feature_transform['text'](text_self.squeeze(0)) |
|
|
| |
| title_cross = self.cross_connections['title'](torch.cat([title_feats, title_self_feats], dim=-1)) |
| text_cross = self.cross_connections['text'](torch.cat([text_feats, text_self_feats], dim=-1)) |
|
|
| |
| title_scales = {scale: proj(title) for scale, proj in self.scale_projections.items()} |
| text_scales = {scale: proj(text) for scale, proj in self.scale_projections.items()} |
|
|
| |
| fused_features = torch.cat([ |
| title_cross, text_cross, |
| title_scales['scale1'], text_scales['scale1'] |
| ], dim=-1) |
| fused = self.fusion_network(fused_features) |
|
|
| |
| residual = self.residual_proj(torch.cat([title, text], dim=-1)) |
|
|
| |
| gate_input = torch.cat([fused, residual], dim=-1) |
| gate = self.gate(gate_input) |
| gated_fusion = gate * fused + (1 - gate) * residual |
|
|
| |
| output = self.output_layer(gated_fusion) |
| return output |
|
|
| |
| class OrdinalLayer(nn.Module): |
| def __init__(self, config: DualModernBERTConfig): |
| super().__init__() |
| input_dim = config.fusion_hidden_dim |
| self.ordinal = nn.Sequential( |
| nn.Dropout(config.ordinal_dropout), |
| nn.Linear(input_dim, config.num_ordinal_labels) |
| ) |
|
|
| def forward(self, x): |
| return self.ordinal(x) |
|
|
| |
| |
| |
| |
| class SimpleEnhancedOrdinalLoss(nn.Module): |
| def __init__(self, config: DualModernBERTConfig): |
| super().__init__() |
| self.num_ordinal_labels = config.num_ordinal_labels |
| self.smoothing = config.loss_smoothing |
| self.base_boundary_weight = config.loss_base_boundary_weight |
| |
| self.boundary_weights = torch.tensor(config.loss_boundary_weights, dtype=torch.float) |
| |
| self.register_buffer('boundary_weights_tensor', self.boundary_weights) |
|
|
|
|
| def get_boundary_weight(self, pos): |
| """获取边界权重""" |
| |
| if pos < len(self.boundary_weights_tensor): |
| return self.base_boundary_weight * self.boundary_weights_tensor[pos] |
| else: |
| |
| return self.base_boundary_weight |
|
|
| def forward(self, predictions, targets): |
| |
| smoothed_targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing |
|
|
| |
| bce_loss = F.binary_cross_entropy_with_logits(predictions, smoothed_targets, reduction='none') |
|
|
| |
| probs = torch.sigmoid(predictions) |
| boundary_penalty = torch.zeros_like(bce_loss) |
|
|
| |
| current_device = predictions.device |
| self.boundary_weights_tensor = self.boundary_weights_tensor.to(current_device) |
|
|
|
|
| for i in range(predictions.size(1) - 1): |
| diff = torch.abs(probs[:, i] - probs[:, i + 1]) |
| penalty = torch.exp(-diff) * 0.5 |
| adaptive_weight = self.get_boundary_weight(i) |
| boundary_penalty[:, i] = adaptive_weight * penalty |
|
|
| |
| final_loss = bce_loss + boundary_penalty |
| return final_loss.mean() |
|
|
| |
| class DualModernBERTModel(PreTrainedModel): |
| config_class = DualModernBERTConfig |
|
|
| def __init__(self, config: DualModernBERTConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| print(f"Initializing title encoder from: {config.base_model_name}") |
| self.title_encoder = AutoModel.from_pretrained( |
| config.base_model_name, |
| add_pooling_layer=False, |
| trust_remote_code=True, |
| |
| ) |
| print(f"Initializing text encoder from: {config.base_model_name}") |
| self.text_encoder = AutoModel.from_pretrained( |
| config.base_model_name, |
| add_pooling_layer=False, |
| trust_remote_code=True, |
| |
| ) |
| |
|
|
| |
| if not hasattr(config, 'hidden_size'): |
| self.config.hidden_size = self.title_encoder.config.hidden_size |
|
|
| self.title_dropout = nn.Dropout(config.encoder_dropout) |
| self.text_dropout = nn.Dropout(config.encoder_dropout) |
| self.fusion = EnhancedFusion(config) |
| self.ordinal_layer = OrdinalLayer(config) |
|
|
| |
| |
| self.criterion = SimpleEnhancedOrdinalLoss(config) |
| |
|
|
| |
| self._freeze_encoder_layers(config.freeze_base_encoder_layers) |
|
|
| self.post_init() |
|
|
| def _freeze_encoder_layers(self, num_layers_to_freeze): |
| """冻结两个编码器的底层""" |
| if num_layers_to_freeze > 0: |
| print(f"Freezing first {num_layers_to_freeze} layers of both encoders.") |
| for encoder in [self.title_encoder, self.text_encoder]: |
| if hasattr(encoder, 'layers'): |
| num_actual_layers = len(encoder.layers) |
| layers_to_freeze_count = min(num_layers_to_freeze, num_actual_layers) |
| for i in range(layers_to_freeze_count): |
| for param in encoder.layers[i].parameters(): |
| param.requires_grad = False |
| elif hasattr(encoder, 'encoder') and hasattr(encoder.encoder, 'layer'): |
| num_actual_layers = len(encoder.encoder.layer) |
| layers_to_freeze_count = min(num_layers_to_freeze, num_actual_layers) |
| for i in range(layers_to_freeze_count): |
| for param in encoder.encoder.layer[i].parameters(): |
| param.requires_grad = False |
| else: |
| print(f"Warning: Could not find layers attribute typical for freezing in {encoder.__class__.__name__}. Freezing skipped for this encoder.") |
|
|
| |
| def forward( |
| self, |
| title_input_ids=None, |
| title_attention_mask=None, |
| title_token_type_ids=None, |
| text_input_ids=None, |
| text_attention_mask=None, |
| text_token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| inputs_embeds=None, |
| labels=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if title_input_ids is None or text_input_ids is None: |
| raise ValueError("Both title_input_ids and text_input_ids must be provided.") |
| if title_attention_mask is None: |
| title_attention_mask = torch.ones_like(title_input_ids) |
| if text_attention_mask is None: |
| text_attention_mask = torch.ones_like(text_input_ids) |
|
|
| |
| title_outputs = self.title_encoder( |
| input_ids=title_input_ids, |
| attention_mask=title_attention_mask, |
| token_type_ids=title_token_type_ids, |
| |
| head_mask=head_mask, |
| |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| |
| title_features = title_outputs[0][:, 0] |
|
|
| text_outputs = self.text_encoder( |
| input_ids=text_input_ids, |
| attention_mask=text_attention_mask, |
| token_type_ids=text_token_type_ids, |
| |
| head_mask=head_mask, |
| |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| |
| text_features = text_outputs[0][:, 0] |
| |
|
|
| title_features_dropped = self.title_dropout(title_features) |
| text_features_dropped = self.text_dropout(text_features) |
|
|
| fused_features = self.fusion(title_features_dropped, text_features_dropped) |
| logits = self.ordinal_layer(fused_features) |
|
|
| loss = None |
| if labels is not None: |
| |
| |
| loss = self.criterion(logits, labels.float()) |
| |
|
|
| |
| if not return_dict: |
| |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| |
| merged_hidden_states = None |
| merged_attentions = None |
| |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=merged_hidden_states, |
| attentions=merged_attentions, |
| ) |
|
|
| |
| from transformers import AutoConfig, AutoModelForSequenceClassification |
|
|
| AutoConfig.register("dual-modernbert", DualModernBERTConfig) |
| |
| |
| AutoModelForSequenceClassification.register(DualModernBERTConfig, DualModernBERTModel) |