| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from transformers import AutoModel
|
|
|
| class DeBERTaLSTMClassifier(nn.Module):
|
| def __init__(self, hidden_dim=128, num_labels=2):
|
| super().__init__()
|
|
|
| self.deberta = AutoModel.from_pretrained("microsoft/deberta-base")
|
| for param in self.deberta.parameters():
|
| param.requires_grad = False
|
|
|
| self.lstm = nn.LSTM(
|
| input_size=self.deberta.config.hidden_size,
|
| hidden_size=hidden_dim,
|
| batch_first=True,
|
| bidirectional=True
|
| )
|
|
|
| self.fc = nn.Linear(hidden_dim * 2, num_labels)
|
|
|
|
|
| self.attention = nn.Linear(hidden_dim * 2, 1)
|
|
|
| def forward(self, input_ids, attention_mask, return_attention=False):
|
| with torch.no_grad():
|
| outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True)
|
|
|
| lstm_out, _ = self.lstm(outputs.last_hidden_state)
|
|
|
| if return_attention:
|
|
|
| attention_weights = self.attention(lstm_out)
|
| attention_weights = F.softmax(attention_weights.squeeze(-1), dim=-1)
|
|
|
|
|
| attention_weights = attention_weights * attention_mask.float()
|
| attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8)
|
|
|
|
|
| attended_output = torch.sum(lstm_out * attention_weights.unsqueeze(-1), dim=1)
|
| logits = self.fc(attended_output)
|
|
|
| return logits, attention_weights, outputs.attentions
|
| else:
|
| final_hidden = lstm_out[:, -1, :]
|
| logits = self.fc(final_hidden)
|
| return logits |