InferShilu / app.py
bztxb's picture
Update app.py
3641335 verified
# app.py
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
# 1. 指定已经训练好的多分类模型名称或路径
MODEL_NAME = "bztxb/shiluBERT"
# 载入标签
labels_list = ['上供', '中人', '中央亞', '中央行政', '中央軍', '主副食', '交通', '人事', '人文敎育',
'人物', '任免', '住生活', '佛敎', '保健', '倉庫', '倫理', '倭', '儀式', '儒學', '元',
'兩班', '兵法', '兵站', '其他', '出版', '前史', '勸農', '化學', '匠人', '印刷', '史學',
'司法', '商人', '商品', '商業', '嗜好食品', '器皿祭物', '國王', '國用', '土俗信仰', '土地賣買',
'土木', '地學', '地方自治', '地方行政', '地方軍', '外交', '天氣', '契', '妃嬪', '姓名', '宅地',
'宗社', '宗親', '官廳手工', '官服', '宮官', '宴會', '家具', '家屋', '家族', '家産', '專賣',
'工業', '市場', '常服', '常民', '度量衡', '建築', '建設', '彈劾', '役', '思想', '恤兵', '戰爭',
'戶口', '戶籍', '手工業品', '手數料', '技術敎育', '採鑛', '政論', '政變', '故事', '敎育', '救恤',
'數學', '文學', '明', '曆法', '書冊', '東南亞', '東學', '林業', '果樹園藝', '歐美', '歷史', '殖利',
'民亂', '水利', '水産業', '水運', '治安', '法制', '漁業', '演劇', '物價', '物理', '特殊敎育',
'特殊軍', '特用作物', '獸醫學', '王室', '琉球', '生物', '田制', '田稅', '畜産', '社會紀綱',
'禁火', '禮俗', '禮服', '私營手工', '科學', '移動', '管理', '經營形態', '經筵', '綱常', '綿作',
'編史', '美術', '聚落', '舞踊', '藝術', '藥學', '行刑', '行幸', '行政', '衣生活', '裁判',
'裝身具', '製鍊', '西學', '親族', '語學', '語文學', '諫諍', '變亂', '財政', '貢物', '貨幣',
'貿易', '賃貸', '賃金', '賜給', '賤人', '赴防', '身分', '身分變動', '身良役賤', '軍事', '軍器',
'軍役', '軍政', '軍資', '農作', '農村手工', '農業', '農業技術', '通信', '進上', '運賃', '道敎',
'選拔', '鄕村', '酒類', '醫學', '醫藥', '野', '量田', '金融', '鑛山', '鑛業', '開墾', '關防',
'陸運', '雜稅', '音樂', '風俗', '食生活', '養蠶', '馬政', '鹽業']
# 2. 载入模型和分词器
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
# 3. 使用 pipeline 封装推理逻辑
# 如果模型标签大于2个(多分类),可以指定 return_all_scores=True 来获取各分类概率
clf = pipeline(
task="text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True, # 多分类场景标签的得分
function_to_apply="sigmoid"
)
# 4. 定义预测函数
def predict(text_str):
max_len = 510 # 可设置为 512,确保不超出最大限制
if len(text_str.split()) > max_len:
text_str = ' '.join(text_str.split()[:max_len])
texts = [text_str]
#texts = ["○辛亥。總督倉塲侍郎岳爾岱。因病解任。以禮科給事中兆華。署總督倉塲侍郎"]
results = clf(texts)
#print(results)
threshold = 0.5 #设定阈值
# 训练时的标签顺序,比如:
label_list = labels_list # 假设有X个标签
final_predictions = [] # 存放每条文本对应的多标签预测
for single_text_scores in results:
# single_text_scores 是形如 [ {"label":"LABEL_0", "score":...}, {"label":"LABEL_1",...}, ... ]
assigned_labels = []
for label_dict in single_text_scores:
# label_dict["label"] 是 "LABEL_0"/"LABEL_1" 类似这样的名字
# 提取数字索引:
label_idx_str = label_dict["label"].replace("LABEL_", "")
label_idx = int(label_idx_str)
if label_dict["score"] >= threshold:
assigned_labels.append(label_list[label_idx])
final_predictions.append(assigned_labels)
for text, preds in zip(texts, final_predictions):
print("文本:", text)
print("预测标签:", preds, "\n")
return [text, preds][1]
# 5. 构建 Gradio 界面
# 要更个性化渲染,可使用不同的组件(outputs=...)
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=3, placeholder="请输入待预测文本", label="实录文本"),
outputs="text",
title="推理明/清实录的多标签分类",
description="输入待分类的文本",
theme="compact"
)
# 6. 启动服务
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)