Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
| # 1. 指定已经训练好的多分类模型名称或路径 | |
| MODEL_NAME = "bztxb/shiluBERT" | |
| # 载入标签 | |
| labels_list = ['上供', '中人', '中央亞', '中央行政', '中央軍', '主副食', '交通', '人事', '人文敎育', | |
| '人物', '任免', '住生活', '佛敎', '保健', '倉庫', '倫理', '倭', '儀式', '儒學', '元', | |
| '兩班', '兵法', '兵站', '其他', '出版', '前史', '勸農', '化學', '匠人', '印刷', '史學', | |
| '司法', '商人', '商品', '商業', '嗜好食品', '器皿祭物', '國王', '國用', '土俗信仰', '土地賣買', | |
| '土木', '地學', '地方自治', '地方行政', '地方軍', '外交', '天氣', '契', '妃嬪', '姓名', '宅地', | |
| '宗社', '宗親', '官廳手工', '官服', '宮官', '宴會', '家具', '家屋', '家族', '家産', '專賣', | |
| '工業', '市場', '常服', '常民', '度量衡', '建築', '建設', '彈劾', '役', '思想', '恤兵', '戰爭', | |
| '戶口', '戶籍', '手工業品', '手數料', '技術敎育', '採鑛', '政論', '政變', '故事', '敎育', '救恤', | |
| '數學', '文學', '明', '曆法', '書冊', '東南亞', '東學', '林業', '果樹園藝', '歐美', '歷史', '殖利', | |
| '民亂', '水利', '水産業', '水運', '治安', '法制', '漁業', '演劇', '物價', '物理', '特殊敎育', | |
| '特殊軍', '特用作物', '獸醫學', '王室', '琉球', '生物', '田制', '田稅', '畜産', '社會紀綱', | |
| '禁火', '禮俗', '禮服', '私營手工', '科學', '移動', '管理', '經營形態', '經筵', '綱常', '綿作', | |
| '編史', '美術', '聚落', '舞踊', '藝術', '藥學', '行刑', '行幸', '行政', '衣生活', '裁判', | |
| '裝身具', '製鍊', '西學', '親族', '語學', '語文學', '諫諍', '變亂', '財政', '貢物', '貨幣', | |
| '貿易', '賃貸', '賃金', '賜給', '賤人', '赴防', '身分', '身分變動', '身良役賤', '軍事', '軍器', | |
| '軍役', '軍政', '軍資', '農作', '農村手工', '農業', '農業技術', '通信', '進上', '運賃', '道敎', | |
| '選拔', '鄕村', '酒類', '醫學', '醫藥', '野', '量田', '金融', '鑛山', '鑛業', '開墾', '關防', | |
| '陸運', '雜稅', '音樂', '風俗', '食生活', '養蠶', '馬政', '鹽業'] | |
| # 2. 载入模型和分词器 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| # 3. 使用 pipeline 封装推理逻辑 | |
| # 如果模型标签大于2个(多分类),可以指定 return_all_scores=True 来获取各分类概率 | |
| clf = pipeline( | |
| task="text-classification", | |
| model=model, | |
| tokenizer=tokenizer, | |
| return_all_scores=True, # 多分类场景标签的得分 | |
| function_to_apply="sigmoid" | |
| ) | |
| # 4. 定义预测函数 | |
| def predict(text_str): | |
| max_len = 510 # 可设置为 512,确保不超出最大限制 | |
| if len(text_str.split()) > max_len: | |
| text_str = ' '.join(text_str.split()[:max_len]) | |
| texts = [text_str] | |
| #texts = ["○辛亥。總督倉塲侍郎岳爾岱。因病解任。以禮科給事中兆華。署總督倉塲侍郎"] | |
| results = clf(texts) | |
| #print(results) | |
| threshold = 0.5 #设定阈值 | |
| # 训练时的标签顺序,比如: | |
| label_list = labels_list # 假设有X个标签 | |
| final_predictions = [] # 存放每条文本对应的多标签预测 | |
| for single_text_scores in results: | |
| # single_text_scores 是形如 [ {"label":"LABEL_0", "score":...}, {"label":"LABEL_1",...}, ... ] | |
| assigned_labels = [] | |
| for label_dict in single_text_scores: | |
| # label_dict["label"] 是 "LABEL_0"/"LABEL_1" 类似这样的名字 | |
| # 提取数字索引: | |
| label_idx_str = label_dict["label"].replace("LABEL_", "") | |
| label_idx = int(label_idx_str) | |
| if label_dict["score"] >= threshold: | |
| assigned_labels.append(label_list[label_idx]) | |
| final_predictions.append(assigned_labels) | |
| for text, preds in zip(texts, final_predictions): | |
| print("文本:", text) | |
| print("预测标签:", preds, "\n") | |
| return [text, preds][1] | |
| # 5. 构建 Gradio 界面 | |
| # 要更个性化渲染,可使用不同的组件(outputs=...) | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=3, placeholder="请输入待预测文本", label="实录文本"), | |
| outputs="text", | |
| title="推理明/清实录的多标签分类", | |
| description="输入待分类的文本", | |
| theme="compact" | |
| ) | |
| # 6. 启动服务 | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |