File size: 2,820 Bytes
459ef90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertPreTrainedModel
class MultiTaskDistilBert(DistilBertPreTrainedModel):
"""
Multi-task DistilBERT classifier for child helpline case management.
Performs simultaneous classification across 4 tasks:
- Main category classification
- Sub-category classification
- Intervention recommendation
- Priority assignment
"""
def __init__(self, config):
super().__init__(config)
self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
# Task-specific classification heads
self.classifier_main = nn.Linear(config.dim, config.num_main_labels)
self.classifier_sub = nn.Linear(config.dim, config.num_sub_labels)
self.classifier_interv = nn.Linear(config.dim, config.num_intervention_labels)
self.classifier_priority = nn.Linear(config.dim, config.num_priority_labels)
self.dropout = nn.Dropout(config.dropout)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None,
main_category_id=None, sub_category_id=None,
intervention_id=None, priority_id=None):
# Shared DistilBERT encoder
distilbert_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Feature extraction and processing
hidden_state = distilbert_output.last_hidden_state
pooled_output = hidden_state[:, 0] # [CLS] token
pooled_output = self.pre_classifier(pooled_output)
pooled_output = nn.ReLU()(pooled_output)
pooled_output = self.dropout(pooled_output)
# Multi-task predictions
logits_main = self.classifier_main(pooled_output)
logits_sub = self.classifier_sub(pooled_output)
logits_interv = self.classifier_interv(pooled_output)
logits_priority = self.classifier_priority(pooled_output)
# Multi-task loss calculation (training only)
loss = None
if main_category_id is not None:
loss_fct = nn.CrossEntropyLoss()
loss_main = loss_fct(logits_main, main_category_id)
loss_sub = loss_fct(logits_sub, sub_category_id)
loss_interv = loss_fct(logits_interv, intervention_id)
loss_priority = loss_fct(logits_priority, priority_id)
loss = loss_main + loss_sub + loss_interv + loss_priority
if loss is not None:
return (loss, logits_main, logits_sub, logits_interv, logits_priority)
else:
return (logits_main, logits_sub, logits_interv, logits_priority)
|