Rogendo's picture
Upload Multi-task classification model
459ef90 verified
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertPreTrainedModel
class MultiTaskDistilBert(DistilBertPreTrainedModel):
"""
Multi-task DistilBERT classifier for child helpline case management.
Performs simultaneous classification across 4 tasks:
- Main category classification
- Sub-category classification
- Intervention recommendation
- Priority assignment
"""
def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
# Task-specific classification heads
self.classifier_main = nn.Linear(config.dim, config.num_main_labels)
self.classifier_sub = nn.Linear(config.dim, config.num_sub_labels)
self.classifier_interv = nn.Linear(config.dim, config.num_intervention_labels)
self.classifier_priority = nn.Linear(config.dim, config.num_priority_labels)
self.dropout = nn.Dropout(config.dropout)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None,
main_category_id=None, sub_category_id=None,
intervention_id=None, priority_id=None):
# Shared DistilBERT encoder
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Feature extraction and processing
hidden_state = distilbert_output.last_hidden_state
pooled_output = hidden_state[:, 0] # [CLS] token
pooled_output = self.pre_classifier(pooled_output)
pooled_output = nn.ReLU()(pooled_output)
pooled_output = self.dropout(pooled_output)
# Multi-task predictions
logits_main = self.classifier_main(pooled_output)
logits_sub = self.classifier_sub(pooled_output)
logits_interv = self.classifier_interv(pooled_output)
logits_priority = self.classifier_priority(pooled_output)
# Multi-task loss calculation (training only)
loss = None
if main_category_id is not None:
loss_fct = nn.CrossEntropyLoss()
loss_main = loss_fct(logits_main, main_category_id)
loss_sub = loss_fct(logits_sub, sub_category_id)
loss_interv = loss_fct(logits_interv, intervention_id)
loss_priority = loss_fct(logits_priority, priority_id)
loss = loss_main + loss_sub + loss_interv + loss_priority
if loss is not None:
return (loss, logits_main, logits_sub, logits_interv, logits_priority)
else:
return (logits_main, logits_sub, logits_interv, logits_priority)