PolStance / app.py
abcd1234davidchen's picture
Update app.py
7d6eeac verified
import torch
import torch.nn as nn
from transformers import AutoModel, BertTokenizerFast
import gradio as gr
import re
from model import StanceClassifier
import os
import huggingface_hub
torch.manual_seed(42)
checkpoint = "ckiplab/bert-base-chinese"
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
base_model = AutoModel.from_pretrained(checkpoint)
model = StanceClassifier(base_model, num_classes=3)
dict_path = huggingface_hub.hf_hub_download(repo_id="abcd1234davidchen/PolStanceBERT",filename="stance_classifier.pth",local_dir=".",local_dir_use_symlinks=False)
model.load_state_dict(torch.load(dict_path, map_location=torch.device('cpu')))
model.eval()
labels = ['KMT', 'DPP', 'Neutral']
def predict_stance(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"]
)
probs = nn.Softmax(dim=1)(outputs)
print(probs)
predicted_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][predicted_class].item()
return labels[predicted_class], confidence
def gradio_interface(text):
singleSentenceMode = False
if text[0:1]=="!" or text[0:1]=="!":
text=text[1:]
singleSentenceMode = True
sentences = re.split(r"[。!?\n]", text)
sentences = [s for idx, s in enumerate(sentences) if s.strip()]
accumulate_sentence = [" ".join(sentences[:idx+1]) for idx, s in enumerate(sentences) if s.strip()]
results = []
if singleSentenceMode:
for s in sentences:
stance, conf = predict_stance(s)
results.append((s + f" (Confidence: {conf:.4f})", stance))
return results
for s, acus in zip(sentences, accumulate_sentence):
stance, conf = predict_stance(acus)
results.append((s + f" (Confidence: {conf:.4f})", stance))
return results
def ui():
gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Input Text", placeholder="Enter text to predict political stance..."),
outputs=gr.HighlightedText(label="Prediction Result",color_map={"KMT":"blue","DPP":"green","Neutral":"purple"}),
title="Political Stance Prediction",
description="Enter a text to predict its political stance (KMT, DPP, Neutral). Prefix a sentence with '!' or '!' to analyze each sentence individually.",
).launch()
if __name__ == "__main__":
ui()