Spaces:
Sleeping
Sleeping
File size: 2,540 Bytes
1d43ce2 9ee1e28 1d43ce2 41bb39e 9ee1e28 3aa6477 1d43ce2 9ee1e28 1d43ce2 3aa6477 7d6eeac 3aa6477 7d6eeac 1d43ce2 9ee1e28 1d43ce2 3aa6477 41bb39e 9ee1e28 41bb39e 3aa6477 9ee1e28 41bb39e 1d43ce2 41bb39e 1d43ce2 3aa6477 1d43ce2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import torch
import torch.nn as nn
from transformers import AutoModel, BertTokenizerFast
import gradio as gr
import re
from model import StanceClassifier
import os
import huggingface_hub
torch.manual_seed(42)
checkpoint = "ckiplab/bert-base-chinese"
tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
base_model = AutoModel.from_pretrained(checkpoint)
model = StanceClassifier(base_model, num_classes=3)
dict_path = huggingface_hub.hf_hub_download(repo_id="abcd1234davidchen/PolStanceBERT",filename="stance_classifier.pth",local_dir=".",local_dir_use_symlinks=False)
model.load_state_dict(torch.load(dict_path, map_location=torch.device('cpu')))
model.eval()
labels = ['KMT', 'DPP', 'Neutral']
def predict_stance(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"]
)
probs = nn.Softmax(dim=1)(outputs)
print(probs)
predicted_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][predicted_class].item()
return labels[predicted_class], confidence
def gradio_interface(text):
singleSentenceMode = False
if text[0:1]=="!" or text[0:1]=="!":
text=text[1:]
singleSentenceMode = True
sentences = re.split(r"[。!?\n]", text)
sentences = [s for idx, s in enumerate(sentences) if s.strip()]
accumulate_sentence = [" ".join(sentences[:idx+1]) for idx, s in enumerate(sentences) if s.strip()]
results = []
if singleSentenceMode:
for s in sentences:
stance, conf = predict_stance(s)
results.append((s + f" (Confidence: {conf:.4f})", stance))
return results
for s, acus in zip(sentences, accumulate_sentence):
stance, conf = predict_stance(acus)
results.append((s + f" (Confidence: {conf:.4f})", stance))
return results
def ui():
gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Input Text", placeholder="Enter text to predict political stance..."),
outputs=gr.HighlightedText(label="Prediction Result",color_map={"KMT":"blue","DPP":"green","Neutral":"purple"}),
title="Political Stance Prediction",
description="Enter a text to predict its political stance (KMT, DPP, Neutral). Prefix a sentence with '!' or '!' to analyze each sentence individually.",
).launch()
if __name__ == "__main__":
ui() |