| import torch |
| import torch.nn as nn |
| from transformers import LongformerPreTrainedModel, LongformerModel |
| import torch.nn.functional as F |
|
|
| class AttentionPooling(nn.Module): |
| def __init__(self, hidden_dim): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.query = nn.Linear(hidden_dim, hidden_dim) |
| self.energy = nn.Linear(hidden_dim, 1) |
| |
| |
| nn.init.xavier_uniform_(self.query.weight) |
| nn.init.xavier_uniform_(self.energy.weight) |
| self.query.bias.data.zero_() |
| self.energy.bias.data.zero_() |
|
|
| def forward(self, hidden_states, attention_mask=None): |
| |
| transformed = torch.tanh(self.query(hidden_states)) |
| scores = self.energy(transformed).squeeze(-1) |
| |
| |
| if attention_mask is not None: |
| scores = scores.masked_fill(attention_mask == 0, float('-inf')) |
| |
| |
| weights = F.softmax(scores, dim=-1) |
| |
| |
| pooled = torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) |
| return pooled |
|
|
|
|
|
|
| class CustomLongformerForSequenceClassification(LongformerPreTrainedModel): |
| """Longformer model with attention pooling for sequence classification. |
| |
| Uses attention pooling over the last four hidden layers instead of CLS token pooling. |
| """ |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
| |
| |
| self.longformer = LongformerModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
| |
| self.attention_poolers = nn.ModuleList([ |
| AttentionPooling(config.hidden_size) for _ in range(4) |
| ]) |
| |
| |
| self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
| |
| |
| self.post_init() |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| outputs = self.longformer( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| **kwargs |
| ) |
|
|
| |
| last_four_layers = outputs.hidden_states[-4:] |
| |
| |
| pooled = [] |
| for layer, pooler in zip(last_four_layers, self.attention_poolers): |
| pooled.append(pooler(layer, attention_mask=attention_mask)) |
| |
| |
| concatenated = torch.cat(pooled, dim=1) |
| concatenated = self.dropout(concatenated) |
| logits = self.classifier(concatenated) |
|
|
| |
| loss = None |
| if labels is not None: |
| if hasattr(self, 'loss_fct'): |
| loss = self.loss_fct(logits, labels) |
| else: |
| loss = F.mse_loss(logits, labels.float()) |
|
|
| return {'loss': loss, 'logits': logits} |
|
|
| class CustomLongformerForSequenceClassification(LongformerPreTrainedModel): |
| """Longformer model with attention pooling for sequence classification.""" |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
| |
| |
| self.longformer = LongformerModel(config) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| |
| |
| self.attention_poolers = nn.ModuleList([ |
| AttentionPooling(config.hidden_size) for _ in range(4) |
| ]) |
| |
| |
| self.classifier = nn.Linear(config.hidden_size * 4, config.num_labels) |
| |
| |
| self.post_init() |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| outputs = self.longformer( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| **kwargs |
| ) |
|
|
| |
| last_four_layers = outputs.hidden_states[-4:] |
| |
| |
| pooled = [] |
| for layer, pooler in zip(last_four_layers, self.attention_poolers): |
| pooled.append(pooler(layer, attention_mask=attention_mask)) |
| |
| |
| concatenated = torch.cat(pooled, dim=1) |
| concatenated = self.dropout(concatenated) |
| logits = self.classifier(concatenated) |
|
|
| |
| loss = None |
| if labels is not None: |
| if hasattr(self, 'loss_fct'): |
| loss = self.loss_fct(logits, labels) |
| else: |
| loss = F.mse_loss(logits.view(-1), labels.float().view(-1)) |
|
|
| return {'loss': loss, 'logits': logits} |