Upload stance_classifier.pth

#1
by ben-jian - opened
Files changed (2) hide show
  1. app.py +29 -26
  2. model.py +0 -82
app.py CHANGED
@@ -1,27 +1,41 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import AutoModel, BertTokenizerFast
4
  import gradio as gr
5
  import re
6
- from model import StanceClassifier
7
- import os
8
- import huggingface_hub
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  torch.manual_seed(42)
11
- checkpoint = "ckiplab/bert-base-chinese"
12
- tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
13
  base_model = AutoModel.from_pretrained(checkpoint)
14
 
15
  model = StanceClassifier(base_model, num_classes=3)
16
-
17
- dict_path = huggingface_hub.hf_hub_download(repo_id="abcd1234davidchen/PolStanceBERT",filename="stance_classifier.pth",local_dir=".",local_dir_use_symlinks=False)
18
-
19
- model.load_state_dict(torch.load(dict_path, map_location=torch.device('cpu')))
20
  model.eval()
21
  labels = ['KMT', 'DPP', 'Neutral']
22
 
23
  def predict_stance(text):
24
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
25
  with torch.no_grad():
26
  outputs = model(
27
  input_ids=inputs["input_ids"],
@@ -34,22 +48,11 @@ def predict_stance(text):
34
  return labels[predicted_class], confidence
35
 
36
  def gradio_interface(text):
37
- singleSentenceMode = False
38
- if text[0:1]=="!" or text[0:1]=="!":
39
- text=text[1:]
40
- singleSentenceMode = True
41
-
42
  sentences = re.split(r"[。!?\n]", text)
43
- sentences = [s for idx, s in enumerate(sentences) if s.strip()]
44
- accumulate_sentence = [" ".join(sentences[:idx+1]) for idx, s in enumerate(sentences) if s.strip()]
45
  results = []
46
- if singleSentenceMode:
47
- for s in sentences:
48
- stance, conf = predict_stance(s)
49
- results.append((s + f" (Confidence: {conf:.4f})", stance))
50
- return results
51
- for s, acus in zip(sentences, accumulate_sentence):
52
- stance, conf = predict_stance(acus)
53
  results.append((s + f" (Confidence: {conf:.4f})", stance))
54
  return results
55
 
@@ -59,7 +62,7 @@ def ui():
59
  inputs=gr.Textbox(label="Input Text", placeholder="Enter text to predict political stance..."),
60
  outputs=gr.HighlightedText(label="Prediction Result",color_map={"KMT":"blue","DPP":"green","Neutral":"purple"}),
61
  title="Political Stance Prediction",
62
- description="Enter a text to predict its political stance (KMT, DPP, Neutral). Prefix a sentence with '!' or '!' to analyze each sentence individually.",
63
  ).launch()
64
 
65
  if __name__ == "__main__":
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel
4
  import gradio as gr
5
  import re
 
 
 
6
 
