asmashayea commited on
Commit
bb464bb
·
1 Parent(s): 7ebac28
araberta_setting/modeling_bilstm_crf.py CHANGED
@@ -6,7 +6,8 @@ class BERT_BiLSTM_CRF(nn.Module):
6
  def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
7
  super().__init__()
8
  self.bert = base_model
9
- self.label2id = config.label2id # <-- pulled from config
 
10
  self.id2label = config.id2label
11
  self.num_labels = config.num_labels
12
 
@@ -16,7 +17,7 @@ class BERT_BiLSTM_CRF(nn.Module):
16
  num_layers=2,
17
  batch_first=True,
18
  bidirectional=True,
19
- dropout=0.2
20
  )
21
  self.dropout = nn.Dropout(dropout_rate)
22
  self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)
 
6
  def __init__(self, base_model, config, dropout_rate=0.2, rnn_dim=256):
7
  super().__init__()
8
  self.bert = base_model
9
+ self.config = config # add this line
10
+ self.label2id = config.label2id
11
  self.id2label = config.id2label
12
  self.num_labels = config.num_labels
13
 
 
17
  num_layers=2,
18
  batch_first=True,
19
  bidirectional=True,
20
+ dropout=dropout_rate
21
  )
22
  self.dropout = nn.Dropout(dropout_rate)
23
  self.classifier = nn.Linear(rnn_dim * 2, self.num_labels)