Add multitask_model.py
Browse files- multitask_model.py +122 -0
multitask_model.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import BertModel, BertConfig, RobertaTokenizer, RobertaModel, RobertaConfig, PretrainedConfig
|
| 4 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 5 |
+
from peft import LoraConfig, get_peft_model, LoraModel
|
| 6 |
+
|
| 7 |
+
task_name_to_id = {"sentiment": 0, "hate": 1, "emotion": 2}
|
| 8 |
+
|
| 9 |
+
# Number of classes for each task
|
| 10 |
+
task_num_labels = {
|
| 11 |
+
"sentiment": 3,
|
| 12 |
+
"hate": 3,
|
| 13 |
+
"emotion": 4
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MultiTaskModel(nn.Module):
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
model_name="roberta-base"
|
| 22 |
+
config = RobertaConfig.from_pretrained(model_name)
|
| 23 |
+
base_model = RobertaModel.from_pretrained(model_name, config=config)
|
| 24 |
+
|
| 25 |
+
# self.task_weights = {
|
| 26 |
+
# "sentiment": 1,
|
| 27 |
+
# "hate": 2,
|
| 28 |
+
# "emotion": 1
|
| 29 |
+
# }
|
| 30 |
+
|
| 31 |
+
lora_config = LoraConfig(
|
| 32 |
+
r=16,
|
| 33 |
+
lora_alpha=32,
|
| 34 |
+
target_modules=['query', 'value'],
|
| 35 |
+
lora_dropout=0.05,
|
| 36 |
+
bias='none',
|
| 37 |
+
task_type='SEQ_CLS'
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.model = LoraModel(base_model, lora_config, adapter_name="shared")
|
| 41 |
+
|
| 42 |
+
# self.model.add_adapter(lora_config, adapter_name= 'hate') # For hate
|
| 43 |
+
# self.model.set_adapter("shared") # Set default
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
hidden_size = config.hidden_size
|
| 47 |
+
dropout_prob = 0.1
|
| 48 |
+
intermediate_size = 128
|
| 49 |
+
|
| 50 |
+
self.sentiment_head = nn.Sequential(
|
| 51 |
+
nn.Linear(hidden_size, intermediate_size),
|
| 52 |
+
nn.ReLU(),
|
| 53 |
+
nn.Dropout(dropout_prob),
|
| 54 |
+
nn.Linear(intermediate_size, task_num_labels["sentiment"])
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.hate_head = nn.Sequential(
|
| 58 |
+
nn.Linear(hidden_size, intermediate_size),
|
| 59 |
+
nn.ReLU(),
|
| 60 |
+
nn.Dropout(dropout_prob),
|
| 61 |
+
nn.Linear(intermediate_size, task_num_labels["hate"])
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.emotion_head = nn.Sequential(
|
| 65 |
+
nn.Linear(hidden_size, intermediate_size),
|
| 66 |
+
nn.ReLU(),
|
| 67 |
+
nn.Dropout(dropout_prob),
|
| 68 |
+
nn.Linear(intermediate_size, task_num_labels["emotion"])
|
| 69 |
+
)
|
| 70 |
+
#self.bert.print_trainable_parameters()r
|
| 71 |
+
print(f"Trainable parameters (LoRA): {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
| 72 |
+
print(f"Total parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
|
| 73 |
+
self.loss_fct = nn.CrossEntropyLoss()
|
| 74 |
+
|
| 75 |
+
def forward(self, input_ids=None, attention_mask=None, task_id=None, labels=None):
|
| 76 |
+
|
| 77 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 78 |
+
pooled = outputs.last_hidden_state[:, 0] # Use first token (CLS)
|
| 79 |
+
|
| 80 |
+
sentiment_mask = task_id == task_name_to_id["sentiment"]
|
| 81 |
+
hate_mask = task_id == task_name_to_id["hate"]
|
| 82 |
+
emotion_mask = task_id == task_name_to_id["emotion"]
|
| 83 |
+
|
| 84 |
+
logits = {}
|
| 85 |
+
loss = 0
|
| 86 |
+
|
| 87 |
+
# Sentiment task
|
| 88 |
+
if sentiment_mask.any():
|
| 89 |
+
sentiment_pooled = pooled[sentiment_mask]
|
| 90 |
+
sentiment_logits = self.sentiment_head(sentiment_pooled)
|
| 91 |
+
logits["sentiment"] = sentiment_logits
|
| 92 |
+
if labels is not None:
|
| 93 |
+
sentiment_labels = labels[sentiment_mask]
|
| 94 |
+
loss += self.loss_fct(sentiment_logits, sentiment_labels)
|
| 95 |
+
else:
|
| 96 |
+
logits["sentiment"] = torch.empty(0, task_num_labels["sentiment"], device=input_ids.device)
|
| 97 |
+
|
| 98 |
+
# Hate task
|
| 99 |
+
if hate_mask.any():
|
| 100 |
+
hate_pooled = pooled[hate_mask]
|
| 101 |
+
hate_logits = self.hate_head(hate_pooled)
|
| 102 |
+
logits["hate"] = hate_logits
|
| 103 |
+
if labels is not None:
|
| 104 |
+
hate_labels = labels[hate_mask]
|
| 105 |
+
loss += self.loss_fct(hate_logits, hate_labels)
|
| 106 |
+
else:
|
| 107 |
+
logits["hate"] = torch.empty(0, task_num_labels["hate"], device=input_ids.device)
|
| 108 |
+
|
| 109 |
+
# Emotion task
|
| 110 |
+
if emotion_mask.any():
|
| 111 |
+
emotion_pooled = pooled[emotion_mask]
|
| 112 |
+
emotion_logits = self.emotion_head(emotion_pooled)
|
| 113 |
+
logits["emotion"] = emotion_logits
|
| 114 |
+
if labels is not None:
|
| 115 |
+
emotion_labels = labels[emotion_mask]
|
| 116 |
+
loss += self.loss_fct(emotion_logits, emotion_labels)
|
| 117 |
+
else:
|
| 118 |
+
logits["emotion"] = torch.empty(0, task_num_labels["emotion"], device=input_ids.device)
|
| 119 |
+
|
| 120 |
+
return {"loss": loss, "logits": logits} if labels is not None else {"logits": logits}
|
| 121 |
+
|
| 122 |
+
|