abcd1234davidchen commited on
Commit
1d43ce2
·
1 Parent(s): 2a23b9d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import gradio as gr
5
+
6
+ class StanceClassifier(nn.Module):
7
+ def __init__(self,transformer_model, num_classes, dropout_rate=0.6):
8
+ super(StanceClassifier, self).__init__()
9
+ self.transformer = transformer_model
10
+ self.dropout = nn.Dropout(dropout_rate)
11
+ self.layer_norm = nn.LayerNorm(transformer_model.config.hidden_size)
12
+ self.classifier = nn.Sequential(
13
+ nn.Dropout(dropout_rate),
14
+ nn.Linear(transformer_model.config.hidden_size, transformer_model.config.hidden_size//2),
15
+ nn.ReLU(),
16
+ nn.Dropout(dropout_rate),
17
+ nn.Linear(transformer_model.config.hidden_size//2, num_classes)
18
+ )
19
+ def forward(self, input_ids, attention_mask):
20
+ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
21
+ pooled_output = outputs.last_hidden_state[:, 0]
22
+ pooled_output = self.layer_norm(pooled_output)
23
+ logits = self.classifier(pooled_output)
24
+ return logits
25
+
26
+ torch.manual_seed(42)
27
+ checkpoint = "bert-base-chinese"
28
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
29
+ base_model = AutoModel.from_pretrained(checkpoint)
30
+
31
+ model = StanceClassifier(base_model, num_classes=3)
32
+ model.load_state_dict(torch.load("stance_classifier.pth", map_location=torch.device('cpu')))
33
+ model.eval()
34
+ labels = ['KMT', 'DPP', 'Neutral']
35
+
36
+ def predict_stance(text):
37
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=64)
38
+ with torch.no_grad():
39
+ outputs = model(
40
+ input_ids=inputs["input_ids"],
41
+ attention_mask=inputs["attention_mask"]
42
+ )
43
+ probs = nn.Softmax(dim=1)(outputs)
44
+ print(probs)
45
+ predicted_class = torch.argmax(probs, dim=1).item()
46
+ confidence = probs[0][predicted_class].item()
47
+ return labels[predicted_class], confidence
48
+
49
+ def gradio_interface(text):
50
+ stance, conf = predict_stance(text)
51
+ return f"Predicted Stance: {stance} with confidence {conf:.4f}"
52
+
53
+ def ui():
54
+ gr.Interface(
55
+ fn=gradio_interface,
56
+ inputs=gr.Textbox(label="Input Text", placeholder="Enter text to predict political stance..."),
57
+ outputs=gr.Textbox(label="Prediction Result"),
58
+ title="Political Stance Prediction",
59
+ description="Enter a text to predict its political stance (KMT, DPP, Neutral)."
60
+ ).launch()
61
+
62
+ if __name__ == "__main__":
63
+ ui()