|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification |
|
|
MODEL_REPO = "robot4/emotion" |
|
|
MODEL_SUBFOLDER = "checkpoints_finally_end" |
|
|
print(f"正在加载 BERT 模型: {MODEL_REPO} ...") |
|
|
try: |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(MODEL_REPO, subfolder=MODEL_SUBFOLDER) |
|
|
|
|
|
|
|
|
config = BertConfig.from_pretrained(MODEL_REPO, subfolder=MODEL_SUBFOLDER) |
|
|
|
|
|
|
|
|
model = BertForSequenceClassification.from_pretrained( |
|
|
MODEL_REPO, |
|
|
config=config, |
|
|
subfolder=MODEL_SUBFOLDER |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
def predict(text): |
|
|
if not text: return None, "请输入内容" |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
probs = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
pred_idx = torch.argmax(probs).item() |
|
|
confidence = probs[0][pred_idx].item() |
|
|
|
|
|
|
|
|
id2label = {0: '😡 消极 (Negative)', 1: '😐 中性 (Neutral)', 2: '😊 积极 (Positive)'} |
|
|
label = id2label.get(pred_idx, "Unknown") |
|
|
|
|
|
|
|
|
return { |
|
|
'积极': probs[0][2].item(), |
|
|
'中性': probs[0][1].item(), |
|
|
'消极': probs[0][0].item() |
|
|
}, f"预测结果: {label}\n置信度: {confidence:.4f}" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="中文情感分析") as demo: |
|
|
gr.Markdown(f"# 🎭 中文情感分析演示 (BERT)") |
|
|
gr.Markdown(f"模型加载自: [Hugging Face Hub]({MODEL_REPO})") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
inp = gr.Textbox(label="输入中文评论", lines=4, placeholder="比如:这家店真的太好吃了,强烈推荐!") |
|
|
btn = gr.Button("开始分析", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
out_label = gr.Label(label="情感概率") |
|
|
out_text = gr.Textbox(label="详细结果") |
|
|
|
|
|
btn.click(predict, inputs=inp, outputs=[out_label, out_text]) |
|
|
|
|
|
gr.Examples( |
|
|
examples=["这家店太难吃了,避雷!", "还可以,中规中矩。", "超级好评,下次还来!", "物流稍微有点慢,但东西不错。"], |
|
|
inputs=inp |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |