abcd1234davidchen commited on
Commit
9ee1e28
·
1 Parent(s): 4296e9b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +9 -34
  2. model.py +82 -0
  3. stance_classifier.pth +3 -0
app.py CHANGED
@@ -1,39 +1,13 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import AutoTokenizer, AutoModel
4
  import gradio as gr
5
  import re
 
6
 
7
- class StanceClassifier(nn.Module):
8
- def __init__(self,transformer_model, num_classes, dropout_rate=0.6):
9
- super(StanceClassifier, self).__init__()
10
- self.transformer = transformer_model
11
- self.dropout = nn.Dropout(dropout_rate)
12
- self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
13
- l0 = transformer_model.config.hidden_size
14
- l1 = l0 // 2
15
- l2 = l1 // 2
16
- self.classifier = nn.Sequential(
17
- nn.Linear(l0, l1),
18
- nn.LayerNorm(l1),
19
- nn.GELU(),
20
- nn.Dropout(dropout_rate),
21
- nn.Linear(l1, l2),
22
- nn.LayerNorm(l2),
23
- nn.GELU(),
24
- nn.Dropout(dropout_rate),
25
- nn.Linear(l2, num_classes),
26
- )
27
- def forward(self, input_ids, attention_mask):
28
- outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
29
- pooled_output = outputs.last_hidden_state[:, 0]
30
- pooled_output = self.layer_norm(pooled_output)
31
- logits = self.classifier(pooled_output)
32
- return logits
33
-
34
  torch.manual_seed(42)
35
- checkpoint = "hfl/chinese-roberta-wwm-ext"
36
- tokenizer = AutoTokenizer.from_pretrained(checkpoint)
37
  base_model = AutoModel.from_pretrained(checkpoint)
38
 
39
  model = StanceClassifier(base_model, num_classes=3)
@@ -42,7 +16,7 @@ model.eval()
42
  labels = ['KMT', 'DPP', 'Neutral']
43
 
44
  def predict_stance(text):
45
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=64)
46
  with torch.no_grad():
47
  outputs = model(
48
  input_ids=inputs["input_ids"],
@@ -56,10 +30,11 @@ def predict_stance(text):
56
 
57
  def gradio_interface(text):
58
  sentences = re.split(r"[。!?\n]", text)
59
- sentences = [s for s in sentences if s.strip()]
 
60
  results = []
61
- for s in sentences:
62
- stance, conf = predict_stance(s)
63
  results.append((s + f" (Confidence: {conf:.4f})", stance))
64
  return results
65
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import AutoModel, BertTokenizerFast
4
  import gradio as gr
5
  import re
6
+ from model import StanceClassifier
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  torch.manual_seed(42)
9
+ checkpoint = "ckiplab/bert-base-chinese"
10
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
11
  base_model = AutoModel.from_pretrained(checkpoint)
12
 
13
  model = StanceClassifier(base_model, num_classes=3)
 
16
  labels = ['KMT', 'DPP', 'Neutral']
17
 
18
  def predict_stance(text):
19
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
20
  with torch.no_grad():
21
  outputs = model(
22
  input_ids=inputs["input_ids"],
 
30
 
31
  def gradio_interface(text):
32
  sentences = re.split(r"[。!?\n]", text)
33
+ sentences = [s for idx, s in enumerate(sentences) if s.strip()]
34
+ accumulate_sentence = [" ".join(sentences[:idx+1]) for idx, s in enumerate(sentences) if s.strip()]
35
  results = []
36
+ for s, acus in zip(sentences, accumulate_sentence):
37
+ stance, conf = predict_stance(acus)
38
  results.append((s + f" (Confidence: {conf:.4f})", stance))
39
  return results
40
 
model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class StanceClassifier(nn.Module):
6
+ def __init__(self, transformer_model, num_classes, dropout_rate=0.6):
7
+ super(StanceClassifier, self).__init__()
8
+ self.transformer = transformer_model
9
+ self.dropout = nn.Dropout(dropout_rate)
10
+ self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
11
+
12
+ l0 = transformer_model.config.hidden_size
13
+ l1 = transformer_model.config.hidden_size * 2
14
+ l2 = l1 // 2
15
+ l3 = l2 // 2
16
+ # classifier expects pooled token representation (batch, hidden)
17
+ self.classifier = nn.Sequential(
18
+ nn.Linear(l0, l1),
19
+ nn.LayerNorm(l1),
20
+ nn.GELU(),
21
+ nn.Dropout(dropout_rate),
22
+ nn.Linear(l1, l2),
23
+ nn.LayerNorm(l2),
24
+ nn.GELU(),
25
+ nn.Dropout(dropout_rate),
26
+ nn.Linear(l2, l3),
27
+ nn.LayerNorm(l3),
28
+ nn.GELU(),
29
+ nn.Linear(l3, num_classes),
30
+ )
31
+
32
+ self.attention_vector = nn.Linear(l0, 1)
33
+ nn.init.xavier_uniform_(self.attention_vector.weight)
34
+
35
+
36
+ self.freeze_transformer()
37
+
38
+ def freeze_transformer(self):
39
+ for param in self.transformer.parameters():
40
+ param.requires_grad = False
41
+
42
+ def unfreeze_transformer(self):
43
+ for param in self.transformer.parameters():
44
+ param.requires_grad = True
45
+
46
+ def forward(self, input_ids, attention_mask):
47
+ if not any(p.requires_grad for p in self.transformer.parameters()):
48
+ with torch.no_grad():
49
+ outputs = self.transformer(
50
+ input_ids=input_ids, attention_mask=attention_mask
51
+ )
52
+ else:
53
+ outputs = self.transformer(
54
+ input_ids=input_ids, attention_mask=attention_mask
55
+ )
56
+
57
+ # token-level hidden states: (batch, seq_len, hidden)
58
+ token_states = outputs.last_hidden_state
59
+
60
+ scores = self.attention_vector(token_states).squeeze(-1) # (batch, seq_len)
61
+ mask = attention_mask.to(dtype=torch.bool) # (batch, seq_len)
62
+ scores = scores.masked_fill(~mask, -1e9)
63
+ weights = torch.softmax(scores, dim=1) # (batch, seq_len)
64
+ pooled_output = (weights.unsqueeze(-1) * token_states).sum(dim=1) # (batch, hidden)
65
+
66
+ if torch.isnan(pooled_output).any() or torch.isinf(pooled_output).any():
67
+ print("WARNING: Transformer output NaN/Inf")
68
+ pooled_output = torch.where(
69
+ torch.isnan(pooled_output) | torch.isinf(pooled_output),
70
+ torch.zeros_like(pooled_output),
71
+ pooled_output,
72
+ )
73
+
74
+ pooled_output = self.layer_norm(pooled_output)
75
+ logits = self.classifier(pooled_output)
76
+ return logits
77
+
78
+ def classifier_params(self):
79
+ return list(self.classifier.parameters())
80
+
81
+ def transformer_params(self):
82
+ return list(self.transformer.parameters())
stance_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0674d3874bfd18e814d48820cffc501d586178802fbf7d044a96f6dcc0241b3d
3
+ size 419826179