shiluInfer / app.py
bztxb
完善提示信息
e562522
import os
import json
from typing import Dict, List
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, pipeline
MODEL_ID = "bztxb/shiluBERT"
LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", ".")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
THRESHOLD_DEFAULT = float(os.getenv("THRESHOLD_DEFAULT", "0.5"))
STRIDE = 0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE_INDEX = 0 if DEVICE == "cuda" else -1
DEFAULT_SAMPLE_TEXT = "○嚴私鹽之禁時戶部奏在京各衙門遣官吏人等於長蘆運司關支食鹽有將批文不投運司照買私鹽裝載各處販賣一二次者又有夾帶私鹽沿途發賣者及中鹽客啇支鹽不循舊例每包添私鹽至三四百斤者請令沿途巡檢司批驗所等處務要拘驗鹽批及鹽引數目嚴加盤詰秤掣若有批文違限夾帶私鹽者依律入官官吏人等如例送問仍行巡鹽御史通行嚴禁從之"
load_error = None
tokenizer = None
classifier = None
label_list: List[str] = []
def pick_model_source() -> str:
if os.path.exists(os.path.join(LOCAL_MODEL_DIR, "config.json")):
return LOCAL_MODEL_DIR
return MODEL_ID
def load_label_list(model_source: str) -> List[str]:
local_path = os.path.join(model_source, "label_map.json")
if os.path.exists(local_path):
file_path = local_path
else:
file_path = hf_hub_download(repo_id=model_source, filename="label_map.json")
with open(file_path, "r", encoding="utf-8") as file:
data = json.load(file)
if isinstance(data, dict) and isinstance(data.get("labels"), list):
return data["labels"]
if isinstance(data, list):
return data
return []
def map_label_name(raw_label: str) -> str:
if raw_label.startswith("LABEL_"):
try:
idx = int(raw_label.split("_", 1)[1])
if 0 <= idx < len(label_list):
return str(label_list[idx])
except Exception:
pass
return raw_label
try:
model_source = pick_model_source()
tokenizer = AutoTokenizer.from_pretrained(model_source, use_fast=True)
classifier = pipeline(
task="text-classification",
model=model_source,
tokenizer=tokenizer,
top_k=None,
device=DEVICE_INDEX,
)
label_list = load_label_list(model_source)
except Exception as exc:
load_error = str(exc)
def split_windows(text: str) -> List[str]:
enc = tokenizer(
text,
truncation=True,
max_length=MAX_LENGTH,
stride=STRIDE,
return_overflowing_tokens=True,
padding=False,
return_tensors=None,
)
input_ids_batch = enc.get("input_ids", [])
if not input_ids_batch:
return [text]
windows = [
tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for ids in input_ids_batch
]
return [window for window in windows if window.strip()] or [text]
def normalize_outputs(outputs):
if not outputs:
return []
if isinstance(outputs, list) and outputs and isinstance(outputs[0], dict):
return [outputs]
return outputs
def predict(text: str, threshold: float) -> Dict[str, float]:
if load_error is not None:
return {"error": load_error}
if not text or not text.strip():
return {"error": "请输入文本。"}
windows = split_windows(text)
outputs = classifier(windows, truncation=True, max_length=MAX_LENGTH)
outputs = normalize_outputs(outputs)
label_scores: Dict[str, float] = {}
for window_result in outputs:
for item in window_result:
label = map_label_name(str(item.get("label", "UNKNOWN")))
score = float(item.get("score", 0.0))
label_scores[label] = max(score, label_scores.get(label, 0.0))
items = sorted(label_scores.items(), key=lambda pair: pair[1], reverse=True)
selected = [(label, score) for label, score in items if score >= threshold]
if not selected:
return {"info": f"无标签达到当前阈值 {threshold:.2f},请尝试降低阈值以查看更多结果。"}
return {label: round(score, 6) for label, score in selected}
app = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(
lines=8,
label="输入文本后,可调整阈值以选择不同置信度水平下的标签",
placeholder="请输入待分类文本...",
value=DEFAULT_SAMPLE_TEXT,
),
gr.Slider(minimum=0.0, maximum=1.0, value=THRESHOLD_DEFAULT, step=0.01, label="阈值"),
],
outputs=gr.JSON(label="预测结果(标签:置信度)"),
title="明/清实录多标签分类推理",
#examples=[[DEFAULT_SAMPLE_TEXT, THRESHOLD_DEFAULT]],
)
if __name__ == "__main__":
app.launch()