|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import DistilBertModel, DistilBertPreTrainedModel |
|
|
|
|
|
class MultiTaskDistilBert(DistilBertPreTrainedModel): |
|
|
""" |
|
|
Multi-task DistilBERT classifier for child helpline case management. |
|
|
|
|
|
Performs simultaneous classification across 4 tasks: |
|
|
- Main category classification |
|
|
- Sub-category classification |
|
|
- Intervention recommendation |
|
|
- Priority assignment |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.distilbert = DistilBertModel(config) |
|
|
self.pre_classifier = nn.Linear(config.dim, config.dim) |
|
|
|
|
|
|
|
|
self.classifier_main = nn.Linear(config.dim, config.num_main_labels) |
|
|
self.classifier_sub = nn.Linear(config.dim, config.num_sub_labels) |
|
|
self.classifier_interv = nn.Linear(config.dim, config.num_intervention_labels) |
|
|
self.classifier_priority = nn.Linear(config.dim, config.num_priority_labels) |
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, |
|
|
main_category_id=None, sub_category_id=None, |
|
|
intervention_id=None, priority_id=None): |
|
|
|
|
|
|
|
|
distilbert_output = self.distilbert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
|
|
|
hidden_state = distilbert_output.last_hidden_state |
|
|
pooled_output = hidden_state[:, 0] |
|
|
pooled_output = self.pre_classifier(pooled_output) |
|
|
pooled_output = nn.ReLU()(pooled_output) |
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
|
logits_main = self.classifier_main(pooled_output) |
|
|
logits_sub = self.classifier_sub(pooled_output) |
|
|
logits_interv = self.classifier_interv(pooled_output) |
|
|
logits_priority = self.classifier_priority(pooled_output) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if main_category_id is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss_main = loss_fct(logits_main, main_category_id) |
|
|
loss_sub = loss_fct(logits_sub, sub_category_id) |
|
|
loss_interv = loss_fct(logits_interv, intervention_id) |
|
|
loss_priority = loss_fct(logits_priority, priority_id) |
|
|
loss = loss_main + loss_sub + loss_interv + loss_priority |
|
|
|
|
|
if loss is not None: |
|
|
return (loss, logits_main, logits_sub, logits_interv, logits_priority) |
|
|
else: |
|
|
return (logits_main, logits_sub, logits_interv, logits_priority) |
|
|
|