import torch import torch.nn as nn from transformers import DistilBertModel, DistilBertPreTrainedModel class MultiTaskDistilBert(DistilBertPreTrainedModel): """ Multi-task DistilBERT classifier for child helpline case management. Performs simultaneous classification across 4 tasks: - Main category classification - Sub-category classification - Intervention recommendation - Priority assignment """ def __init__(self, config): super().__init__(config) self.distilbert = DistilBertModel(config) self.pre_classifier = nn.Linear(config.dim, config.dim) # Task-specific classification heads self.classifier_main = nn.Linear(config.dim, config.num_main_labels) self.classifier_sub = nn.Linear(config.dim, config.num_sub_labels) self.classifier_interv = nn.Linear(config.dim, config.num_intervention_labels) self.classifier_priority = nn.Linear(config.dim, config.num_priority_labels) self.dropout = nn.Dropout(config.dropout) self.init_weights() def forward(self, input_ids=None, attention_mask=None, main_category_id=None, sub_category_id=None, intervention_id=None, priority_id=None): # Shared DistilBERT encoder distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ) # Feature extraction and processing hidden_state = distilbert_output.last_hidden_state pooled_output = hidden_state[:, 0] # [CLS] token pooled_output = self.pre_classifier(pooled_output) pooled_output = nn.ReLU()(pooled_output) pooled_output = self.dropout(pooled_output) # Multi-task predictions logits_main = self.classifier_main(pooled_output) logits_sub = self.classifier_sub(pooled_output) logits_interv = self.classifier_interv(pooled_output) logits_priority = self.classifier_priority(pooled_output) # Multi-task loss calculation (training only) loss = None if main_category_id is not None: loss_fct = nn.CrossEntropyLoss() loss_main = loss_fct(logits_main, main_category_id) loss_sub = loss_fct(logits_sub, sub_category_id) loss_interv = loss_fct(logits_interv, intervention_id) loss_priority = loss_fct(logits_priority, priority_id) loss = loss_main + loss_sub + loss_interv + loss_priority if loss is not None: return (loss, logits_main, logits_sub, logits_interv, logits_priority) else: return (logits_main, logits_sub, logits_interv, logits_priority)