FemiMatch / app.py
JulieH0524's picture
Update app.py
7ca3d8a verified
import torch
from transformers import BertTokenizerFast, BertForSequenceClassification
import pandas as pd
import gradio as gr
# 模型和分词器都在根目录
model = BertForSequenceClassification.from_pretrained(".")
tokenizer = BertTokenizerFast.from_pretrained(".")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# 载入人物和著作CSV
figures_df = pd.read_csv("figures.csv")
figure_info = {}
for _, row in figures_df.iterrows():
figure_info[row['流派']] = (
f"代表人物:{row['代表人物']}\n"
f"著作:{row['代表著作']}\n"
f"身份/简介:{row['身份/简介']}"
)
id2label = {
0: "交叉性女性主义",
1: "差异女性主义",
2: "激进女性主义",
3: "自由女性主义"
}
def analyze_paragraph(paragraph):
import re
from collections import Counter
sentences = re.split(r'[。!?]', paragraph)
sentences = [s.strip() for s in sentences if s.strip()]
predictions = []
for s in sentences:
inputs = tokenizer(s, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
with torch.no_grad():
outputs = model(**inputs)
pred_id = torch.argmax(outputs.logits, dim=1).item()
predictions.append(pred_id)
count = Counter(predictions)
total = len(predictions)
ratio_str = "\n".join([f"{id2label[k]}: {v/total:.1%}" for k, v in count.items()])
main_label_id = count.most_common(1)[0][0]
main_label = id2label[main_label_id]
figures = figure_info.get(main_label, "无对应人物资料")
return f"主导流派:{main_label}\n\n流派占比:\n{ratio_str}\n\n{figures}"
interface = gr.Interface(
fn=analyze_paragraph,
inputs=gr.Textbox(lines=5, placeholder="请输入文本..."),
outputs="text",
title="女性主义流派辨析模型",
description="分析输入文本的女性主义流派占比和主导流派,并给出对应人物著作。"
)
# 🚫 不要加 if __name__ == "__main__"
# ✅ 直接运行 interface
interface.launch()