Liuxinhao
Update app.py
47558c1 verified
# app.py(从本地加载模型版本)
import os
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import traceback
# ====== 配置:本地模型目录(相对于 Space 容器的 /home/user/app) ======
# 你把文件直接上传到 Space 根目录,使用 "./" 或 "/home/user/app"
LOCAL_MODEL_DIR = "./" # or "/home/user/app"
# =====================================================================
def choose_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def list_dir(path):
try:
files = os.listdir(path)
return files
except Exception:
return []
def load_model_and_tokenizer_local(model_dir, device):
"""从本地目录加载 tokenizer 与 model(强制 local_files_only=True)"""
model_dir = os.path.abspath(model_dir)
print(f"[INFO] 尝试从本地加载模型:{model_dir} -> device: {device}")
print(f"[DEBUG] 目录列表:{list_dir(model_dir)}")
# 确认必需文件是否存在
must_files = ["config.json"]
missing = [f for f in must_files if not os.path.exists(os.path.join(model_dir, f))]
if missing:
raise FileNotFoundError(f"本地模型目录缺少必须文件: {missing}. 请确认已上传到 Space 根目录。")
# 加载 tokenizer(强制离线)
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, local_files_only=True)
# 加载 model(支持 safetensors/pytorch 二进制)
if device.type == "cuda":
# 尝试半精度加载以节省显存(若权重支持)
try:
model = AutoModelForSequenceClassification.from_pretrained(model_dir, torch_dtype=torch.float16, local_files_only=True)
except Exception:
model = AutoModelForSequenceClassification.from_pretrained(model_dir, local_files_only=True)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_dir, local_files_only=True)
model.to(device).eval()
return tokenizer, model
# 0 -> 无害, 1 -> 有害
LABEL_MAP = {0: "无害", 1: "有害"}
def predict_text(text, tokenizer, model, device, max_length=256):
toks = tokenizer([text], truncation=True, padding=True, max_length=max_length, return_tensors="pt")
toks = {k: v.to(device) for k, v in toks.items()}
with torch.no_grad():
out = model(**toks)
logits = out.logits
logits_cpu = logits.detach().cpu()
id2label = getattr(model.config, "id2label", None) or {}
num_labels = getattr(model.config, "num_labels", logits_cpu.size(-1) if logits_cpu is not None else 1)
# single-logit binary
if logits_cpu.dim() == 1 or logits_cpu.size(-1) == 1:
score = float(torch.sigmoid(logits_cpu.view(-1))[0].item())
label_id = 1 if score >= 0.5 else 0
label_str = LABEL_MAP.get(label_id, str(label_id))
return {
"text": text,
"label_id": int(label_id),
"label": label_str,
"score": score,
"notes": "single-logit -> sigmoid"
}
else:
probs = torch.softmax(logits_cpu, dim=-1).numpy().tolist()[0]
if len(probs) == 2:
harmful_score = float(probs[1])
label_id = 1 if harmful_score >= 0.5 else 0
label_str = LABEL_MAP.get(label_id, str(label_id))
return {
"text": text,
"label_id": int(label_id),
"label": label_str,
"score": harmful_score,
"probs": {"0": float(probs[0]), "1": float(probs[1])},
"notes": "2-class softmax (index 1 = 有害)"
}
else:
best_idx = int(np.argmax(probs))
best_score = float(probs[best_idx])
if id2label:
raw_label = id2label.get(best_idx, str(best_idx))
low = str(raw_label).lower()
if "safe" in low or "clean" in low:
label_cn = "无害"
elif "unsafe" in low or "nsfw" in low or "abuse" in low or "toxic" in low or "harm" in low:
label_cn = "有害"
elif "1" == str(best_idx):
label_cn = "有害"
else:
label_cn = str(raw_label)
else:
label_cn = "有害" if best_idx == 1 else ("无害" if best_idx == 0 else str(best_idx))
return {
"text": text,
"label_id": int(best_idx),
"label": label_cn,
"score": best_score,
"probs": {str(i): float(p) for i,p in enumerate(probs)},
"notes": "multi-class softmax"
}
# === 程序入口:从本地加载一次(启动时) ===
device = choose_device()
try:
tokenizer, model = load_model_and_tokenizer_local(LOCAL_MODEL_DIR, device)
READY = True
except Exception as e:
print("[ERROR] 本地模型加载失败:", e)
traceback.print_exc()
tokenizer, model = None, None
READY = False
def predict_gradio(text):
if not READY or tokenizer is None or model is None:
return {"error": "模型未就绪,请查看日志。", "local_files": list_dir(os.path.abspath(LOCAL_MODEL_DIR))}
try:
return predict_text(text, tokenizer, model, device)
except Exception as e:
return {"error": f"推理失败: {e}", "trace": traceback.format_exc()}
# Gradio UI(不变)
with gr.Blocks(title="中文内容检查(本地模型)") as demo:
gr.Markdown("## 中文内容检测(0 = 无害,1 = 有害)\n- 使用本地上传到 Space 的模型文件进行离线推理。")
with gr.Row():
txt = gr.Textbox(lines=4, label="输入中文文本")
out_json = gr.JSON(label="检测结果(JSON)")
btn = gr.Button("检测")
btn.click(fn=predict_gradio, inputs=[txt], outputs=[out_json])
gr.Examples(examples=[
["今天心情很好,祝你好运!"],
["我要去伤害他。"],
["这是开玩笑的玩梗,不代表真实意思。"]
], inputs=[txt])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)