File size: 2,820 Bytes
459ef90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertPreTrainedModel

class MultiTaskDistilBert(DistilBertPreTrainedModel):
    """
    Multi-task DistilBERT classifier for child helpline case management.
    
    Performs simultaneous classification across 4 tasks:
    - Main category classification
    - Sub-category classification  
    - Intervention recommendation
    - Priority assignment
    """
    
    def __init__(self, config):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        
        # Task-specific classification heads
        self.classifier_main = nn.Linear(config.dim, config.num_main_labels)
        self.classifier_sub = nn.Linear(config.dim, config.num_sub_labels)
        self.classifier_interv = nn.Linear(config.dim, config.num_intervention_labels)
        self.classifier_priority = nn.Linear(config.dim, config.num_priority_labels)
        
        self.dropout = nn.Dropout(config.dropout)
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, 
                main_category_id=None, sub_category_id=None, 
                intervention_id=None, priority_id=None):
        
        # Shared DistilBERT encoder
        distilbert_output = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Feature extraction and processing
        hidden_state = distilbert_output.last_hidden_state 
        pooled_output = hidden_state[:, 0]  # [CLS] token
        pooled_output = self.pre_classifier(pooled_output) 
        pooled_output = nn.ReLU()(pooled_output)           
        pooled_output = self.dropout(pooled_output)        
        
        # Multi-task predictions
        logits_main = self.classifier_main(pooled_output)
        logits_sub = self.classifier_sub(pooled_output)
        logits_interv = self.classifier_interv(pooled_output)
        logits_priority = self.classifier_priority(pooled_output)
        
        # Multi-task loss calculation (training only)
        loss = None
        if main_category_id is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss_main = loss_fct(logits_main, main_category_id)
            loss_sub = loss_fct(logits_sub, sub_category_id)
            loss_interv = loss_fct(logits_interv, intervention_id)
            loss_priority = loss_fct(logits_priority, priority_id)
            loss = loss_main + loss_sub + loss_interv + loss_priority
        
        if loss is not None:
            return (loss, logits_main, logits_sub, logits_interv, logits_priority)
        else:
            return (logits_main, logits_sub, logits_interv, logits_priority)