Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -10,12 +10,18 @@ class StanceClassifier(nn.Module):
10
  self.transformer = transformer_model
11
  self.dropout = nn.Dropout(dropout_rate)
12
  self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
 
 
 
13
  self.classifier = nn.Sequential(
14
  nn.Dropout(dropout_rate),
15
- nn.Linear(transformer_model.config.hidden_size, transformer_model.config.hidden_size//2),
16
  nn.ReLU(),
17
  nn.Dropout(dropout_rate),
18
- nn.Linear(transformer_model.config.hidden_size//2, num_classes)
 
 
 
19
  )
20
  def forward(self, input_ids, attention_mask):
21
  outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
@@ -25,7 +31,7 @@ class StanceClassifier(nn.Module):
25
  return logits
26
 
27
  torch.manual_seed(42)
28
- checkpoint = "bert-base-chinese"
29
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
30
  base_model = AutoModel.from_pretrained(checkpoint)
31
 
 
10
  self.transformer = transformer_model
11
  self.dropout = nn.Dropout(dropout_rate)
12
  self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
13
+ l0 = transformer_model.config.hidden_size
14
+ l1 = transformer_model.config.hidden_size //2
15
+ l2 = l1 //2
16
  self.classifier = nn.Sequential(
17
  nn.Dropout(dropout_rate),
18
+ nn.Linear(l0,l1),
19
  nn.ReLU(),
20
  nn.Dropout(dropout_rate),
21
+ nn.Linear(l1,l2),
22
+ nn.Softmax(dim=1),
23
+ nn.Dropout(dropout_rate),
24
+ nn.Linear(l2, num_classes),
25
  )
26
  def forward(self, input_ids, attention_mask):
27
  outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
 
31
  return logits
32
 
33
  torch.manual_seed(42)
34
+ checkpoint = "hfl/chinese-roberta-wwm-ext"
35
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
36
  base_model = AutoModel.from_pretrained(checkpoint)
37