SeragAmin commited on
Commit
644b065
·
1 Parent(s): bb38e03

Add multitask_model.py

Browse files
Files changed (1) hide show
  1. multitask_model.py +122 -0
multitask_model.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel, BertConfig, RobertaTokenizer, RobertaModel, RobertaConfig, PretrainedConfig
4
+ from transformers.modeling_outputs import SequenceClassifierOutput
5
+ from peft import LoraConfig, get_peft_model, LoraModel
6
+
7
+ task_name_to_id = {"sentiment": 0, "hate": 1, "emotion": 2}
8
+
9
+ # Number of classes for each task
10
+ task_num_labels = {
11
+ "sentiment": 3,
12
+ "hate": 3,
13
+ "emotion": 4
14
+ }
15
+
16
+
17
+
18
+ class MultiTaskModel(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ model_name="roberta-base"
22
+ config = RobertaConfig.from_pretrained(model_name)
23
+ base_model = RobertaModel.from_pretrained(model_name, config=config)
24
+
25
+ # self.task_weights = {
26
+ # "sentiment": 1,
27
+ # "hate": 2,
28
+ # "emotion": 1
29
+ # }
30
+
31
+ lora_config = LoraConfig(
32
+ r=16,
33
+ lora_alpha=32,
34
+ target_modules=['query', 'value'],
35
+ lora_dropout=0.05,
36
+ bias='none',
37
+ task_type='SEQ_CLS'
38
+ )
39
+
40
+ self.model = LoraModel(base_model, lora_config, adapter_name="shared")
41
+
42
+ # self.model.add_adapter(lora_config, adapter_name= 'hate') # For hate
43
+ # self.model.set_adapter("shared") # Set default
44
+
45
+
46
+ hidden_size = config.hidden_size
47
+ dropout_prob = 0.1
48
+ intermediate_size = 128
49
+
50
+ self.sentiment_head = nn.Sequential(
51
+ nn.Linear(hidden_size, intermediate_size),
52
+ nn.ReLU(),
53
+ nn.Dropout(dropout_prob),
54
+ nn.Linear(intermediate_size, task_num_labels["sentiment"])
55
+ )
56
+
57
+ self.hate_head = nn.Sequential(
58
+ nn.Linear(hidden_size, intermediate_size),
59
+ nn.ReLU(),
60
+ nn.Dropout(dropout_prob),
61
+ nn.Linear(intermediate_size, task_num_labels["hate"])
62
+ )
63
+
64
+ self.emotion_head = nn.Sequential(
65
+ nn.Linear(hidden_size, intermediate_size),
66
+ nn.ReLU(),
67
+ nn.Dropout(dropout_prob),
68
+ nn.Linear(intermediate_size, task_num_labels["emotion"])
69
+ )
70
+ #self.bert.print_trainable_parameters()r
71
+ print(f"Trainable parameters (LoRA): {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
72
+ print(f"Total parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
73
+ self.loss_fct = nn.CrossEntropyLoss()
74
+
75
+ def forward(self, input_ids=None, attention_mask=None, task_id=None, labels=None):
76
+
77
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
78
+ pooled = outputs.last_hidden_state[:, 0] # Use first token (CLS)
79
+
80
+ sentiment_mask = task_id == task_name_to_id["sentiment"]
81
+ hate_mask = task_id == task_name_to_id["hate"]
82
+ emotion_mask = task_id == task_name_to_id["emotion"]
83
+
84
+ logits = {}
85
+ loss = 0
86
+
87
+ # Sentiment task
88
+ if sentiment_mask.any():
89
+ sentiment_pooled = pooled[sentiment_mask]
90
+ sentiment_logits = self.sentiment_head(sentiment_pooled)
91
+ logits["sentiment"] = sentiment_logits
92
+ if labels is not None:
93
+ sentiment_labels = labels[sentiment_mask]
94
+ loss += self.loss_fct(sentiment_logits, sentiment_labels)
95
+ else:
96
+ logits["sentiment"] = torch.empty(0, task_num_labels["sentiment"], device=input_ids.device)
97
+
98
+ # Hate task
99
+ if hate_mask.any():
100
+ hate_pooled = pooled[hate_mask]
101
+ hate_logits = self.hate_head(hate_pooled)
102
+ logits["hate"] = hate_logits
103
+ if labels is not None:
104
+ hate_labels = labels[hate_mask]
105
+ loss += self.loss_fct(hate_logits, hate_labels)
106
+ else:
107
+ logits["hate"] = torch.empty(0, task_num_labels["hate"], device=input_ids.device)
108
+
109
+ # Emotion task
110
+ if emotion_mask.any():
111
+ emotion_pooled = pooled[emotion_mask]
112
+ emotion_logits = self.emotion_head(emotion_pooled)
113
+ logits["emotion"] = emotion_logits
114
+ if labels is not None:
115
+ emotion_labels = labels[emotion_mask]
116
+ loss += self.loss_fct(emotion_logits, emotion_labels)
117
+ else:
118
+ logits["emotion"] = torch.empty(0, task_num_labels["emotion"], device=input_ids.device)
119
+
120
+ return {"loss": loss, "logits": logits} if labels is not None else {"logits": logits}
121
+
122
+