yushize's picture
Update app.py
4010573 verified
import os
import gc
import shutil
import numpy as np
import torch
from typing import Dict, Tuple, List
from transformers import AutoModel, AutoTokenizer
import gradio as gr
# 兼容新旧 huggingface_hub 版本(老版本没有 delete_cache_entries)
try:
from huggingface_hub import scan_cache_dir, delete_cache_entries
except Exception:
from huggingface_hub import scan_cache_dir
delete_cache_entries = None
# -----------------------
# 配置参数
# -----------------------
MODEL_OPTIONS = {
"Qwen3-0.6B (xulab-research/patent-classifier-0.6B)": "xulab-research/patent-classifier-0.6B",
"Qwen3-4B (xulab-research/patent-classifier-4B)": "xulab-research/patent-classifier-4B",
}
# 修复:去掉尾部多余空格,避免 Dropdown 默认值和字典 key 不匹配
DEFAULT_MODEL_KEY = "Qwen3-4B (xulab-research/patent-classifier-4B)"
THRESHOLDS_BY_MODEL = {
"xulab-research/patent-classifier": np.array(
[0.55, 0.35, 0.45, 0.5, 0.4, 0.35, 0.45, 0.5, 0.45], dtype=np.float32
),
"xulab-research/patent-classifier-4B": np.array(
[0.5, 0.3, 0.35, 0.3, 0.15, 0.3, 0.4, 0.55, 0.35], dtype=np.float32
),
}
CLASS_INDEX_START = 0
CLASS_NAMES = [
"非AI类", "知识处理", "语音识别", "AI硬件", "进化计算",
"自然语言处理", "机器学习", "计算机视觉", "规划与控制"
]
# -----------------------
# 设备与 dtype
# -----------------------
def pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
DEVICE = pick_device()
def device_str(d: torch.device) -> str:
if d.type == "cuda":
return f"CUDA({torch.cuda.get_device_name(0)})"
if d.type == "mps":
return "Apple Silicon(MPS)"
return "CPU"
def preferred_dtype() -> torch.dtype:
if DEVICE.type == "cuda":
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
if DEVICE.type == "mps":
return torch.float16
return torch.float32
def get_device_map() -> str:
if torch.cuda.is_available():
return "auto"
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return "mps"
return "cpu"
# 可选:固定远端模型 revision,避免缓存多版本堆积
HF_REVISION = os.getenv("HF_REVISION", None)
# 可选:自定义缓存目录(如 /tmp/hf_cache,重启即清)
HF_CACHE_DIR = os.getenv("HF_HUB_CACHE", None)
if HF_CACHE_DIR:
os.environ["HF_HOME"] = HF_CACHE_DIR
os.environ["HF_HUB_CACHE"] = HF_CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
# -----------------------
# 模型缓存(按需加载)
# -----------------------
MODEL_CACHE: Dict[str, Tuple[AutoTokenizer, AutoModel]] = {}
def clear_cuda_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def load_model(hf_repo: str) -> Tuple[AutoTokenizer, AutoModel]:
if hf_repo in MODEL_CACHE:
return MODEL_CACHE[hf_repo]
def _do_load():
tok_kwargs = dict(trust_remote_code=True)
mdl_kwargs = dict(
trust_remote_code=True,
device_map=get_device_map(),
low_cpu_mem_usage=True,
use_safetensors=True,
)
if HF_REVISION:
tok_kwargs["revision"] = HF_REVISION
mdl_kwargs["revision"] = HF_REVISION
if HF_CACHE_DIR:
tok_kwargs["cache_dir"] = HF_CACHE_DIR
mdl_kwargs["cache_dir"] = HF_CACHE_DIR
dtype = preferred_dtype()
if mdl_kwargs["device_map"] in ("auto", "cuda", "mps"):
mdl_kwargs["torch_dtype"] = dtype
tokenizer = AutoTokenizer.from_pretrained(hf_repo, **tok_kwargs)
model = AutoModel.from_pretrained(hf_repo, **mdl_kwargs)
model.eval()
return tokenizer, model
try:
tokenizer, model = _do_load()
MODEL_CACHE[hf_repo] = (tokenizer, model)
return tokenizer, model
except RuntimeError as e:
if "out of memory" in str(e).lower() or "cuda" in str(e).lower():
clear_cuda_cache()
tokenizer, model = _do_load()
MODEL_CACHE[hf_repo] = (tokenizer, model)
return tokenizer, model
raise
def extract_logits(outputs):
if isinstance(outputs, torch.Tensor):
return outputs
if hasattr(outputs, "logits"):
return outputs.logits
if isinstance(outputs, (tuple, list)) and len(outputs) > 0 and isinstance(outputs[0], torch.Tensor):
return outputs[0]
raise ValueError("无法识别模型输出格式:期望 tensor / obj.logits / tuple[0] 为 tensor")
# -----------------------
# 缓存清理工具(兼容无 delete_cache_entries 的旧版 hub)
# -----------------------
def _infer_hub_dir_from_env() -> str:
if HF_CACHE_DIR:
base = HF_CACHE_DIR
else:
base = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
return os.path.join(base, "hub")
def clean_hf_cache(keep_repo_ids: List[str]) -> str:
"""
仅保留 keep_repo_ids 的缓存,清理其它模型/旧 revision。
"""
try:
info = scan_cache_dir()
if delete_cache_entries is not None:
to_delete = []
for repo in info.repos:
if repo.repo_id not in keep_repo_ids:
to_delete.extend(repo.revisions.values())
else:
if HF_REVISION:
for rev in list(repo.revisions.values()):
if rev.commit_hash != HF_REVISION:
to_delete.append(rev)
if to_delete:
delete_cache_entries(to_delete)
return "缓存清理完成。"
# 旧版 fallback:直接按目录删除
hub_root = getattr(info, "cache_dir", None)
hub_dir = os.path.join(hub_root, "hub") if hub_root else _infer_hub_dir_from_env()
if not os.path.isdir(hub_dir):
return f"未找到缓存目录({hub_dir}),无需清理。"
removed = 0
keep_set = set(keep_repo_ids)
for entry in os.listdir(hub_dir):
path = os.path.join(hub_dir, entry)
if not os.path.isdir(path):
continue
if entry.startswith("models--"):
parts = entry.split("--", 2)
repo_id = f"{parts[1]}/{parts[2]}" if len(parts) >= 3 else None
if repo_id and repo_id in keep_set:
if HF_REVISION:
snapshots = os.path.join(path, "snapshots")
if os.path.isdir(snapshots):
for rev in os.listdir(snapshots):
if rev != HF_REVISION:
shutil.rmtree(os.path.join(snapshots, rev), ignore_errors=True)
continue
shutil.rmtree(path, ignore_errors=True)
removed += 1
elif entry.startswith(("datasets--", "spaces--")):
shutil.rmtree(path, ignore_errors=True)
removed += 1
return f"缓存清理完成(删除目录数:{removed})。"
except Exception as e:
return f"缓存清理失败:{e}"
# -----------------------
# 推理
# -----------------------
def predict(text: str, model_choice: str):
if not isinstance(text, str) or not text.strip():
return {}
hf_repo = MODEL_OPTIONS.get(model_choice, MODEL_OPTIONS[DEFAULT_MODEL_KEY])
tokenizer, model = load_model(hf_repo)
default_thr = THRESHOLDS_BY_MODEL.get(
"xulab-research/patent-classifier-4B",
np.array([0.5] * len(CLASS_NAMES), dtype=np.float32),
)
thresholds = THRESHOLDS_BY_MODEL.get(hf_repo, default_thr)
inputs = tokenizer(
text,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
logits = extract_logits(outputs)
if logits.ndim == 1:
logits = logits.unsqueeze(0)
probabilities = torch.sigmoid(logits).detach().cpu().numpy()[0]
predicted_indices = []
for i, prob in enumerate(probabilities):
thr = thresholds[i] if i < len(thresholds) else 0.5
if prob >= thr:
class_idx = i + CLASS_INDEX_START
predicted_indices.append(class_idx)
if not predicted_indices:
max_idx = int(np.argmax(probabilities))
predicted_indices = [max_idx + CLASS_INDEX_START]
result = {}
for idx in predicted_indices:
if 0 <= idx - CLASS_INDEX_START < len(probabilities) and idx < len(CLASS_NAMES):
result[CLASS_NAMES[idx]] = float(probabilities[idx - CLASS_INDEX_START])
return result
# -----------------------
# 界面
# -----------------------
description = (
"专利分类器 - 输入专利摘要文本,模型将预测所属类别。\n\n"
"支持选择两种模型:\n"
f"- 0.6B:{MODEL_OPTIONS['Qwen3-0.6B (xulab-research/patent-classifier-0.6B)']}\n"
f"- 4B:{MODEL_OPTIONS['Qwen3-4B (xulab-research/patent-classifier-4B)']}\n\n"
f"当前设备:{device_str(DEVICE)}"
)
ai_subfields_text = """8个AI子领域\n
1.知识处理:知识处理领域包括用于表示世界事实并从知识库中推导出新事实(或知识)的方法。例如,专家系统通常包含一个知识库和一种推理方法来从该知识库中获得新事实。\n
2.语音识别:语音识别包括从音频信号中理解词语序列的方法。例如,噪声通道模型是一种统计方法,通过贝叶斯规则从语音输入中识别最可能的词语序列。\n
3.AI硬件:AI硬件领域包括专为执行人工智能软件而设计的物理硬件。例如,谷歌设计的张量处理单元(TPU)就是为了更高效地运行神经网络算法。AI硬件可能包括逻辑电路、存储器、视频、处理器和固态技术,也可能包括实现其他AI组成技术(如机器学习算法)的嵌入式软件。\n
4.进化计算:进化计算是一类利用自然演化特性的计算方法。例如,遗传算法通过选择最优的随机变异体以最大化适应度来执行算法变异选择。\n
5.自然语言处理:自然语言处理包括用于理解和使用以人类自然语言编码的数据的方法。例如,语言模型用于表示语言表达的概率分布。\n
6.机器学习:机器学习领域包含一类广泛的计算学习模型。例如,监督学习分类模型是一种基于预标记训练数据学习进行分类的算法。机器学习技术包括但不限于神经网络、模糊逻辑、自适应系统、概率网络、回归分析以及智能搜索。\n
7.计算机视觉:计算机视觉领域包括从图像和视频等视觉输入中提取和理解信息的方法。例如,边缘检测技术可识别图像中的边界和轮廓。其他计算机视觉子领域还包括目标识别、图像处理(如变换、增强或还原)、颜色处理和格式转换等。\n
8.规划与控制:规划与控制领域包括识别并执行实现特定目标的计划的方法。规划的关键方面包括表示行动和世界状态、推理行动的后果,并在潜在计划中高效地搜索。现代控制理论包括在时间维度上最大化目标函数的方法。例如,随机最优控制处理在不确定环境中的动态优化问题。此外,规划与控制还涵盖用于管理/行政的数据系统(例如:组织和员工的管理,包括库存、工作流程、预测和时间管理)、自适应控制系统,以及系统模型或模拟器。
"""
example_texts = [
"本发明提供一种录制与播放用户语音的方法以及使用此方法的电子字 典,所述的方法适用于一电子装置,其中电子装置至少包括一屏幕、一发音 指示、一录音指示键以及一存储器。当屏幕显示一屏幕显示数据的发音指示 时,按下录音指示键以进入语音录制模式。接着输入用户语音,并且储存用 户语音至存储器。之后记录用以显示用户语音在存储器中的位置的存储器地 址或录制数据索引,并将存储器地址或录制数据索引链结至屏幕显示数据。 本发明还提供一种使用上述方法的电子字典。",
"本发明提供一种多业务接入网的控制系统,包括:对用户进行认证、授权 和地址分配的用户管理功能体UMF、管理用户在接入网络中各网元之间链路和 资源的链路管理功能体LMF以及根据用户属性进行资源接纳控制和策略执行或 部署到接入节点AN和接入网网络侧边缘ANE之间的网络设备中的策略执行功 能体PEF,所述链路管理功能体、策略执行功能体分别与接入节点、接入网网 络侧边缘相连,所述用户管理功能体与接入网网络侧边缘相连,所述策略执行 功能体分别与所述链路管理功能体、用户管理功能体相连。本发明还提供一种 多业务接入网的控制方法。本发明通过链路管理功能体对不同的业务用不同的 链路进行区分,保证多业务在接入网中的QoS,解决了用户动态接入多业务和 基于多业务的接入网QoS控制的问题。",
"本发明关于用于一电子装置的影像还原方法及其相关装置。为了更有效 率地还原模糊的影像,本发明提供一种用于一电子装置的影像还原方法,包 含有于接收一被摄物的一影像时,产生一加速度信号;测量该电子装置与该 被摄物之间的距离,以产生一物距;以及根据该加速度信号及该物距,还原 该影像",
]
with gr.Blocks(title="专利分类器") as demo:
gr.Markdown("# 专利分类器")
gr.Markdown(description)
with gr.Row():
model_choice = gr.Dropdown(
label="选择模型",
choices=list(MODEL_OPTIONS.keys()),
value=DEFAULT_MODEL_KEY
)
input_box = gr.Textbox(label="专利摘要", lines=5, placeholder="请输入专利摘要文本...")
with gr.Row():
predict_btn = gr.Button("预测", variant="primary")
clear_btn = gr.Button("清空")
#clean_btn = gr.Button("清理缓存(仅保留当前所选模型)", variant="secondary")
output_label = gr.Label(label="预测类别", num_top_classes=len(CLASS_NAMES))
gr.Examples(
examples=[[t, DEFAULT_MODEL_KEY] for t in example_texts],
inputs=[input_box, model_choice],
label="示例"
)
predict_btn.click(
fn=predict,
inputs=[input_box, model_choice],
outputs=output_label,
concurrency_limit=1
)
clear_btn.click(lambda: ("", {}), outputs=[input_box, output_label])
def _clean(selected_key: str):
repo_id = MODEL_OPTIONS.get(selected_key, MODEL_OPTIONS[DEFAULT_MODEL_KEY])
msg = clean_hf_cache([repo_id])
return gr.update(value={}), gr.update(value=""), msg
clean_status = gr.Markdown("")
#clean_btn.click(_clean, inputs=[model_choice], outputs=[output_label, input_box, clean_status])
gr.Markdown(ai_subfields_text)
demo.launch()