7
+ class StanceClassifier(nn.Module):
8
+ def __init__(self,transformer_model, num_classes, dropout_rate=0.6):
9
+ super(StanceClassifier, self).__init__()
10
+ self.transformer = transformer_model
11
+ self.dropout = nn.Dropout(dropout_rate)
12
+ self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
13
+ self.classifier = nn.Sequential(
14
+ nn.Dropout(dropout_rate),
15
+ nn.Linear(transformer_model.config.hidden_size, transformer_model.config.hidden_size//2),
16
+ nn.ReLU(),
17
+ nn.Dropout(dropout_rate),
18
+ nn.Linear(transformer_model.config.hidden_size//2, num_classes)
19
+ )
20
+ def forward(self, input_ids, attention_mask):
21
+ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
22
+ pooled_output = outputs.last_hidden_state[:, 0]
23
+ pooled_output = self.layer_norm(pooled_output)
24
+ logits = self.classifier(pooled_output)
25
+ return logits
26
+
27
  torch.manual_seed(42)
28
+ checkpoint = "bert-base-chinese"
29
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
30
  base_model = AutoModel.from_pretrained(checkpoint)
31
 
32
  model = StanceClassifier(base_model, num_classes=3)
33
+ model.load_state_dict(torch.load("stance_classifier.pth", map_location=torch.device('cpu')))
 
 
 
34
  model.eval()
35
  labels = ['KMT', 'DPP', 'Neutral']
36
 
37
  def predict_stance(text):
38
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=64)
39
  with torch.no_grad():
40
  outputs = model(
41
  input_ids=inputs["input_ids"],
 
48
  return labels[predicted_class], confidence
49
 
50
  def gradio_interface(text):
 
 
 
 
 
51
  sentences = re.split(r"[。!?\n]", text)
52
+ sentences = [s for s in sentences if s.strip()]
 
53
  results = []
54
+ for s in sentences:
55
+ stance, conf = predict_stance(s)
 
 
 
 
 
56
  results.append((s + f" (Confidence: {conf:.4f})", stance))
57
  return results
58
 
 
62
  inputs=gr.Textbox(label="Input Text", placeholder="Enter text to predict political stance..."),
63
  outputs=gr.HighlightedText(label="Prediction Result",color_map={"KMT":"blue","DPP":"green","Neutral":"purple"}),
64
  title="Political Stance Prediction",
65
+ description="Enter a text to predict its political stance (KMT, DPP, Neutral)."
66
  ).launch()
67
 
68
  if __name__ == "__main__":
model.py DELETED
@@ -1,82 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class StanceClassifier(nn.Module):
6
- def __init__(self, transformer_model, num_classes, dropout_rate=0.6):
7
- super(StanceClassifier, self).__init__()
8
- self.transformer = transformer_model
9
- self.dropout = nn.Dropout(dropout_rate)
10
- self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
11
-
12
- l0 = transformer_model.config.hidden_size
13
- l1 = transformer_model.config.hidden_size * 2
14
- l2 = l1 // 2
15
- l3 = l2 // 2
16
- # classifier expects pooled token representation (batch, hidden)
17
- self.classifier = nn.Sequential(
18
- nn.Linear(l0, l1),
19
- nn.LayerNorm(l1),
20
- nn.GELU(),
21
- nn.Dropout(dropout_rate),
22
- nn.Linear(l1, l2),
23
- nn.LayerNorm(l2),
24
- nn.GELU(),
25
- nn.Dropout(dropout_rate),
26
- nn.Linear(l2, l3),
27
- nn.LayerNorm(l3),
28
- nn.GELU(),
29
- nn.Linear(l3, num_classes),
30
- )
31
-
32
- self.attention_vector = nn.Linear(l0, 1)
33
- nn.init.xavier_uniform_(self.attention_vector.weight)
34
-
35
-
36
- self.freeze_transformer()
37
-
38
- def freeze_transformer(self):
39
- for param in self.transformer.parameters():
40
- param.requires_grad = False
41
-
42
- def unfreeze_transformer(self):
43
- for param in self.transformer.parameters():
44
- param.requires_grad = True
45
-
46
- def forward(self, input_ids, attention_mask):
47
- if not any(p.requires_grad for p in self.transformer.parameters()):
48
- with torch.no_grad():
49
- outputs = self.transformer(
50
- input_ids=input_ids, attention_mask=attention_mask
51
- )
52
- else:
53
- outputs = self.transformer(
54
- input_ids=input_ids, attention_mask=attention_mask
55
- )
56
-
57
- # token-level hidden states: (batch, seq_len, hidden)
58
- token_states = outputs.last_hidden_state
59
-
60
- scores = self.attention_vector(token_states).squeeze(-1) # (batch, seq_len)
61
- mask = attention_mask.to(dtype=torch.bool) # (batch, seq_len)
62
- scores = scores.masked_fill(~mask, -1e9)
63
- weights = torch.softmax(scores, dim=1) # (batch, seq_len)
64
- pooled_output = (weights.unsqueeze(-1) * token_states).sum(dim=1) # (batch, hidden)
65
-
66
- if torch.isnan(pooled_output).any() or torch.isinf(pooled_output).any():
67
- print("WARNING: Transformer output NaN/Inf")
68
- pooled_output = torch.where(
69
- torch.isnan(pooled_output) | torch.isinf(pooled_output),
70
- torch.zeros_like(pooled_output),
71
- pooled_output,
72
- )
73
-
74
- pooled_output = self.layer_norm(pooled_output)
75
- logits = self.classifier(pooled_output)
76
- return logits
77
-
78
- def classifier_params(self):
79
- return list(self.classifier.parameters())
80
-
81
- def transformer_params(self):
82
- return list(self.transformer.parameters())