import torch from transformers import BertTokenizerFast, BertForSequenceClassification import pandas as pd import gradio as gr # 模型和分词器都在根目录 model = BertForSequenceClassification.from_pretrained(".") tokenizer = BertTokenizerFast.from_pretrained(".") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # 载入人物和著作CSV figures_df = pd.read_csv("figures.csv") figure_info = {} for _, row in figures_df.iterrows(): figure_info[row['流派']] = ( f"代表人物:{row['代表人物']}\n" f"著作:{row['代表著作']}\n" f"身份/简介:{row['身份/简介']}" ) id2label = { 0: "交叉性女性主义", 1: "差异女性主义", 2: "激进女性主义", 3: "自由女性主义" } def analyze_paragraph(paragraph): import re from collections import Counter sentences = re.split(r'[。!?]', paragraph) sentences = [s.strip() for s in sentences if s.strip()] predictions = [] for s in sentences: inputs = tokenizer(s, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device) with torch.no_grad(): outputs = model(**inputs) pred_id = torch.argmax(outputs.logits, dim=1).item() predictions.append(pred_id) count = Counter(predictions) total = len(predictions) ratio_str = "\n".join([f"{id2label[k]}: {v/total:.1%}" for k, v in count.items()]) main_label_id = count.most_common(1)[0][0] main_label = id2label[main_label_id] figures = figure_info.get(main_label, "无对应人物资料") return f"主导流派:{main_label}\n\n流派占比:\n{ratio_str}\n\n{figures}" interface = gr.Interface( fn=analyze_paragraph, inputs=gr.Textbox(lines=5, placeholder="请输入文本..."), outputs="text", title="女性主义流派辨析模型", description="分析输入文本的女性主义流派占比和主导流派,并给出对应人物著作。" ) # 🚫 不要加 if __name__ == "__main__" # ✅ 直接运行 interface interface.launch()