|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import BertModel, BertConfig, RobertaTokenizer, RobertaModel, RobertaConfig, PretrainedConfig |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
from peft import LoraConfig, get_peft_model, LoraModel |
|
|
|
|
|
task_name_to_id = {"sentiment": 0, "hate": 1, "emotion": 2} |
|
|
|
|
|
|
|
|
task_num_labels = { |
|
|
"sentiment": 3, |
|
|
"hate": 3, |
|
|
"emotion": 4 |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class MultiTaskModel(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
model_name="roberta-base" |
|
|
config = RobertaConfig.from_pretrained(model_name) |
|
|
base_model = RobertaModel.from_pretrained(model_name, config=config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
target_modules=['query', 'value'], |
|
|
lora_dropout=0.05, |
|
|
bias='none', |
|
|
task_type='SEQ_CLS' |
|
|
) |
|
|
|
|
|
self.model = LoraModel(base_model, lora_config, adapter_name="shared") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_size = config.hidden_size |
|
|
dropout_prob = 0.1 |
|
|
intermediate_size = 128 |
|
|
|
|
|
self.sentiment_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, intermediate_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout_prob), |
|
|
nn.Linear(intermediate_size, task_num_labels["sentiment"]) |
|
|
) |
|
|
|
|
|
self.hate_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, intermediate_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout_prob), |
|
|
nn.Linear(intermediate_size, task_num_labels["hate"]) |
|
|
) |
|
|
|
|
|
self.emotion_head = nn.Sequential( |
|
|
nn.Linear(hidden_size, intermediate_size), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout_prob), |
|
|
nn.Linear(intermediate_size, task_num_labels["emotion"]) |
|
|
) |
|
|
|
|
|
print(f"Trainable parameters (LoRA): {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") |
|
|
print(f"Total parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}") |
|
|
self.loss_fct = nn.CrossEntropyLoss() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, task_id=None, labels=None): |
|
|
|
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pooled = outputs.last_hidden_state[:, 0] |
|
|
|
|
|
sentiment_mask = task_id == task_name_to_id["sentiment"] |
|
|
hate_mask = task_id == task_name_to_id["hate"] |
|
|
emotion_mask = task_id == task_name_to_id["emotion"] |
|
|
|
|
|
logits = {} |
|
|
loss = 0 |
|
|
|
|
|
|
|
|
if sentiment_mask.any(): |
|
|
sentiment_pooled = pooled[sentiment_mask] |
|
|
sentiment_logits = self.sentiment_head(sentiment_pooled) |
|
|
logits["sentiment"] = sentiment_logits |
|
|
if labels is not None: |
|
|
sentiment_labels = labels[sentiment_mask] |
|
|
loss += self.loss_fct(sentiment_logits, sentiment_labels) |
|
|
else: |
|
|
logits["sentiment"] = torch.empty(0, task_num_labels["sentiment"], device=input_ids.device) |
|
|
|
|
|
|
|
|
if hate_mask.any(): |
|
|
hate_pooled = pooled[hate_mask] |
|
|
hate_logits = self.hate_head(hate_pooled) |
|
|
logits["hate"] = hate_logits |
|
|
if labels is not None: |
|
|
hate_labels = labels[hate_mask] |
|
|
loss += self.loss_fct(hate_logits, hate_labels) |
|
|
else: |
|
|
logits["hate"] = torch.empty(0, task_num_labels["hate"], device=input_ids.device) |
|
|
|
|
|
|
|
|
if emotion_mask.any(): |
|
|
emotion_pooled = pooled[emotion_mask] |
|
|
emotion_logits = self.emotion_head(emotion_pooled) |
|
|
logits["emotion"] = emotion_logits |
|
|
if labels is not None: |
|
|
emotion_labels = labels[emotion_mask] |
|
|
loss += self.loss_fct(emotion_logits, emotion_labels) |
|
|
else: |
|
|
logits["emotion"] = torch.empty(0, task_num_labels["emotion"], device=input_ids.device) |
|
|
|
|
|
return {"loss": loss, "logits": logits} if labels is not None else {"logits": logits} |
|
|
|
|
|
|