yushize commited on
Commit
4010573
·
verified ·
1 Parent(s): c41ae26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -47
app.py CHANGED
@@ -1,82 +1,327 @@
1
  import os
 
 
2
  import numpy as np
3
  import torch
 
4
  from transformers import AutoModel, AutoTokenizer
5
  import gradio as gr
6
 
7
- # 配置参数
8
- MODEL_NAME = "yushize/patent-classifier-backup"
9
- THRESHOLDS = np.array([0.55, 0.35, 0.45, 0.5, 0.4, 0.35, 0.45, 0.5, 0.45])
10
- CLASS_INDEX_START = 0 # 类别索引从0开始
 
 
11
 
12
- # 加载模型和分词器
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
15
- model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True).to(device)
16
- model.eval()
 
 
 
 
17
 
18
- # 类别名称(根据您的实际类别修改)
 
 
 
 
 
 
 
 
19
  CLASS_NAMES = [
20
  "非AI类", "知识处理", "语音识别", "AI硬件", "进化计算",
21
  "自然语言处理", "机器学习", "计算机视觉", "规划与控制"
22
  ]
23
 
24
- def predict(text):
25
- """对单条文本进行预测"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if not isinstance(text, str) or not text.strip():
27
  return {}
28
-
29
- # 分词和编码
 
 
 
 
 
 
 
 
30
  inputs = tokenizer(
31
  text,
32
  padding=True,
33
  truncation=True,
34
  max_length=256,
35
  return_tensors="pt"
36
- ).to(device)
37
-
38
- # 推理
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
-
42
- logits = outputs
43
- probabilities = torch.sigmoid(logits).cpu().numpy()[0]
44
-
45
- # 获取超过阈值的类别索引
 
 
46
  predicted_indices = []
47
  for i, prob in enumerate(probabilities):
48
- if prob >= THRESHOLDS[i]:
49
- # CLASS_INDEX_START 是类别索引的起始值(0或1)
50
- class_idx = i + CLASS_INDEX_START
51
  predicted_indices.append(class_idx)
52
-
53
- # 如果没有超过阈值的类别,返回概率最高的类别
54
  if not predicted_indices:
55
- max_idx = np.argmax(probabilities)
56
  predicted_indices = [max_idx + CLASS_INDEX_START]
57
-
58
- # 创建gr.Label需要的输出格式
59
  result = {}
60
  for idx in predicted_indices:
61
- if idx < len(CLASS_NAMES):
62
  result[CLASS_NAMES[idx]] = float(probabilities[idx - CLASS_INDEX_START])
63
-
64
  return result
65
 
66
- # 创建 Gradio 界面
67
- description = "专利分类器 - 输入专利摘要文本,模型将预测所属类别"
68
-
69
- iface = gr.Interface(
70
- fn=predict,
71
- inputs=gr.Textbox(label="专利摘要", lines=5, placeholder="请输入专利摘要文本..."),
72
- outputs=gr.Label(label="预测类别", num_top_classes=9),
73
- title="专利分类器",
74
- description=description,
75
- examples=[
76
- ["本发明提供一种录制与播放用户语音的方法以及使用此方法的电子字 典,所述的方法适用于一电子装置,其中电子装置至少包括一屏幕、一发音 指示、一录音指示键以及一存储器。当屏幕显示一屏幕显示数据的发音指示 时,按下录音指示键以进入语音录制模式。接着输入用户语音,并且储存用 户语音至存储器。之后记录用以显示用户语音在存储器中的位置的存储器地 址或录制数据索引,并将存储器地址或录制数据索引链结至屏幕显示数据。 本发明还提供一种使用上述方法的电子字典。"],
77
- ["本发明提供一种多业务接入网的控制系统,包括:对用户进行认证、授权 和地址分配的用户管理功能体UMF、管理用户在接入网络中各网元之间链路和 资源的链路管理功能体LMF以及根据用户属性进行资源接纳控制和策略执行或 部署到接入节点AN和接入网网络侧边缘ANE之间的网络设备中的策略执行功 能体PEF,所述链路管理功能体、策略执行功能体分别与接入节点、接入网网 络侧边缘相连,所述用户管理功能体与接入网网络侧边缘相连,所述策略执行 功能体分别与所述链路管理功能体、用户管理功能体相连。本发明还提供一种 多业务接入网的控制方法。本发明通过链路管理功能体对不同的业务用不同的 链路进行区分,保证多业务在接入网中的QoS,解决了用户动态接入多业务和 基于多业务的接入网QoS控制的问题。"],
78
- ["本发明关于用于一电子装置的影像还原方法及其相关装置。为了更有效 率地还原模糊的影像,本发明提供一种用于一电子装置的影像还原方法,包 含有于接收一被摄物的一影像时,产生一加速度信号;测量该电子装置与该 被摄物之间的距离,以产生一物距;以及根据该加速度信号及该物距,还原 该影像"]
79
- ]
80
  )
81
 
82
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gc
3
+ import shutil
4
  import numpy as np
5
  import torch
6
+ from typing import Dict, Tuple, List
7
  from transformers import AutoModel, AutoTokenizer
8
  import gradio as gr
9
 
10
+ # 兼容新旧 huggingface_hub 版本(老版本没有 delete_cache_entries)
11
+ try:
12
+ from huggingface_hub import scan_cache_dir, delete_cache_entries
13
+ except Exception:
14
+ from huggingface_hub import scan_cache_dir
15
+ delete_cache_entries = None
16
 
17
+ # -----------------------
18
+ # 配置参数
19
+ # -----------------------
20
+ MODEL_OPTIONS = {
21
+ "Qwen3-0.6B (xulab-research/patent-classifier-0.6B)": "xulab-research/patent-classifier-0.6B",
22
+ "Qwen3-4B (xulab-research/patent-classifier-4B)": "xulab-research/patent-classifier-4B",
23
+ }
24
+ # 修复:去掉尾部多余空格,避免 Dropdown 默认值和字典 key 不匹配
25
+ DEFAULT_MODEL_KEY = "Qwen3-4B (xulab-research/patent-classifier-4B)"
26
 
27
+ THRESHOLDS_BY_MODEL = {
28
+ "xulab-research/patent-classifier": np.array(
29
+ [0.55, 0.35, 0.45, 0.5, 0.4, 0.35, 0.45, 0.5, 0.45], dtype=np.float32
30
+ ),
31
+ "xulab-research/patent-classifier-4B": np.array(
32
+ [0.5, 0.3, 0.35, 0.3, 0.15, 0.3, 0.4, 0.55, 0.35], dtype=np.float32
33
+ ),
34
+ }
35
+ CLASS_INDEX_START = 0
36
  CLASS_NAMES = [
37
  "非AI类", "知识处理", "语音识别", "AI硬件", "进化计算",
38
  "自然语言处理", "机器学习", "计算机视觉", "规划与控制"
39
  ]
40
 
41
+ # -----------------------
42
+ # 设备与 dtype
43
+ # -----------------------
44
+ def pick_device() -> torch.device:
45
+ if torch.cuda.is_available():
46
+ return torch.device("cuda")
47
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
48
+ return torch.device("mps")
49
+ return torch.device("cpu")
50
+
51
+ DEVICE = pick_device()
52
+
53
+ def device_str(d: torch.device) -> str:
54
+ if d.type == "cuda":
55
+ return f"CUDA({torch.cuda.get_device_name(0)})"
56
+ if d.type == "mps":
57
+ return "Apple Silicon(MPS)"
58
+ return "CPU"
59
+
60
+ def preferred_dtype() -> torch.dtype:
61
+ if DEVICE.type == "cuda":
62
+ return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
63
+ if DEVICE.type == "mps":
64
+ return torch.float16
65
+ return torch.float32
66
+
67
+ def get_device_map() -> str:
68
+ if torch.cuda.is_available():
69
+ return "auto"
70
+ if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
71
+ return "mps"
72
+ return "cpu"
73
+
74
+ # 可选:固定远端模型 revision,避免缓存多版本堆积
75
+ HF_REVISION = os.getenv("HF_REVISION", None)
76
+
77
+ # 可选:自定义缓存目录(如 /tmp/hf_cache,重启即清)
78
+ HF_CACHE_DIR = os.getenv("HF_HUB_CACHE", None)
79
+ if HF_CACHE_DIR:
80
+ os.environ["HF_HOME"] = HF_CACHE_DIR
81
+ os.environ["HF_HUB_CACHE"] = HF_CACHE_DIR
82
+ os.environ["TRANSFORMERS_CACHE"] = HF_CACHE_DIR
83
+
84
+ # -----------------------
85
+ # 模型缓存(按需加载)
86
+ # -----------------------
87
+ MODEL_CACHE: Dict[str, Tuple[AutoTokenizer, AutoModel]] = {}
88
+
89
+ def clear_cuda_cache():
90
+ if torch.cuda.is_available():
91
+ torch.cuda.empty_cache()
92
+ gc.collect()
93
+
94
+ def load_model(hf_repo: str) -> Tuple[AutoTokenizer, AutoModel]:
95
+ if hf_repo in MODEL_CACHE:
96
+ return MODEL_CACHE[hf_repo]
97
+
98
+ def _do_load():
99
+ tok_kwargs = dict(trust_remote_code=True)
100
+ mdl_kwargs = dict(
101
+ trust_remote_code=True,
102
+ device_map=get_device_map(),
103
+ low_cpu_mem_usage=True,
104
+ use_safetensors=True,
105
+ )
106
+ if HF_REVISION:
107
+ tok_kwargs["revision"] = HF_REVISION
108
+ mdl_kwargs["revision"] = HF_REVISION
109
+ if HF_CACHE_DIR:
110
+ tok_kwargs["cache_dir"] = HF_CACHE_DIR
111
+ mdl_kwargs["cache_dir"] = HF_CACHE_DIR
112
+
113
+ dtype = preferred_dtype()
114
+ if mdl_kwargs["device_map"] in ("auto", "cuda", "mps"):
115
+ mdl_kwargs["torch_dtype"] = dtype
116
+
117
+ tokenizer = AutoTokenizer.from_pretrained(hf_repo, **tok_kwargs)
118
+ model = AutoModel.from_pretrained(hf_repo, **mdl_kwargs)
119
+ model.eval()
120
+ return tokenizer, model
121
+
122
+ try:
123
+ tokenizer, model = _do_load()
124
+ MODEL_CACHE[hf_repo] = (tokenizer, model)
125
+ return tokenizer, model
126
+ except RuntimeError as e:
127
+ if "out of memory" in str(e).lower() or "cuda" in str(e).lower():
128
+ clear_cuda_cache()
129
+ tokenizer, model = _do_load()
130
+ MODEL_CACHE[hf_repo] = (tokenizer, model)
131
+ return tokenizer, model
132
+ raise
133
+
134
+ def extract_logits(outputs):
135
+ if isinstance(outputs, torch.Tensor):
136
+ return outputs
137
+ if hasattr(outputs, "logits"):
138
+ return outputs.logits
139
+ if isinstance(outputs, (tuple, list)) and len(outputs) > 0 and isinstance(outputs[0], torch.Tensor):
140
+ return outputs[0]
141
+ raise ValueError("无法识别模型输出格式:期望 tensor / obj.logits / tuple[0] 为 tensor")
142
+
143
+ # -----------------------
144
+ # 缓存清理工具(兼容无 delete_cache_entries 的旧版 hub)
145
+ # -----------------------
146
+ def _infer_hub_dir_from_env() -> str:
147
+ if HF_CACHE_DIR:
148
+ base = HF_CACHE_DIR
149
+ else:
150
+ base = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
151
+ return os.path.join(base, "hub")
152
+
153
+ def clean_hf_cache(keep_repo_ids: List[str]) -> str:
154
+ """
155
+ 仅保留 keep_repo_ids 的缓存,清理其它模型/旧 revision。
156
+ """
157
+ try:
158
+ info = scan_cache_dir()
159
+ if delete_cache_entries is not None:
160
+ to_delete = []
161
+ for repo in info.repos:
162
+ if repo.repo_id not in keep_repo_ids:
163
+ to_delete.extend(repo.revisions.values())
164
+ else:
165
+ if HF_REVISION:
166
+ for rev in list(repo.revisions.values()):
167
+ if rev.commit_hash != HF_REVISION:
168
+ to_delete.append(rev)
169
+ if to_delete:
170
+ delete_cache_entries(to_delete)
171
+ return "缓存清理完成。"
172
+
173
+ # 旧版 fallback:直接按目录删除
174
+ hub_root = getattr(info, "cache_dir", None)
175
+ hub_dir = os.path.join(hub_root, "hub") if hub_root else _infer_hub_dir_from_env()
176
+ if not os.path.isdir(hub_dir):
177
+ return f"未找到缓存目录({hub_dir}),无需清理。"
178
+
179
+ removed = 0
180
+ keep_set = set(keep_repo_ids)
181
+ for entry in os.listdir(hub_dir):
182
+ path = os.path.join(hub_dir, entry)
183
+ if not os.path.isdir(path):
184
+ continue
185
+ if entry.startswith("models--"):
186
+ parts = entry.split("--", 2)
187
+ repo_id = f"{parts[1]}/{parts[2]}" if len(parts) >= 3 else None
188
+ if repo_id and repo_id in keep_set:
189
+ if HF_REVISION:
190
+ snapshots = os.path.join(path, "snapshots")
191
+ if os.path.isdir(snapshots):
192
+ for rev in os.listdir(snapshots):
193
+ if rev != HF_REVISION:
194
+ shutil.rmtree(os.path.join(snapshots, rev), ignore_errors=True)
195
+ continue
196
+ shutil.rmtree(path, ignore_errors=True)
197
+ removed += 1
198
+ elif entry.startswith(("datasets--", "spaces--")):
199
+ shutil.rmtree(path, ignore_errors=True)
200
+ removed += 1
201
+ return f"缓存清理完成(删除目录数:{removed})。"
202
+ except Exception as e:
203
+ return f"缓存清理失败:{e}"
204
+
205
+ # -----------------------
206
+ # 推理
207
+ # -----------------------
208
+ def predict(text: str, model_choice: str):
209
  if not isinstance(text, str) or not text.strip():
210
  return {}
211
+
212
+ hf_repo = MODEL_OPTIONS.get(model_choice, MODEL_OPTIONS[DEFAULT_MODEL_KEY])
213
+ tokenizer, model = load_model(hf_repo)
214
+
215
+ default_thr = THRESHOLDS_BY_MODEL.get(
216
+ "xulab-research/patent-classifier-4B",
217
+ np.array([0.5] * len(CLASS_NAMES), dtype=np.float32),
218
+ )
219
+ thresholds = THRESHOLDS_BY_MODEL.get(hf_repo, default_thr)
220
+
221
  inputs = tokenizer(
222
  text,
223
  padding=True,
224
  truncation=True,
225
  max_length=256,
226
  return_tensors="pt"
227
+ )
 
 
228
  with torch.no_grad():
229
  outputs = model(**inputs)
230
+ logits = extract_logits(outputs)
231
+
232
+ if logits.ndim == 1:
233
+ logits = logits.unsqueeze(0)
234
+
235
+ probabilities = torch.sigmoid(logits).detach().cpu().numpy()[0]
236
+
237
  predicted_indices = []
238
  for i, prob in enumerate(probabilities):
239
+ thr = thresholds[i] if i < len(thresholds) else 0.5
240
+ if prob >= thr:
241
+ class_idx = i + CLASS_INDEX_START
242
  predicted_indices.append(class_idx)
243
+
 
244
  if not predicted_indices:
245
+ max_idx = int(np.argmax(probabilities))
246
  predicted_indices = [max_idx + CLASS_INDEX_START]
247
+
 
248
  result = {}
249
  for idx in predicted_indices:
250
+ if 0 <= idx - CLASS_INDEX_START < len(probabilities) and idx < len(CLASS_NAMES):
251
  result[CLASS_NAMES[idx]] = float(probabilities[idx - CLASS_INDEX_START])
252
+
253
  return result
254
 
255
+ # -----------------------
256
+ # 界面
257
+ # -----------------------
258
+ description = (
259
+ "专利分类器 - 输入专利摘要文本,模型将预测所属类别。\n\n"
260
+ "支持选择两种模型:\n"
261
+ f"- 0.6B:{MODEL_OPTIONS['Qwen3-0.6B (xulab-research/patent-classifier-0.6B)']}\n"
262
+ f"- 4B:{MODEL_OPTIONS['Qwen3-4B (xulab-research/patent-classifier-4B)']}\n\n"
263
+ f"当前设备:{device_str(DEVICE)}"
 
 
 
 
 
264
  )
265
 
266
+ ai_subfields_text = """8个AI子领域\n
267
+ 1.知识处理:知识处理领域包括用于表示世界事实并从知识库中推导出新事实(或知识)的方法。例如,专家系统通常包含一个知识库和一种推理方法来从该知识库中获得新事实。\n
268
+ 2.语音识别:语音识别包括从音频信号中理解词语序列的方法。例如,噪声通道模型是一种统计方法,通过贝叶斯规则从语音输入中识别最可能的词语序列。\n
269
+ 3.AI硬件:AI硬件领域包括专为执行人工智能软件而设计的物理硬件。例如,谷歌设计的张量处理单元(TPU)就是为了更高效地运行神经网络算法。AI硬件可能包括逻辑电路、存储器、视频、处理器和固态技术,也可能包括实现其他AI组成技术(如机器学习算法)的嵌入式软件。\n
270
+ 4.进化计算:进化计算是一类利用自然演化特性的计算方法。例如,遗传算法通过选择最优的随机变异体以最大化适应度来执行算法变异选择。\n
271
+ 5.自然语言处理:自然语言处理包括用于理解和使用以人类自然语言编码的数据的方法。例如,语言模型用于表示语言表达的概率分布。\n
272
+ 6.机器学习:机器学习领域包含一类广泛的计算学习模型。例如,监督学习分类模型是一种基于预标记训练数据学习进行分类的算法。机器学习技术包括但不限于神经网络、模糊逻辑、自适应系统、概率网络、回归分析以及智能搜索。\n
273
+ 7.计算机视觉:计算机视觉领域包括从图像和视频等视觉输入中提取和理解信息的方法。例如,边缘检测技术可识别图像中的边界和轮廓。其他计算机视觉子领域还包括目标识别、图像处理(如变换、增强或还原)、颜色处理和格式转换等。\n
274
+ 8.规划与控制:规划与控制领域包括识别并执行实现特定目标的计划的方法。规划的关键方面包括表示行动和世界状态、推理行动的后果,并在潜��计划中高效地搜索。现代控制理论包括在时间维度上最大化目标函数的方法。例如,随机最优控制处理在不确定环境中的动态优化问题。此外,规划与控制还涵盖用于管理/行政的数据系统(例如:组织和员工的管理,包括库存、工作流程、预测和时间管理)、自适应控制系统,以及系统模型或模拟器。
275
+ """
276
+
277
+ example_texts = [
278
+ "本发明提供一种录制与播放用户语音的方法以及使用此方法的电子字 典,所述的方法适用于一电子装置,其中电子装置至少包括一屏幕、一发音 指示、一录音指示键以及一存储器。当屏幕显示一屏幕显示数据的发音指示 时,按下录音指示键以进入语音录制模式。接着输入用户语音,并且储存用 户语音至存储器。之后记录用以显示用户语音在存储器中的位置的存储器地 址或录制数据索引,并将存储器地址或录制数据索引链结至屏幕显示数据。 本发明还提供一种使用上述方法的电子字典。",
279
+ "本发明提供一种多业务接入网的控制系统,包括:对用户进行认证、授权 和地址分配的用户管理功能体UMF、管理用户在接入网络中各网元之间链路和 资源的链路管理功能体LMF以及根据用户属性进行资源接纳控制和策略执行或 部署到接入节点AN和接入网网络侧边缘ANE之间的网络设备中的策略执行功 能体PEF,所述链路管理功能体、策略执行功能体分别与接入节点、接入网网 络侧边缘相连,所述用户管理功能体与接入网网络侧边缘相连,所述策略执行 功能体分别与所述链路管理功能体、用户管理功能体相连。本发明还提供一种 多业务接入网的控制方法。本发明通过链路管理功能体对不同的业务用不同的 链路进行区分,保证多业务在接入网中的QoS,解决了用户动态接入多业务和 基于多业务的接入网QoS控制的问题。",
280
+ "本发明关于用于一电子装置的影像还原方法及其相关装置。为了更有效 率地还原模糊的影像,本发明提供一种用于一电子装置的影像还原方法,包 含有于接收一被摄物的一影像时,产生一加速度信号;测量该电子装置与该 被摄物之间的距离,以产生一物距;以及根据该加速度信号及该物距,还原 该影像",
281
+ ]
282
+
283
+
284
+ with gr.Blocks(title="专利分类器") as demo:
285
+ gr.Markdown("# 专利分类器")
286
+ gr.Markdown(description)
287
+
288
+ with gr.Row():
289
+ model_choice = gr.Dropdown(
290
+ label="选择模型",
291
+ choices=list(MODEL_OPTIONS.keys()),
292
+ value=DEFAULT_MODEL_KEY
293
+ )
294
+
295
+ input_box = gr.Textbox(label="专利摘要", lines=5, placeholder="请输入专利摘要文本...")
296
+ with gr.Row():
297
+ predict_btn = gr.Button("预测", variant="primary")
298
+ clear_btn = gr.Button("清空")
299
+ #clean_btn = gr.Button("清理缓存(仅保留当前所选模型)", variant="secondary")
300
+
301
+ output_label = gr.Label(label="预测类别", num_top_classes=len(CLASS_NAMES))
302
+
303
+ gr.Examples(
304
+ examples=[[t, DEFAULT_MODEL_KEY] for t in example_texts],
305
+ inputs=[input_box, model_choice],
306
+ label="示例"
307
+ )
308
+
309
+ predict_btn.click(
310
+ fn=predict,
311
+ inputs=[input_box, model_choice],
312
+ outputs=output_label,
313
+ concurrency_limit=1
314
+ )
315
+ clear_btn.click(lambda: ("", {}), outputs=[input_box, output_label])
316
+
317
+ def _clean(selected_key: str):
318
+ repo_id = MODEL_OPTIONS.get(selected_key, MODEL_OPTIONS[DEFAULT_MODEL_KEY])
319
+ msg = clean_hf_cache([repo_id])
320
+ return gr.update(value={}), gr.update(value=""), msg
321
+
322
+ clean_status = gr.Markdown("")
323
+ #clean_btn.click(_clean, inputs=[model_choice], outputs=[output_label, input_box, clean_status])
324
+
325
+ gr.Markdown(ai_subfields_text)
326
+
327
+ demo.launch()