RoBERTa_Multitask / multitask_model.py
SeragAmin's picture
Add multitask_model.py
644b065
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig, RobertaTokenizer, RobertaModel, RobertaConfig, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from peft import LoraConfig, get_peft_model, LoraModel
task_name_to_id = {"sentiment": 0, "hate": 1, "emotion": 2}
# Number of classes for each task
task_num_labels = {
"sentiment": 3,
"hate": 3,
"emotion": 4
}
class MultiTaskModel(nn.Module):
def __init__(self):
super().__init__()
model_name="roberta-base"
config = RobertaConfig.from_pretrained(model_name)
base_model = RobertaModel.from_pretrained(model_name, config=config)
# self.task_weights = {
# "sentiment": 1,
# "hate": 2,
# "emotion": 1
# }
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=['query', 'value'],
lora_dropout=0.05,
bias='none',
task_type='SEQ_CLS'
)
self.model = LoraModel(base_model, lora_config, adapter_name="shared")
# self.model.add_adapter(lora_config, adapter_name= 'hate') # For hate
# self.model.set_adapter("shared") # Set default
hidden_size = config.hidden_size
dropout_prob = 0.1
intermediate_size = 128
self.sentiment_head = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Linear(intermediate_size, task_num_labels["sentiment"])
)
self.hate_head = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Linear(intermediate_size, task_num_labels["hate"])
)
self.emotion_head = nn.Sequential(
nn.Linear(hidden_size, intermediate_size),
nn.ReLU(),
nn.Dropout(dropout_prob),
nn.Linear(intermediate_size, task_num_labels["emotion"])
)
#self.bert.print_trainable_parameters()r
print(f"Trainable parameters (LoRA): {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
print(f"Total parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
self.loss_fct = nn.CrossEntropyLoss()
def forward(self, input_ids=None, attention_mask=None, task_id=None, labels=None):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0] # Use first token (CLS)
sentiment_mask = task_id == task_name_to_id["sentiment"]
hate_mask = task_id == task_name_to_id["hate"]
emotion_mask = task_id == task_name_to_id["emotion"]
logits = {}
loss = 0
# Sentiment task
if sentiment_mask.any():
sentiment_pooled = pooled[sentiment_mask]
sentiment_logits = self.sentiment_head(sentiment_pooled)
logits["sentiment"] = sentiment_logits
if labels is not None:
sentiment_labels = labels[sentiment_mask]
loss += self.loss_fct(sentiment_logits, sentiment_labels)
else:
logits["sentiment"] = torch.empty(0, task_num_labels["sentiment"], device=input_ids.device)
# Hate task
if hate_mask.any():
hate_pooled = pooled[hate_mask]
hate_logits = self.hate_head(hate_pooled)
logits["hate"] = hate_logits
if labels is not None:
hate_labels = labels[hate_mask]
loss += self.loss_fct(hate_logits, hate_labels)
else:
logits["hate"] = torch.empty(0, task_num_labels["hate"], device=input_ids.device)
# Emotion task
if emotion_mask.any():
emotion_pooled = pooled[emotion_mask]
emotion_logits = self.emotion_head(emotion_pooled)
logits["emotion"] = emotion_logits
if labels is not None:
emotion_labels = labels[emotion_mask]
loss += self.loss_fct(emotion_logits, emotion_labels)
else:
logits["emotion"] = torch.empty(0, task_num_labels["emotion"], device=input_ids.device)
return {"loss": loss, "logits": logits} if labels is not None else {"logits": logits}