File size: 4,962 Bytes
d3b85a7
 
 
610f255
d3b85a7
610f255
 
d3b85a7
610f255
7345988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3b85a7
610f255
 
 
 
 
68b725b
610f255
 
 
 
3641335
610f255
 
 
 
 
 
3641335
 
 
 
9e62f27
610f255
 
 
 
9e62f27
3641335
 
9e62f27
610f255
 
 
 
 
3641335
 
610f255
 
 
 
 
 
 
9e62f27
610f255
 
 
 
 
 
32043d6
610f255
 
3641335
68b725b
32043d6
1a4a4bc
 
610f255
d3b85a7
610f255
d3b85a7
610f255
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# app.py

import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

# 1. 指定已经训练好的多分类模型名称或路径
MODEL_NAME = "bztxb/shiluBERT"

# 载入标签
labels_list = ['上供', '中人', '中央亞', '中央行政', '中央軍', '主副食', '交通', '人事', '人文敎育',
               '人物', '任免', '住生活', '佛敎', '保健', '倉庫', '倫理', '倭', '儀式', '儒學', '元',
               '兩班', '兵法', '兵站', '其他', '出版', '前史', '勸農', '化學', '匠人', '印刷', '史學',
               '司法', '商人', '商品', '商業', '嗜好食品', '器皿祭物', '國王', '國用', '土俗信仰', '土地賣買',
               '土木', '地學', '地方自治', '地方行政', '地方軍', '外交', '天氣', '契', '妃嬪', '姓名', '宅地',
               '宗社', '宗親', '官廳手工', '官服', '宮官', '宴會', '家具', '家屋', '家族', '家産', '專賣',
               '工業', '市場', '常服', '常民', '度量衡', '建築', '建設', '彈劾', '役', '思想', '恤兵', '戰爭',
               '戶口', '戶籍', '手工業品', '手數料', '技術敎育', '採鑛', '政論', '政變', '故事', '敎育', '救恤',
               '數學', '文學', '明', '曆法', '書冊', '東南亞', '東學', '林業', '果樹園藝', '歐美', '歷史', '殖利',
               '民亂', '水利', '水産業', '水運', '治安', '法制', '漁業', '演劇', '物價', '物理', '特殊敎育',
               '特殊軍', '特用作物', '獸醫學', '王室', '琉球', '生物', '田制', '田稅', '畜産', '社會紀綱',
               '禁火', '禮俗', '禮服', '私營手工', '科學', '移動', '管理', '經營形態', '經筵', '綱常', '綿作',
               '編史', '美術', '聚落', '舞踊', '藝術', '藥學', '行刑', '行幸', '行政', '衣生活', '裁判',
               '裝身具', '製鍊', '西學', '親族', '語學', '語文學', '諫諍', '變亂', '財政', '貢物', '貨幣',
               '貿易', '賃貸', '賃金', '賜給', '賤人', '赴防', '身分', '身分變動', '身良役賤', '軍事', '軍器',
               '軍役', '軍政', '軍資', '農作', '農村手工', '農業', '農業技術', '通信', '進上', '運賃', '道敎',
               '選拔', '鄕村', '酒類', '醫學', '醫藥', '野', '量田', '金融', '鑛山', '鑛業', '開墾', '關防',
               '陸運', '雜稅', '音樂', '風俗', '食生活', '養蠶', '馬政', '鹽業']

# 2. 载入模型和分词器
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)

# 3. 使用 pipeline 封装推理逻辑
#    如果模型标签大于2个(多分类),可以指定 return_all_scores=True 来获取各分类概率
clf = pipeline(
    task="text-classification",
    model=model,
    tokenizer=tokenizer,
    return_all_scores=True,  # 多分类场景标签的得分
    function_to_apply="sigmoid" 
)

# 4. 定义预测函数
def predict(text_str):

    max_len = 510  # 可设置为 512,确保不超出最大限制
    if len(text_str.split()) > max_len:
        text_str = ' '.join(text_str.split()[:max_len])
    texts = [text_str]
    
    #texts = ["○辛亥。總督倉塲侍郎岳爾岱。因病解任。以禮科給事中兆華。署總督倉塲侍郎"]
    results = clf(texts)
    #print(results)
    threshold = 0.5 #设定阈值
    
    # 训练时的标签顺序,比如: 
    label_list = labels_list  # 假设有X个标签
    
    final_predictions = []  # 存放每条文本对应的多标签预测
    for single_text_scores in results:
        # single_text_scores 是形如 [ {"label":"LABEL_0", "score":...}, {"label":"LABEL_1",...}, ... ]
        assigned_labels = []
        for label_dict in single_text_scores:
            # label_dict["label"] 是 "LABEL_0"/"LABEL_1" 类似这样的名字
            # 提取数字索引:
            label_idx_str = label_dict["label"].replace("LABEL_", "")
            label_idx = int(label_idx_str)
            
            if label_dict["score"] >= threshold:
                assigned_labels.append(label_list[label_idx])
        
        final_predictions.append(assigned_labels)
    
    for text, preds in zip(texts, final_predictions):
        print("文本:", text)
        print("预测标签:", preds, "\n")
    return [text, preds][1]

# 5. 构建 Gradio 界面
#    要更个性化渲染,可使用不同的组件(outputs=...)
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="请输入待预测文本", label="实录文本"),
    outputs="text",
    title="推理明/清实录的多标签分类",
    description="输入待分类的文本",
    theme="compact"
)

# 6. 启动服务
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)