| import torch |
| import torch.nn as nn |
| from transformers import BertConfig, BertModel |
|
|
|
|
| class AttentionPool(nn.Module): |
| """Attention-based pooling layer.""" |
|
|
| def __init__(self, hidden_size): |
| super(AttentionPool, self).__init__() |
| self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1)) |
| nn.init.xavier_uniform_( |
| self.attention_weights |
| ) |
|
|
| def forward(self, hidden_states): |
| attention_scores = torch.matmul(hidden_states, self.attention_weights) |
| attention_scores = torch.softmax(attention_scores, dim=1) |
| pooled_output = torch.sum(hidden_states * attention_scores, dim=1) |
| return pooled_output |
|
|
|
|
| class GeneformerMultiTask(nn.Module): |
| def __init__( |
| self, |
| pretrained_path, |
| num_labels_list, |
| dropout_rate=0.1, |
| use_task_weights=False, |
| task_weights=None, |
| max_layers_to_freeze=0, |
| use_attention_pooling=False, |
| ): |
| super(GeneformerMultiTask, self).__init__() |
| self.config = BertConfig.from_pretrained(pretrained_path) |
| self.bert = BertModel.from_pretrained(pretrained_path) |
| self.num_labels_list = num_labels_list |
| self.use_task_weights = use_task_weights |
| self.dropout = nn.Dropout(dropout_rate) |
| self.use_attention_pooling = use_attention_pooling |
|
|
| if use_task_weights and ( |
| task_weights is None or len(task_weights) != len(num_labels_list) |
| ): |
| raise ValueError( |
| "Task weights must be defined and match the number of tasks when 'use_task_weights' is True." |
| ) |
| self.task_weights = ( |
| task_weights if use_task_weights else [1.0] * len(num_labels_list) |
| ) |
|
|
| |
| for layer in self.bert.encoder.layer[:max_layers_to_freeze]: |
| for param in layer.parameters(): |
| param.requires_grad = False |
|
|
| self.attention_pool = ( |
| AttentionPool(self.config.hidden_size) if use_attention_pooling else None |
| ) |
|
|
| self.classification_heads = nn.ModuleList( |
| [ |
| nn.Linear(self.config.hidden_size, num_labels) |
| for num_labels in num_labels_list |
| ] |
| ) |
| |
| for head in self.classification_heads: |
| nn.init.xavier_uniform_(head.weight) |
| nn.init.zeros_(head.bias) |
|
|
| def forward(self, input_ids, attention_mask, labels=None): |
| try: |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| except Exception as e: |
| raise RuntimeError(f"Error during BERT forward pass: {e}") |
|
|
| sequence_output = outputs.last_hidden_state |
|
|
| try: |
| pooled_output = ( |
| self.attention_pool(sequence_output) |
| if self.use_attention_pooling |
| else sequence_output[:, 0, :] |
| ) |
| pooled_output = self.dropout(pooled_output) |
| except Exception as e: |
| raise RuntimeError(f"Error during pooling and dropout: {e}") |
|
|
| total_loss = 0 |
| logits = [] |
| losses = [] |
|
|
| for task_id, (head, num_labels) in enumerate( |
| zip(self.classification_heads, self.num_labels_list) |
| ): |
| try: |
| task_logits = head(pooled_output) |
| except Exception as e: |
| raise RuntimeError( |
| f"Error during forward pass of classification head {task_id}: {e}" |
| ) |
|
|
| logits.append(task_logits) |
|
|
| if labels is not None: |
| try: |
| loss_fct = nn.CrossEntropyLoss() |
| task_loss = loss_fct( |
| task_logits.view(-1, num_labels), labels[task_id].view(-1) |
| ) |
| if self.use_task_weights: |
| task_loss *= self.task_weights[task_id] |
| total_loss += task_loss |
| losses.append(task_loss.item()) |
| except Exception as e: |
| raise RuntimeError( |
| f"Error during loss computation for task {task_id}: {e}" |
| ) |
|
|
| return total_loss, logits, losses if labels is not None else logits |