|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import DebertaV2Tokenizer , DebertaV2Model |
|
|
from typing import Dict, Any |
|
|
import joblib |
|
|
import os |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_path =""): |
|
|
|
|
|
self.tokenizer = DebertaV2Tokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
self.model = MultitaskDebertaModel(num_emotion_labels=8, num_polarity_labels=4, num_hate_speech_labels=2) |
|
|
self.model.load_state_dict(torch.load(os.path.join(model_path, 'pytorch_model.bin'))) |
|
|
|
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.emotion_encoder = joblib.load(os.path.join(model_path, 'emotion_encoder.pkl')) |
|
|
self.polarity_encoder = joblib.load(os.path.join(model_path, 'polarity_encoder.pkl')) |
|
|
self.hate_speech_encoder = joblib.load(os.path.join(model_path, 'hate_speech_encoder.pkl')) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
texts = data.get('inputs', []) |
|
|
|
|
|
|
|
|
batch_size = 32 |
|
|
results = { |
|
|
"emotions": [], |
|
|
"polarities": [], |
|
|
"hate_speech": [] |
|
|
} |
|
|
|
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = texts[i:i+batch_size] |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(batch_texts, return_tensors='pt', max_length=256, truncation=True, padding=True) |
|
|
if 'token_type_ids' in inputs: |
|
|
del inputs['token_type_ids'] |
|
|
inputs = {key: val.to(self.device) for key, val in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
emotion_logits = outputs.get('emotion') |
|
|
polarity_logits = outputs.get('polarity') |
|
|
hate_speech_logits = outputs.get('hate_speech') |
|
|
|
|
|
|
|
|
emotion_preds = torch.argmax(emotion_logits, dim=1).cpu().numpy().tolist() |
|
|
polarity_preds = torch.argmax(polarity_logits, dim=1).cpu().numpy().tolist() |
|
|
hate_speech_preds = torch.argmax(hate_speech_logits, dim=1).cpu().numpy().tolist() |
|
|
|
|
|
|
|
|
decoded_emotions = self.emotion_encoder.inverse_transform(emotion_preds).tolist() |
|
|
decoded_polarities = self.polarity_encoder.inverse_transform(polarity_preds).tolist() |
|
|
decoded_hate_speech = self.hate_speech_encoder.inverse_transform(hate_speech_preds).tolist() |
|
|
|
|
|
results["emotions"].extend(decoded_emotions) |
|
|
results["polarities"].extend(decoded_polarities) |
|
|
results["hate_speech"].extend(decoded_hate_speech) |
|
|
|
|
|
return results |
|
|
|
|
|
def load_model(self, model_path): |
|
|
|
|
|
self.load_state_dict(torch.load(model_path)) |
|
|
|
|
|
|
|
|
class MultitaskDebertaModel(nn.Module): |
|
|
def __init__(self, num_emotion_labels, num_polarity_labels, num_hate_speech_labels): |
|
|
super(MultitaskDebertaModel, self).__init__() |
|
|
self.deberta = DebertaV2Model.from_pretrained('microsoft/deberta-v3-base') |
|
|
|
|
|
|
|
|
for param in self.deberta.encoder.layer[:5]: |
|
|
for p in param.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
self.emotion_lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) |
|
|
self.polarity_lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) |
|
|
self.hate_speech_lstm = nn.LSTM(768, 128, bidirectional=True, batch_first=True) |
|
|
|
|
|
|
|
|
self.emotion_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True) |
|
|
self.polarity_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True) |
|
|
self.hate_speech_attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, batch_first=True) |
|
|
|
|
|
|
|
|
self.emotion_dense = nn.Linear(256, 128) |
|
|
self.polarity_dense = nn.Linear(256, 128) |
|
|
self.hate_speech_dense = nn.Linear(256, 128) |
|
|
|
|
|
|
|
|
self.fusion_dense = nn.Linear(128 + 128 + 128 + 768, 128) |
|
|
|
|
|
|
|
|
self.emotion_classifier = nn.Linear(128, num_emotion_labels) |
|
|
self.polarity_classifier = nn.Linear(128, num_polarity_labels) |
|
|
self.hate_speech_classifier = nn.Linear(128, num_hate_speech_labels) |
|
|
|
|
|
|
|
|
self.layer_norm = nn.LayerNorm(128) |
|
|
self.dropout = nn.Dropout(p=0.3) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
|
|
deberta_outputs = self.deberta(input_ids, attention_mask=attention_mask) |
|
|
sequence_output = deberta_outputs.last_hidden_state |
|
|
cls_output = sequence_output[:, 0, :] |
|
|
|
|
|
|
|
|
emotion_lstm_output, _ = self.emotion_lstm(sequence_output) |
|
|
polarity_lstm_output, _ = self.polarity_lstm(sequence_output) |
|
|
hate_speech_lstm_output, _ = self.hate_speech_lstm(sequence_output) |
|
|
|
|
|
|
|
|
emotion_attention_output, _ = self.emotion_attention(emotion_lstm_output, emotion_lstm_output, emotion_lstm_output) |
|
|
polarity_attention_output, _ = self.polarity_attention(polarity_lstm_output, polarity_lstm_output, polarity_lstm_output) |
|
|
hate_speech_attention_output, _ = self.hate_speech_attention(hate_speech_lstm_output, hate_speech_lstm_output, hate_speech_lstm_output) |
|
|
|
|
|
|
|
|
emotion_features = torch.mean(emotion_attention_output, dim=1) |
|
|
polarity_features = torch.mean(polarity_attention_output, dim=1) |
|
|
hate_speech_features = torch.mean(hate_speech_attention_output, dim=1) |
|
|
|
|
|
|
|
|
emotion_features = self.relu(self.emotion_dense(emotion_features)) |
|
|
polarity_features = self.relu(self.polarity_dense(polarity_features)) |
|
|
hate_speech_features = self.relu(self.hate_speech_dense(hate_speech_features)) |
|
|
|
|
|
|
|
|
combined_features = torch.cat([emotion_features, polarity_features, hate_speech_features, cls_output], dim=-1) |
|
|
combined_features = self.relu(self.fusion_dense(combined_features)) |
|
|
|
|
|
|
|
|
combined_features = self.layer_norm(combined_features) |
|
|
combined_features = self.dropout(combined_features) |
|
|
|
|
|
|
|
|
emotion_logits = self.emotion_classifier(combined_features) |
|
|
polarity_logits = self.polarity_classifier(combined_features) |
|
|
hate_speech_logits = self.hate_speech_classifier(combined_features) |
|
|
|
|
|
return { |
|
|
'emotion': emotion_logits, |
|
|
'polarity': polarity_logits, |
|
|
'hate_speech': hate_speech_logits |
|
|
} |
|
|
|