Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| from typing import Dict, List | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer, pipeline | |
| MODEL_ID = "bztxb/shiluBERT" | |
| LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", ".") | |
| MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) | |
| THRESHOLD_DEFAULT = float(os.getenv("THRESHOLD_DEFAULT", "0.5")) | |
| STRIDE = 0 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DEVICE_INDEX = 0 if DEVICE == "cuda" else -1 | |
| DEFAULT_SAMPLE_TEXT = "○嚴私鹽之禁時戶部奏在京各衙門遣官吏人等於長蘆運司關支食鹽有將批文不投運司照買私鹽裝載各處販賣一二次者又有夾帶私鹽沿途發賣者及中鹽客啇支鹽不循舊例每包添私鹽至三四百斤者請令沿途巡檢司批驗所等處務要拘驗鹽批及鹽引數目嚴加盤詰秤掣若有批文違限夾帶私鹽者依律入官官吏人等如例送問仍行巡鹽御史通行嚴禁從之" | |
| load_error = None | |
| tokenizer = None | |
| classifier = None | |
| label_list: List[str] = [] | |
| def pick_model_source() -> str: | |
| if os.path.exists(os.path.join(LOCAL_MODEL_DIR, "config.json")): | |
| return LOCAL_MODEL_DIR | |
| return MODEL_ID | |
| def load_label_list(model_source: str) -> List[str]: | |
| local_path = os.path.join(model_source, "label_map.json") | |
| if os.path.exists(local_path): | |
| file_path = local_path | |
| else: | |
| file_path = hf_hub_download(repo_id=model_source, filename="label_map.json") | |
| with open(file_path, "r", encoding="utf-8") as file: | |
| data = json.load(file) | |
| if isinstance(data, dict) and isinstance(data.get("labels"), list): | |
| return data["labels"] | |
| if isinstance(data, list): | |
| return data | |
| return [] | |
| def map_label_name(raw_label: str) -> str: | |
| if raw_label.startswith("LABEL_"): | |
| try: | |
| idx = int(raw_label.split("_", 1)[1]) | |
| if 0 <= idx < len(label_list): | |
| return str(label_list[idx]) | |
| except Exception: | |
| pass | |
| return raw_label | |
| try: | |
| model_source = pick_model_source() | |
| tokenizer = AutoTokenizer.from_pretrained(model_source, use_fast=True) | |
| classifier = pipeline( | |
| task="text-classification", | |
| model=model_source, | |
| tokenizer=tokenizer, | |
| top_k=None, | |
| device=DEVICE_INDEX, | |
| ) | |
| label_list = load_label_list(model_source) | |
| except Exception as exc: | |
| load_error = str(exc) | |
| def split_windows(text: str) -> List[str]: | |
| enc = tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=MAX_LENGTH, | |
| stride=STRIDE, | |
| return_overflowing_tokens=True, | |
| padding=False, | |
| return_tensors=None, | |
| ) | |
| input_ids_batch = enc.get("input_ids", []) | |
| if not input_ids_batch: | |
| return [text] | |
| windows = [ | |
| tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| for ids in input_ids_batch | |
| ] | |
| return [window for window in windows if window.strip()] or [text] | |
| def normalize_outputs(outputs): | |
| if not outputs: | |
| return [] | |
| if isinstance(outputs, list) and outputs and isinstance(outputs[0], dict): | |
| return [outputs] | |
| return outputs | |
| def predict(text: str, threshold: float) -> Dict[str, float]: | |
| if load_error is not None: | |
| return {"error": load_error} | |
| if not text or not text.strip(): | |
| return {"error": "请输入文本。"} | |
| windows = split_windows(text) | |
| outputs = classifier(windows, truncation=True, max_length=MAX_LENGTH) | |
| outputs = normalize_outputs(outputs) | |
| label_scores: Dict[str, float] = {} | |
| for window_result in outputs: | |
| for item in window_result: | |
| label = map_label_name(str(item.get("label", "UNKNOWN"))) | |
| score = float(item.get("score", 0.0)) | |
| label_scores[label] = max(score, label_scores.get(label, 0.0)) | |
| items = sorted(label_scores.items(), key=lambda pair: pair[1], reverse=True) | |
| selected = [(label, score) for label, score in items if score >= threshold] | |
| if not selected: | |
| return {"info": f"无标签达到当前阈值 {threshold:.2f},请尝试降低阈值以查看更多结果。"} | |
| return {label: round(score, 6) for label, score in selected} | |
| app = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox( | |
| lines=8, | |
| label="输入文本后,可调整阈值以选择不同置信度水平下的标签", | |
| placeholder="请输入待分类文本...", | |
| value=DEFAULT_SAMPLE_TEXT, | |
| ), | |
| gr.Slider(minimum=0.0, maximum=1.0, value=THRESHOLD_DEFAULT, step=0.01, label="阈值"), | |
| ], | |
| outputs=gr.JSON(label="预测结果(标签:置信度)"), | |
| title="明/清实录多标签分类推理", | |
| #examples=[[DEFAULT_SAMPLE_TEXT, THRESHOLD_DEFAULT]], | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |