hafizhaaarama commited on
Commit
31139e8
·
verified ·
1 Parent(s): fe8c60c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +5 -5
model.py CHANGED
@@ -2,12 +2,12 @@ import torch.nn as nn
2
  from transformers import AutoModel
3
 
4
  class MultitaskModel(nn.Module):
5
- def __init__(self):
6
  super().__init__()
7
- self.encoder = AutoModel.from_pretrained("distilbert-base-uncased")
8
  self.dropout = nn.Dropout(0.1)
9
- self.classifier_sent = nn.Linear(768, 2) # num_sentiment_labels
10
- self.classifier_emo = nn.Linear(768, 7) # num_emotion_labels
11
 
12
  def forward(self, input_ids=None, attention_mask=None):
13
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
@@ -15,4 +15,4 @@ class MultitaskModel(nn.Module):
15
  pooled = self.dropout(pooled)
16
  sent = self.classifier_sent(pooled)
17
  emo = self.classifier_emo(pooled)
18
- return sent, emo
 
2
  from transformers import AutoModel
3
 
4
  class MultitaskModel(nn.Module):
5
+ def __init__(self, model_name, num_sentiment_labels, num_emotion_labels):
6
  super().__init__()
7
+ self.encoder = AutoModel.from_pretrained(model_name)
8
  self.dropout = nn.Dropout(0.1)
9
+ self.classifier_sent = nn.Linear(self.encoder.config.hidden_size, num_sentiment_labels)
10
+ self.classifier_emo = nn.Linear(self.encoder.config.hidden_size, num_emotion_labels)
11
 
12
  def forward(self, input_ids=None, attention_mask=None):
13
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
 
15
  pooled = self.dropout(pooled)
16
  sent = self.classifier_sent(pooled)
17
  emo = self.classifier_emo(pooled)
18
+ return sent, emo