Spaces:
Build error
Build error
| import torch.nn as nn | |
| from transformers import DistilBertModel | |
| class MultiTaskDistilBert(nn.Module): | |
| """ | |
| Multi-task DistilBERT for classifying Report Category, Irregularity Category, Area, and Root Cause simultaneously. | |
| Complexity: Time O(N * BERT) | Space O(BERT_params + Heads) | |
| """ | |
| def __init__(self, num_labels_dict): | |
| super(MultiTaskDistilBert, self).__init__() | |
| self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased') | |
| self.dropout = nn.Dropout(0.3) | |
| # Classification heads | |
| self.heads = nn.ModuleDict({ | |
| key: nn.Linear(self.distilbert.config.dim, num_labels) | |
| for key, num_labels in num_labels_dict.items() | |
| }) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.last_hidden_state[:, 0] # CLS token | |
| pooled_output = self.dropout(pooled_output) | |
| return {key: head(pooled_output) for key, head in self.heads.items()} | |