gapura-ai-api / data /transformer_architecture.py
Muhammad Ridzki Nugraha
Upload folder using huggingface_hub
13c3f2c verified
import torch.nn as nn
from transformers import DistilBertModel
class MultiTaskDistilBert(nn.Module):
"""
Multi-task DistilBERT for classifying Report Category, Irregularity Category, Area, and Root Cause simultaneously.
Complexity: Time O(N * BERT) | Space O(BERT_params + Heads)
"""
def __init__(self, num_labels_dict):
super(MultiTaskDistilBert, self).__init__()
self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
self.dropout = nn.Dropout(0.3)
# Classification heads
self.heads = nn.ModuleDict({
key: nn.Linear(self.distilbert.config.dim, num_labels)
for key, num_labels in num_labels_dict.items()
})
def forward(self, input_ids, attention_mask):
outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0] # CLS token
pooled_output = self.dropout(pooled_output)
return {key: head(pooled_output) for key, head in self.heads.items()}