| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import AutoModel |
| |
|
| | class DeBERTaLSTMClassifier(nn.Module): |
| | def __init__(self, hidden_dim=128, num_labels=2): |
| | super().__init__() |
| | self.deberta = AutoModel.from_pretrained("microsoft/deberta-base") |
| | |
| | |
| | for param in self.deberta.parameters(): |
| | param.requires_grad = False |
| | |
| | self.lstm = nn.LSTM( |
| | input_size=self.deberta.config.hidden_size, |
| | hidden_size=hidden_dim, |
| | batch_first=True, |
| | bidirectional=True |
| | ) |
| | |
| | |
| | self.attention = nn.Linear(hidden_dim * 2, 1) |
| | |
| | self.fc = nn.Linear(hidden_dim * 2, num_labels) |
| |
|
| | def forward(self, input_ids, attention_mask, return_attention=False): |
| | |
| | with torch.no_grad(): |
| | outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True) |
| | |
| | |
| | lstm_out, _ = self.lstm(outputs.last_hidden_state) |
| | |
| | |
| | |
| | attn_scores = self.attention(lstm_out).squeeze(-1) |
| | |
| | |
| | |
| | mask = attention_mask.float() |
| | attn_scores = attn_scores.masked_fill(mask == 0, -1e9) |
| | |
| | |
| | attn_weights = F.softmax(attn_scores, dim=-1) |
| | |
| | |
| | |
| | context_vector = torch.sum(attn_weights.unsqueeze(-1) * lstm_out, dim=1) |
| | |
| | |
| | logits = self.fc(context_vector) |
| | |
| | |
| | if return_attention: |
| | return logits, attn_weights, outputs.attentions |
| | else: |
| | return logits |