| |
| import os |
| import gradio as gr |
| import torch |
| import numpy as np |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| import traceback |
|
|
| |
| |
| LOCAL_MODEL_DIR = "./" |
| |
|
|
| def choose_device(): |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def list_dir(path): |
| try: |
| files = os.listdir(path) |
| return files |
| except Exception: |
| return [] |
|
|
| def load_model_and_tokenizer_local(model_dir, device): |
| """从本地目录加载 tokenizer 与 model(强制 local_files_only=True)""" |
| model_dir = os.path.abspath(model_dir) |
| print(f"[INFO] 尝试从本地加载模型:{model_dir} -> device: {device}") |
| print(f"[DEBUG] 目录列表:{list_dir(model_dir)}") |
| |
| must_files = ["config.json"] |
| missing = [f for f in must_files if not os.path.exists(os.path.join(model_dir, f))] |
| if missing: |
| raise FileNotFoundError(f"本地模型目录缺少必须文件: {missing}. 请确认已上传到 Space 根目录。") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, local_files_only=True) |
|
|
| |
| if device.type == "cuda": |
| |
| try: |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir, torch_dtype=torch.float16, local_files_only=True) |
| except Exception: |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir, local_files_only=True) |
| else: |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir, local_files_only=True) |
|
|
| model.to(device).eval() |
| return tokenizer, model |
|
|
| |
| LABEL_MAP = {0: "无害", 1: "有害"} |
|
|
| def predict_text(text, tokenizer, model, device, max_length=256): |
| toks = tokenizer([text], truncation=True, padding=True, max_length=max_length, return_tensors="pt") |
| toks = {k: v.to(device) for k, v in toks.items()} |
| with torch.no_grad(): |
| out = model(**toks) |
| logits = out.logits |
| logits_cpu = logits.detach().cpu() |
| id2label = getattr(model.config, "id2label", None) or {} |
| num_labels = getattr(model.config, "num_labels", logits_cpu.size(-1) if logits_cpu is not None else 1) |
|
|
| |
| if logits_cpu.dim() == 1 or logits_cpu.size(-1) == 1: |
| score = float(torch.sigmoid(logits_cpu.view(-1))[0].item()) |
| label_id = 1 if score >= 0.5 else 0 |
| label_str = LABEL_MAP.get(label_id, str(label_id)) |
| return { |
| "text": text, |
| "label_id": int(label_id), |
| "label": label_str, |
| "score": score, |
| "notes": "single-logit -> sigmoid" |
| } |
| else: |
| probs = torch.softmax(logits_cpu, dim=-1).numpy().tolist()[0] |
| if len(probs) == 2: |
| harmful_score = float(probs[1]) |
| label_id = 1 if harmful_score >= 0.5 else 0 |
| label_str = LABEL_MAP.get(label_id, str(label_id)) |
| return { |
| "text": text, |
| "label_id": int(label_id), |
| "label": label_str, |
| "score": harmful_score, |
| "probs": {"0": float(probs[0]), "1": float(probs[1])}, |
| "notes": "2-class softmax (index 1 = 有害)" |
| } |
| else: |
| best_idx = int(np.argmax(probs)) |
| best_score = float(probs[best_idx]) |
| if id2label: |
| raw_label = id2label.get(best_idx, str(best_idx)) |
| low = str(raw_label).lower() |
| if "safe" in low or "clean" in low: |
| label_cn = "无害" |
| elif "unsafe" in low or "nsfw" in low or "abuse" in low or "toxic" in low or "harm" in low: |
| label_cn = "有害" |
| elif "1" == str(best_idx): |
| label_cn = "有害" |
| else: |
| label_cn = str(raw_label) |
| else: |
| label_cn = "有害" if best_idx == 1 else ("无害" if best_idx == 0 else str(best_idx)) |
| return { |
| "text": text, |
| "label_id": int(best_idx), |
| "label": label_cn, |
| "score": best_score, |
| "probs": {str(i): float(p) for i,p in enumerate(probs)}, |
| "notes": "multi-class softmax" |
| } |
|
|
| |
| device = choose_device() |
| try: |
| tokenizer, model = load_model_and_tokenizer_local(LOCAL_MODEL_DIR, device) |
| READY = True |
| except Exception as e: |
| print("[ERROR] 本地模型加载失败:", e) |
| traceback.print_exc() |
| tokenizer, model = None, None |
| READY = False |
|
|
| def predict_gradio(text): |
| if not READY or tokenizer is None or model is None: |
| return {"error": "模型未就绪,请查看日志。", "local_files": list_dir(os.path.abspath(LOCAL_MODEL_DIR))} |
| try: |
| return predict_text(text, tokenizer, model, device) |
| except Exception as e: |
| return {"error": f"推理失败: {e}", "trace": traceback.format_exc()} |
|
|
| |
| with gr.Blocks(title="中文内容检查(本地模型)") as demo: |
| gr.Markdown("## 中文内容检测(0 = 无害,1 = 有害)\n- 使用本地上传到 Space 的模型文件进行离线推理。") |
| with gr.Row(): |
| txt = gr.Textbox(lines=4, label="输入中文文本") |
| out_json = gr.JSON(label="检测结果(JSON)") |
| btn = gr.Button("检测") |
| btn.click(fn=predict_gradio, inputs=[txt], outputs=[out_json]) |
| gr.Examples(examples=[ |
| ["今天心情很好,祝你好运!"], |
| ["我要去伤害他。"], |
| ["这是开玩笑的玩梗,不代表真实意思。"] |
| ], inputs=[txt]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|