Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel | |
| class DeBERTaLSTMClassifier(nn.Module): | |
| def __init__(self, hidden_dim=128, num_labels=2): | |
| super().__init__() | |
| self.deberta = AutoModel.from_pretrained("microsoft/deberta-base") | |
| for param in self.deberta.parameters(): | |
| param.requires_grad = False # freeze DeBERTa (as we don't have enough resources, we will not train DeBERTa in this model) | |
| self.lstm = nn.LSTM( | |
| input_size=self.deberta.config.hidden_size, | |
| hidden_size=hidden_dim, | |
| batch_first=True, | |
| bidirectional=True | |
| ) | |
| self.fc = nn.Linear(hidden_dim * 2, num_labels) | |
| # Attention layer để tính token importance | |
| self.attention = nn.Linear(hidden_dim * 2, 1) | |
| def forward(self, input_ids, attention_mask, return_attention=False): | |
| with torch.no_grad(): | |
| outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True) | |
| lstm_out, _ = self.lstm(outputs.last_hidden_state) # shape: [batch, seq_len, hidden*2] | |
| if return_attention: | |
| # Tính attention weights cho từng token | |
| attention_weights = self.attention(lstm_out) # [batch, seq_len, 1] | |
| attention_weights = F.softmax(attention_weights.squeeze(-1), dim=-1) # [batch, seq_len] | |
| # Apply attention mask | |
| attention_weights = attention_weights * attention_mask.float() | |
| attention_weights = attention_weights / (attention_weights.sum(dim=-1, keepdim=True) + 1e-8) | |
| # Weighted sum of LSTM outputs | |
| attended_output = torch.sum(lstm_out * attention_weights.unsqueeze(-1), dim=1) | |
| logits = self.fc(attended_output) | |
| return logits, attention_weights, outputs.attentions | |
| else: | |
| final_hidden = lstm_out[:, -1, :] # last token output | |
| logits = self.fc(final_hidden) | |
| return logits |