2ephyrh commited on
Commit
aa2a48a
·
verified ·
1 Parent(s): 32bab37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +430 -428
app.py CHANGED
@@ -1,429 +1,431 @@
1
- import gradio as gr
2
- from model import (
3
- cls_predict,
4
- det_predict,
5
- seg_predict,
6
- ALL_SEG_LABELS,
7
- ALL_DET_LABELS,
8
- ALL_SEG_COLOR_MAP,
9
- ALL_CLS_LABELS
10
- )
11
- import os
12
- import requests
13
- from PIL import Image
14
- from io import BytesIO
15
- import time # 用于重试
16
-
17
- # --- 配置 Logo 和版权信息 ---
18
- LOGO_PATH = "logo/logo.png" # 请将您的 Logo 图片存放在此路径
19
- # 🌟 更新:版权信息仅保留年份和权利声明
20
- COPYRIGHT_TEXT = "© 2025 All Rights Reserved."
21
- # COLLABORATION_EMAIL 变量已删除
22
- SCHOOL_NAME_EN = "School of Information Engineering, Wuhan University of Technology"
23
-
24
- # --- 自动下载示例图片逻辑 (嵌入) ---
25
-
26
- # 🌟 恢复到用户指定的 COCO ID 列表
27
- TASK_EXAMPLE_URLS = {
28
- "cls": [
29
- # 图像分类示例 (val2017)
30
- "http://images.cocodataset.org/val2017/000000000285.jpg", # 猫和键盘 (保留)
31
- "http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留)
32
- "http://images.cocodataset.org/val2017/000000000724.jpg", # /餐具 (保留)
33
- "http://images.cocodataset.org/val2017/000000001584.jpg", # 多人,多物体 (保留)
34
- "http://images.cocodataset.org/train2017/000000001097.jpg", # 原 cls_5
35
- ],
36
- "seg": [
37
- # 语义分割示例 (val2017)
38
- "http://images.cocodataset.org/val2017/000000000139.jpg", # 街景/汽车 (保留)
39
- "http://images.cocodataset.org/val2017/000000000632.jpg", # 街景/行人 (保留)
40
- "http://images.cocodataset.org/val2017/000000000885.jpg", # 滑板手 (保留)
41
- "http://images.cocodataset.org/train2017/000000000267.jpg", # 原 seg_4
42
- "http://images.cocodataset.org/train2017/000000001140.jpg", # 原 seg_5
43
- ],
44
- "det": [
45
- # 目标检测示例 (val2017)
46
- "http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留)
47
- "http://images.cocodataset.org/val2017/000000001268.jpg", # 原 det_2
48
- "http://images.cocodataset.org/train2017/000000001072.jpg", # 原 det_3
49
- "http://images.cocodataset.org/train2017/000000000119.jpg", # 原 det_4
50
- "http://images.cocodataset.org/train2017/000000000570.jpg", # 原 det_5
51
- ]
52
- }
53
-
54
- # 项目期望的本地路径
55
- OUTPUT_DIR = "examples"
56
-
57
-
58
- def download_and_save_examples(max_retries=3):
59
- """下载示例图片到本地 examples/ 目录,使用任务前缀命名,增加重试机制"""
60
- if not os.path.exists(OUTPUT_DIR):
61
- os.makedirs(OUTPUT_DIR)
62
-
63
- total_urls = sum(len(urls) for urls in TASK_EXAMPLE_URLS.values())
64
- print(f"🚀 检查和下载 {total_urls} 张示例图片...")
65
-
66
- # 迭代所有任务和 URL
67
- for prefix, urls in TASK_EXAMPLE_URLS.items():
68
- for i, url in enumerate(urls):
69
- # 文件名格式:cls_1.jpg, seg_2.jpg, det_3.jpg
70
- filename = f"{prefix}_{i + 1}.jpg"
71
- filepath = os.path.join(OUTPUT_DIR, filename)
72
-
73
- if os.path.exists(filepath):
74
- continue # 跳过已存在的文件
75
-
76
- for attempt in range(max_retries):
77
- try:
78
- # 增加更长的超时时间
79
- response = requests.get(url, stream=True, timeout=15)
80
- response.raise_for_status()
81
-
82
- image = Image.open(BytesIO(response.content))
83
- image.save(filepath)
84
- print(f" 成功下载并保存: {filename} (尝 {attempt + 1}/{max_retries})")
85
- break # 成功则跳出重试循环
86
-
87
- except requests.exceptions.RequestException as e:
88
- print(f" ⚠️ 下载 {filename} 失败 (尝试 {attempt + 1}/{max_retries}): {e}")
89
- if attempt < max_retries - 1:
90
- time.sleep(2) # 失败后等待2秒再重试
91
- else:
92
- # 404 错误是 ClientError,意味着 URL 不存在
93
- if '404 Client Error' in str(e):
94
- print(f"❌ 最终下载失败 {filename}: URL {url} 不存在 (404 错误)。")
95
- else:
96
- print(f"❌ 最终下载失败 {filename}: 请检查网络连接或 URL。")
97
- break
98
- except Exception as e:
99
- # 图像处理失败 (如 BytesIO 或 PIL 错误),停止重试
100
- print(f"❌ 图像处理失败 {filename}: {e}")
101
- break
102
-
103
- # 立即执行下载,确保 examples 目录下的文件存在
104
-
105
-
106
- download_and_save_examples()
107
-
108
- # 🌟 关键:创建三个独立的示例列表,用于 Gradio Examples 组件
109
- CLS_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"cls_{i + 1}.jpg")] for i in range(5)]
110
- SEG_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"seg_{i + 1}.jpg")] for i in range(5)]
111
- DET_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"det_{i + 1}.jpg")] for i in range(5)]
112
-
113
-
114
- # --- 辅助函数:生成颜色图例 HTML ---
115
- def generate_legend_html(color_map_dict):
116
- """根据颜色映射字典生成 HTML 图例"""
117
- html_content = "<div style='max-height: 300px; overflow-y: scroll; padding: 10px; border: 1px solid #ccc; background-color: #f7f7f7; border-radius: 8px;'>"
118
- html_content += "<h4 style='margin-top: 0; color: #333;'>🎨 分割颜色图例</h4>"
119
-
120
- if "Error" in color_map_dict:
121
- html_content += "<p style='color: red;'>模型加载失败,图例不可用。</p>"
122
- return html_content
123
-
124
- for label, hex_color in color_map_dict.items():
125
- html_content += f"""
126
- <div style='display: flex; align-items: center; margin-bottom: 5px; font-family: sans-serif;'>
127
- <div style='width: 15px; height: 15px; background-color: {hex_color}; border: 1px solid #333; margin-right: 10px; border-radius: 3px;'></div>
128
- <span style='font-size: 14px; color: #555;'>{label}</span>
129
- </div>
130
- """
131
- html_content += "</div>"
132
- return html_content
133
-
134
-
135
- # --- 辅助函数:类别搜索逻辑 ---
136
- def search_labels(query: str, all_labels) -> str:
137
- """
138
- 在标签列表字典中搜索给定的查询
139
- all_labels 可以是 list (分类/分割) 或 dict (检测)。
140
- """
141
- query = query.strip().lower()
142
- if not query:
143
- return "请输入有效的查询内容。"
144
-
145
- MAX_MATCHES = 10
146
-
147
- # 处理 List (Classification, Segmentation)
148
- if isinstance(all_labels, list):
149
- found_matches = [label for label in all_labels if query in label.lower()]
150
-
151
- if found_matches:
152
- result_list = "\n- ".join(found_matches[:MAX_MATCHES])
153
- summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):"
154
- return f"{summary}\n- {result_list}"
155
- else:
156
- return f"❌ 未找到包含 '{query}' 的类别。"
157
-
158
- # 处理 Dict (Detection)
159
- elif isinstance(all_labels, dict):
160
- found_matches = {k: v for k, v in all_labels.items() if query in v.lower()}
161
-
162
- if found_matches:
163
- result_list = [f"ID {k}: {v}" for k, v in list(found_matches.items())[:MAX_MATCHES]]
164
- summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):"
165
- return f"{summary}\n- {result_list}"
166
- else:
167
- return f"❌ 未找到包含 '{query}' 的类别。"
168
-
169
- return "类别数据格式错误或未加载。"
170
-
171
-
172
- # 🌟 美化 1: Gradio Soft 主题 CSS (背景颜色调整)
173
- CUSTOM_CSS = """
174
- /* 整体背景和卡片阴影优化 */
175
- .gradio-container {
176
- background-color: #f7f7f7; /* 调整为柔和的浅灰色背景 */
177
- font-family: 'Inter', system-ui, sans-serif;
178
- }
179
- /* 卡片和主内容区的美化 */
180
- .gradio-container > div {
181
- border-radius: 12px;
182
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* 柔和阴影 */
183
- transition: all 0.3s ease;
184
- }
185
-
186
- /* 主标题样式 (用于承载 Logo 和文本) */
187
- h1 {
188
- display: flex;
189
- align-items: center;
190
- justify-content: center;
191
- font-size: 2.2em;
192
- color: #333333; /* 深色文字 */
193
- padding: 20px 0;
194
- margin: 0;
195
- }
196
-
197
- /* 按钮和输入框圆角 */
198
- .gr-button, .gr-textbox, .gr-number, .gr-image {
199
- border-radius: 8px !important;
200
- }
201
-
202
- /* 标签和组件背景 */
203
- .gradio-container > div:not(.prose):not(.gr-row) {
204
- background: white;
205
- padding: 15px;
206
- }
207
- /* 页脚 Logo 样式 */
208
- .footer-logo-container {
209
- display: flex;
210
- flex-direction: column;
211
- align-items: center;
212
- justify-content: center;
213
- text-align: center;
214
- padding-top: 15px;
215
- }
216
-
217
- /* 页脚图标/链接容器样式 */
218
- .footer-links {
219
- margin-top: 10px;
220
- font-size: 14px;
221
- color: #555;
222
- display: flex;
223
- gap: 15px; /* 图标之间的间隔 */
224
- align-items: center;
225
- /* 确保内容居中 */
226
- justify-content: center;
227
- }
228
-
229
- /* 图标样式 */
230
- .footer-icon {
231
- font-size: 18px;
232
- vertical-align: middle;
233
- }
234
-
235
- /* 链接颜色 */
236
- .footer-links a {
237
- color: #555;
238
- text-decoration: none;
239
- }
240
- .footer-links a:hover {
241
- color: #333;
242
- }
243
- """
244
-
245
- with gr.Blocks(
246
- title="AI基础模型视觉任务演示平台",
247
- ) as demo:
248
- # 注入 Favicon (网页选项卡图标)
249
- gr.HTML(f"""
250
- <head>
251
- <link rel='icon' type='image/png' href='file/{LOGO_PATH}'/>
252
- </head>
253
- """, visible=False)
254
-
255
- # 🌟 主标题区域:恢复机器人图标
256
- gr.Markdown("<h1>🤖 AI基础模型视觉任务演示平台</h1>")
257
- gr.Markdown("---")
258
-
259
- # 🌟 简化功能说明 (只关注使用方法)
260
- with gr.Accordion("📚 功能说明"):
261
- gr.Markdown("""
262
- 平台支持图像分类、语义分割和目标检测三大任务。
263
- 您可以通过以下步骤使用平台
264
- 1. **切换项卡**:选择希望执行 AI 任务
265
- 2. **上传或选择图片**:上传您自己的图片或点击下方的示例图片
266
- 3. **设置参数**:对于目标检测调整置信度阈值
267
- 4. **点击提交**:点击“提交任务”按钮,查看 AI 分析结果。
268
- """)
269
-
270
- # 🌟 新增:数据集介绍
271
- with gr.Accordion("📚 基础数据集介绍", open=False):
272
- gr.Markdown("""
273
- ### 📖 模型训练数据集概览
274
-
275
- | 任务 | 模型 | 数据集 | 类别数 | 简介 |
276
- | :--- | :--- | :--- | :--- | :--- |
277
- | 图像 | ViT | **ImageNet-1K** | 1000 | 包含超过 100 万张图像,是图像识别领域的标准基准。 |
278
- | 语义分割 | SegFormer | **ADE20K** | 150 | 专注于场景解析,提供 150 种语义概念像素级。 |
279
- | 目标检测 | YOLOv8n | **COCO** | 80 | 最常用的目标检测数据集之一,包含大量物体实例。 |
280
- """)
281
-
282
- # 🌟 新增:网络结构介绍
283
- with gr.Accordion("🧠 网络结构介绍", open=False):
284
- gr.Markdown("""
285
- ### 💻 模型架构说明
286
-
287
- 1. **图像分类 (ViT):**
288
- * **全称:** Vision Transformer (ViT-Base-Patch16-224)
289
- * **特点:** 基于 Transformer 结构,将图像切片后进行序列输入,通过自注意力机制实现全局建模。
290
-
291
- 2. **语义分割 (SegFormer):**
292
- * **全称:** Segmentation Transformer
293
- * **特点:** 高效的 Transformer 架构,使用轻量级解码器,专注于速度和准确性的平衡。
294
-
295
- 3. **目标检测 (YOLOv8n):**
296
- * **全称:** You Only Look Once, Version 8 (Nano)
297
- * **特点:** 单阶段检测器,以速度著称,Nano (n) 版本在保持高性能的同时,体积最小。
298
- """)
299
-
300
- # --- 任务选项卡 ---
301
- with gr.Tabs():
302
- # 1. 图像分类 Tab
303
- with gr.TabItem("🖼️ 图像分类 (ViT)"):
304
- with gr.Row():
305
- with gr.Column(scale=1):
306
- cls_input = gr.Image(type='pil', label="输入图像")
307
- cls_button = gr.Button("🚀 提交分类任务")
308
- with gr.Column(scale=1):
309
- cls_output = gr.Label(num_top_classes=5, label="分类结果 (前 5)")
310
-
311
- # 🌟 所有分类类别列表
312
- gr.Markdown("### 🌟 模型支持的全部分类类别 (ImageNet-1K)")
313
- cls_category_json = gr.JSON(value=ALL_CLS_LABELS, label="所有类别列表", scale=1)
314
-
315
- # 🌟 查询 UI
316
- with gr.Row():
317
- cls_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., dog)", scale=3)
318
- cls_search_button = gr.Button("🔍 搜索", scale=1)
319
- cls_search_output = gr.Markdown("搜索结果将显示在这里。")
320
-
321
- cls_search_button.click(
322
- fn=search_labels,
323
- inputs=[cls_search_query, cls_category_json],
324
- outputs=cls_search_output
325
- )
326
-
327
- # 🌟 更新:使用 CLS_EXAMPLES
328
- gr.Examples(examples=CLS_EXAMPLES, inputs=[cls_input], label="示例图片")
329
- cls_button.click(cls_predict, inputs=cls_input, outputs=cls_output)
330
-
331
- # 2. 语义分割 Tab
332
- with gr.TabItem("✂️ 语义分割 (SegFormer)"):
333
- with gr.Row():
334
- with gr.Column(scale=2):
335
- seg_input = gr.Image(type='pil', label="输入图像")
336
- seg_button = gr.Button("🚀 提交分割任务")
337
- with gr.Column(scale=2):
338
- seg_output = gr.Image(type='pil', label="分割结果 (叠加)")
339
- with gr.Column(scale=1):
340
- # 🌟 展示图例
341
- gr.HTML(value=generate_legend_html(ALL_SEG_COLOR_MAP), scale=1)
342
-
343
- # 保留完类别列表(以 JSON 格式展示,作为额外的参考)
344
- gr.Markdown("### 完整类别列表 (JSON)")
345
- seg_category_json = gr.JSON(value={f"ID {i}": label for i, label in enumerate(ALL_SEG_LABELS)},
346
- label="所有类别 JSON")
347
-
348
- # 🌟 查询 UI 提示改为英文
349
- with gr.Row():
350
- seg_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., road, sky)",
351
- scale=3)
352
- seg_search_button = gr.Button("🔍 搜索", scale=1)
353
- seg_search_output = gr.Markdown("搜索结果将显示在这里。")
354
-
355
- seg_search_button.click(
356
- fn=search_labels,
357
- inputs=[seg_search_query, seg_category_json],
358
- outputs=seg_search_output
359
- )
360
-
361
- # 🌟 更新:使用 SEG_EXAMPLES
362
- gr.Examples(examples=SEG_EXAMPLES, inputs=[seg_input], label="示例图片")
363
- seg_button.click(seg_predict, inputs=seg_input, outputs=seg_output)
364
-
365
- # 3. 目标检测 Tab
366
- with gr.TabItem("🎯 目标检测 (YOLOv8n)"):
367
- with gr.Row():
368
- with gr.Column(scale=1):
369
- det_input_image = gr.Image(type='pil', label="输入图像")
370
- det_input_number = gr.Number(
371
- precision=2,
372
- minimum=0.01,
373
- maximum=1,
374
- value=0.30,
375
- label='置信度阈值'
376
- )
377
- det_button = gr.Button("🚀 提交检测任务")
378
- with gr.Column(scale=1):
379
- det_output = gr.Image(type='pil', label="检测结果 (边界框)")
380
-
381
- # 🌟 展示目标检测类别列表
382
- gr.Markdown("### 🎯 模型支持的检测类别 (COCO)")
383
- det_category_json = gr.JSON(value=ALL_DET_LABELS, label="所有类别列表")
384
-
385
- # 🌟 查询 UI 提改为英文
386
- with gr.Row():
387
- det_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., bicycle, train)",
388
- scale=3)
389
- det_search_button = gr.Button("🔍 搜索", scale=1)
390
- det_search_output = gr.Markdown("搜索结果将显示在这里。")
391
-
392
- det_search_button.click(
393
- fn=search_labels,
394
- inputs=[det_search_query, det_category_json],
395
- outputs=det_search_output
396
- )
397
-
398
- # 🌟 更新:使用 DET_EXAMPLES
399
- gr.Examples(examples=DET_EXAMPLES, inputs=[det_input_image], label="示例图片")
400
- det_button.click(det_predict, inputs=[det_input_image, det_input_number], outputs=det_output)
401
-
402
- # 🌟 添加页脚和 Logo/版权 (使用图标和本地 Logo)
403
- gr.HTML(
404
- f"""
405
- <div class='footer-logo-container'>
406
- <div class="footer-links">
407
- <p>{COPYRIGHT_TEXT}</p>
408
- </div>
409
-
410
- <div class="footer-links">
411
-
412
- <!-- 🌟 学校图标和名称 (添加超链接) -->
413
- <span class="footer-icon">🏢</span>
414
- <a href='https://wutinfo.whut.edu.cn/' target='_blank' style='text-decoration: none; color: inherit;'>
415
- <span>{SCHOOL_NAME_EN}</span>
416
- </a>
417
-
418
- <!-- 🌟 额外 Logo 位于版权信息之后 -->
419
- <img src='file/{LOGO_PATH}' alt='Logo' style='height: 30px; margin-left: 20px;' onerror="this.style.display='none'">
420
- </div>
421
- </div>
422
- """
423
- )
424
-
425
- if __name__ == "__main__":
426
- gr.close_all()
427
- print("Launching Gradio demo...")
428
- # 🌟 传入 css 参数
 
 
429
  demo.launch(share=True, css=CUSTOM_CSS)
 
1
+ import gradio as gr
2
+ from model import (
3
+ cls_predict,
4
+ det_predict,
5
+ seg_predict,
6
+ ALL_SEG_LABELS,
7
+ ALL_DET_LABELS,
8
+ ALL_SEG_COLOR_MAP,
9
+ ALL_CLS_LABELS
10
+ )
11
+ import os
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ import time # 用于重试
16
+
17
+ # --- 配置 Logo 和版权信息 ---
18
+ LOGO_PATH = "logo/logo.png" # 请将您的 Logo 图片存放在此路径
19
+ # 🌟 更新:版权信息仅保留年份和权利声明,具体内容在 HTML 中处理
20
+ COPYRIGHT_TEXT = "© 2025 All Rights Reserved."
21
+ SCHOOL_NAME_EN = "School of Information Engineering, Wuhan University of Technology"
22
+
23
+ # --- 自动下载示例图片逻辑 (嵌入) ---
24
+
25
+ # 🌟 恢复到用户指定的 COCO ID 列表
26
+ TASK_EXAMPLE_URLS = {
27
+ "cls": [
28
+ # 图像分类示例 (val2017)
29
+ "http://images.cocodataset.org/val2017/000000000285.jpg", # 猫和键盘 (保留)
30
+ "http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留)
31
+ "http://images.cocodataset.org/val2017/000000000724.jpg", # 食物/餐具 (保留)
32
+ "http://images.cocodataset.org/val2017/000000001584.jpg", # 多人,多 (保留)
33
+ "http://images.cocodataset.org/train2017/000000001097.jpg", # cls_5
34
+ ],
35
+ "seg": [
36
+ # 语义分割示例 (val2017)
37
+ "http://images.cocodataset.org/val2017/000000000139.jpg", # 街景/汽车 (保留)
38
+ "http://images.cocodataset.org/val2017/000000000632.jpg", # 街景/行人 (保留)
39
+ "http://images.cocodataset.org/val2017/000000000885.jpg", # 滑板手 (保留)
40
+ "http://images.cocodataset.org/train2017/000000000267.jpg", # seg_4
41
+ "http://images.cocodataset.org/train2017/000000001140.jpg", # 原 seg_5
42
+ ],
43
+ "det": [
44
+ # 目标检测示例 (val2017)
45
+ "http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留)
46
+ "http://images.cocodataset.org/val2017/000000001268.jpg", # det_2
47
+ "http://images.cocodataset.org/train2017/000000001072.jpg", # 原 det_3
48
+ "http://images.cocodataset.org/train2017/000000000119.jpg", # 原 det_4
49
+ "http://images.cocodataset.org/train2017/000000000570.jpg", # 原 det_5
50
+ ]
51
+ }
52
+
53
+ # 项目期望的本地路径
54
+ OUTPUT_DIR = "examples"
55
+
56
+
57
+ def download_and_save_examples(max_retries=3):
58
+ """下载示例图片到本地 examples/ 目录,使用任务前缀命名,增加重试机制"""
59
+ if not os.path.exists(OUTPUT_DIR):
60
+ os.makedirs(OUTPUT_DIR)
61
+
62
+ total_urls = sum(len(urls) for urls in TASK_EXAMPLE_URLS.values())
63
+ print(f"🚀 检查和下载 {total_urls} 张示例图片...")
64
+
65
+ # 迭代所有任务和 URL
66
+ for prefix, urls in TASK_EXAMPLE_URLS.items():
67
+ for i, url in enumerate(urls):
68
+ # 文件名格式:cls_1.jpg, seg_2.jpg, det_3.jpg
69
+ filename = f"{prefix}_{i + 1}.jpg"
70
+ filepath = os.path.join(OUTPUT_DIR, filename)
71
+
72
+ if os.path.exists(filepath):
73
+ continue # 跳过已存在的文件
74
+
75
+ for attempt in range(max_retries):
76
+ try:
77
+ # 增加更长的超时时间
78
+ response = requests.get(url, stream=True, timeout=15)
79
+ response.raise_for_status()
80
+
81
+ image = Image.open(BytesIO(response.content))
82
+ image.save(filepath)
83
+ print(f" 成功下载并保存: {filename} (尝试 {attempt + 1}/{max_retries})")
84
+ break # 成功则跳出重循环
85
+
86
+ except requests.exceptions.RequestException as e:
87
+ print(f" ⚠️ 下载 {filename} 失败 (尝试 {attempt + 1}/{max_retries}): {e}")
88
+ if attempt < max_retries - 1:
89
+ time.sleep(2) # 失败后等待2秒再重试
90
+ else:
91
+ # 404 错误是 ClientError,意味着 URL 不存在
92
+ if '404 Client Error' in str(e):
93
+ print(f"❌ 最终下载失败 {filename}: URL {url} 不存在 (404 错误)。")
94
+ else:
95
+ print(f"❌ 最终下载失败 {filename}: 请检查网络连接或 URL。")
96
+ break
97
+ except Exception as e:
98
+ # 图像处理失败 (如 BytesIO 或 PIL 错误),停止重试
99
+ print(f"❌ 图像处理失败 {filename}: {e}")
100
+ break
101
+
102
+ # 立即执行下载,确保 examples 目录下的文件存在
103
+
104
+
105
+ download_and_save_examples()
106
+
107
+ # 🌟 关键:创建三个独立的示例列表,用于 Gradio Examples 组件
108
+ CLS_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"cls_{i + 1}.jpg")] for i in range(5)]
109
+ SEG_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"seg_{i + 1}.jpg")] for i in range(5)]
110
+ DET_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"det_{i + 1}.jpg")] for i in range(5)]
111
+
112
+
113
+ # --- 辅助函数:生成颜色图例 HTML ---
114
+ def generate_legend_html(color_map_dict):
115
+ """根据颜色映射字典生成 HTML 图例"""
116
+ html_content = "<div style='max-height: 300px; overflow-y: scroll; padding: 10px; border: 1px solid #ccc; background-color: #f7f7f7; border-radius: 8px;'>"
117
+ html_content += "<h4 style='margin-top: 0; color: #333;'>🎨 分割颜色图例</h4>"
118
+
119
+ if "Error" in color_map_dict:
120
+ html_content += "<p style='color: red;'>模型加载失败,图例不可用。</p>"
121
+ return html_content
122
+
123
+ for label, hex_color in color_map_dict.items():
124
+ html_content += f"""
125
+ <div style='display: flex; align-items: center; margin-bottom: 5px; font-family: sans-serif;'>
126
+ <div style='width: 15px; height: 15px; background-color: {hex_color}; border: 1px solid #333; margin-right: 10px; border-radius: 3px;'></div>
127
+ <span style='font-size: 14px; color: #555;'>{label}</span>
128
+ </div>
129
+ """
130
+ html_content += "</div>"
131
+ return html_content
132
+
133
+
134
+ # --- 辅助函数:类别搜索逻辑 ---
135
+ def search_labels(query: str, all_labels) -> str:
136
+ """
137
+ 在标签列表或字典中搜索给定的查询。
138
+ all_labels 可以是 list (分类/分割) dict (检测)
139
+ """
140
+ query = query.strip().lower()
141
+ if not query:
142
+ return "请输入有效的查询内容。"
143
+
144
+ MAX_MATCHES = 10
145
+
146
+ # 处理 List (Classification, Segmentation)
147
+ if isinstance(all_labels, list):
148
+ found_matches = [label for label in all_labels if query in label.lower()]
149
+
150
+ if found_matches:
151
+ result_list = "\n- ".join(found_matches[:MAX_MATCHES])
152
+ summary = f" 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):"
153
+ return f"{summary}\n- {result_list}"
154
+ else:
155
+ return f"❌ 未找到包含 '{query}' 的类别。"
156
+
157
+ # 处理 Dict (Detection)
158
+ elif isinstance(all_labels, dict):
159
+ found_matches = {k: v for k, v in all_labels.items() if query in v.lower()}
160
+
161
+ if found_matches:
162
+ result_list = [f"ID {k}: {v}" for k, v in list(found_matches.items())[:MAX_MATCHES]]
163
+ summary = f" 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} ):"
164
+ return f"{summary}\n- {result_list}"
165
+ else:
166
+ return f"❌ 未找到包含 '{query}' 的类别。"
167
+
168
+ return "类别数据格式错误或未加载。"
169
+
170
+
171
+ # 🌟 美化 1: Gradio Soft 主题 CSS (背景颜色调整)
172
+ CUSTOM_CSS = """
173
+ /* 整体背景和卡片阴影优化 */
174
+ .gradio-container {
175
+ background-color: #f7f7f7; /* 调整为柔和的浅灰色背景 */
176
+ font-family: 'Inter', system-ui, sans-serif;
177
+ }
178
+ /* 卡片和主内容区的美化 */
179
+ .gradio-container > div {
180
+ border-radius: 12px;
181
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* 柔和阴影 */
182
+ transition: all 0.3s ease;
183
+ }
184
+
185
+ /* 主标题样式 (用于承载 Logo 和文本) */
186
+ h1 {
187
+ display: flex;
188
+ align-items: center;
189
+ justify-content: center;
190
+ font-size: 2.2em;
191
+ color: #333333; /* 深色文字 */
192
+ padding: 20px 0;
193
+ margin: 0;
194
+ }
195
+
196
+ /* 按钮和输入框圆角 */
197
+ .gr-button, .gr-textbox, .gr-number, .gr-image {
198
+ border-radius: 8px !important;
199
+ }
200
+
201
+ /* 标签和组件背景 */
202
+ .gradio-container > div:not(.prose):not(.gr-row) {
203
+ background: white;
204
+ padding: 15px;
205
+ }
206
+ /* 页脚 Logo 样式 */
207
+ .footer-logo-container {
208
+ display: flex;
209
+ flex-direction: column;
210
+ align-items: center;
211
+ justify-content: center;
212
+ text-align: center;
213
+ padding-top: 15px;
214
+ }
215
+
216
+ /* 页脚图标/链接容器样式 */
217
+ .footer-links {
218
+ margin-top: 10px;
219
+ font-size: 14px;
220
+ color: #555;
221
+ display: flex;
222
+ gap: 15px; /* 图标之间的间隔 */
223
+ align-items: center;
224
+ /* 确保内容居中 */
225
+ justify-content: center;
226
+ }
227
+
228
+ /* 图标样式 */
229
+ .footer-icon {
230
+ font-size: 18px;
231
+ vertical-align: middle;
232
+ }
233
+
234
+ /* 链接颜色 */
235
+ .footer-links a {
236
+ color: #555;
237
+ text-decoration: none;
238
+ }
239
+ .footer-links a:hover {
240
+ color: #333;
241
+ }
242
+ """
243
+
244
+ with gr.Blocks(
245
+ title="AI基础模型视觉任务演示平台",
246
+ ) as demo:
247
+ # 注入 Favicon (网页选项卡图标)
248
+ gr.HTML(f"""
249
+ <head>
250
+ <link rel='icon' type='image/png' href='file/{LOGO_PATH}'/>
251
+ </head>
252
+ """, visible=False)
253
+
254
+ # 🌟 主标题区域:恢复机器人图标
255
+ gr.Markdown("<h1>🤖 AI基础模型视觉任务演示平台</h1>")
256
+ gr.Markdown("---")
257
+
258
+ # 🌟 简化功能说明 (只关注使用方法)
259
+ with gr.Accordion("📚 功能说明"):
260
+ gr.Markdown("""
261
+ 本平台支持图像分类、语义分割和目标检测三大任务。
262
+ 您可以通过以下步骤使用平台
263
+ 1. **切换选项卡**选择您希望执行的 AI 任务。
264
+ 2. **上传或择图片**:上传自己图片或点击下方的示例图片
265
+ 3. **设置参数**:对于目标检测,调整置信度阈值
266
+ 4. **点击提交**:点击“提交任务”按钮查看 AI 分析结果
267
+ """)
268
+
269
+ # 🌟 新增:数据集介绍
270
+ with gr.Accordion("📚 基础数据集介绍", open=False):
271
+ gr.Markdown("""
272
+ ### 📖 模型训练数据集概览
273
+
274
+ | 任务 | 模型 | 数据集 | 类别数 | 简介 |
275
+ | :--- | :--- | :--- | :--- | :--- |
276
+ | 图像分类 | ViT | **ImageNet-1K** | 1000 | 包含超过 100 万张图像,是图像识别领域的标准基准。 |
277
+ | 语义 | SegFormer | **ADE20K** | 150 | 专注于场景解析,提供 150 种语义概念像素级。 |
278
+ | 目标检测 | YOLOv8n | **COCO** | 80 | 最常用检测数据集之一,包含大量物体实例。 |
279
+ """)
280
+
281
+ # 🌟 新增:网络结构介绍
282
+ with gr.Accordion("🧠 网络结构介绍", open=False):
283
+ gr.Markdown("""
284
+ ### 💻 模型架构说明
285
+
286
+ 1. **图像分类 (ViT):**
287
+ * **全称:** Vision Transformer (ViT-Base-Patch16-224)
288
+ * **特点:** 基于 Transformer 结构,将图像切片后进行序列输入,通过自注意力机制实现全局建模。
289
+
290
+ 2. **语义分割 (SegFormer):**
291
+ * **全称:** Segmentation Transformer
292
+ * **特点:** 高效的 Transformer 架构,使用轻量级解码器,专注于速度和准确性的平衡。
293
+
294
+ 3. **目标检测 (YOLOv8n):**
295
+ * **全称:** You Only Look Once, Version 8 (Nano)
296
+ * **特点:** 单阶段检测器,以速度著称,Nano (n) 版本在保持高性能的同时,体积最小。
297
+ """)
298
+
299
+ # --- 任务选项卡 ---
300
+ with gr.Tabs():
301
+ # 1. 图像分类 Tab
302
+ with gr.TabItem("🖼️ 图像分类 (ViT)"):
303
+ with gr.Row():
304
+ with gr.Column(scale=1):
305
+ cls_input = gr.Image(type='pil', label="输入图像")
306
+ cls_button = gr.Button("🚀 提交分类任务")
307
+ with gr.Column(scale=1):
308
+ cls_output = gr.Label(num_top_classes=5, label="分类结果 (前 5)")
309
+
310
+ # 🌟 调整��序:Examples 先于 类别列表/查询 UI
311
+ gr.Examples(examples=CLS_EXAMPLES, inputs=[cls_input], label="例图片")
312
+
313
+ # 🌟 展示所有别列表
314
+ gr.Markdown("### 🌟 模型支持的全部分类类别 (ImageNet-1K)")
315
+ cls_category_json = gr.JSON(value=ALL_CLS_LABELS, label="所有类别列表", scale=1)
316
+
317
+ # 🌟 查询 UI
318
+ with gr.Row():
319
+ cls_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., dog)", scale=3)
320
+ cls_search_button = gr.Button("🔍 搜索", scale=1)
321
+ cls_search_output = gr.Markdown("搜索结果将显示在这里。")
322
+
323
+ cls_search_button.click(
324
+ fn=search_labels,
325
+ inputs=[cls_search_query, cls_category_json],
326
+ outputs=cls_search_output
327
+ )
328
+
329
+ cls_button.click(cls_predict, inputs=cls_input, outputs=cls_output)
330
+
331
+ # 2. 语义分割 Tab
332
+ with gr.TabItem("✂️ 语义分割 (SegFormer)"):
333
+ with gr.Row():
334
+ with gr.Column(scale=2):
335
+ seg_input = gr.Image(type='pil', label="输入图像")
336
+ seg_button = gr.Button("🚀 提交分割任务")
337
+ with gr.Column(scale=2):
338
+ seg_output = gr.Image(type='pil', label="分割结果 (叠加)")
339
+ with gr.Column(scale=1):
340
+ # 🌟 展示图例
341
+ gr.HTML(value=generate_legend_html(ALL_SEG_COLOR_MAP), scale=1)
342
+
343
+ # 🌟 调顺序:Examples 先于 类别列表/查询 UI
344
+ gr.Examples(examples=SEG_EXAMPLES, inputs=[seg_input], label="示例图片")
345
+
346
+ # 保留完整的类别列表(以 JSON 格式展示,作为额外的参考)
347
+ gr.Markdown("### 完整类别列表 (JSON)")
348
+ seg_category_json = gr.JSON(value={f"ID {i}": label for i, label in enumerate(ALL_SEG_LABELS)},
349
+ label="所有类别 JSON")
350
+
351
+ # 🌟 查询 UI 提示改为英文
352
+ with gr.Row():
353
+ seg_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., road, sky)",
354
+ scale=3)
355
+ seg_search_button = gr.Button("🔍 搜索", scale=1)
356
+ seg_search_output = gr.Markdown("搜索结果将显示在这里。")
357
+
358
+ seg_search_button.click(
359
+ fn=search_labels,
360
+ inputs=[seg_search_query, seg_category_json],
361
+ outputs=seg_search_output
362
+ )
363
+
364
+ seg_button.click(seg_predict, inputs=seg_input, outputs=seg_output)
365
+
366
+ # 3. 目标检测 Tab
367
+ with gr.TabItem("🎯 目标检测 (YOLOv8n)"):
368
+ with gr.Row():
369
+ with gr.Column(scale=1):
370
+ det_input_image = gr.Image(type='pil', label="输入图像")
371
+ det_input_number = gr.Number(
372
+ precision=2,
373
+ minimum=0.01,
374
+ maximum=1,
375
+ value=0.30,
376
+ label='置信度阈值'
377
+ )
378
+ det_button = gr.Button("🚀 提交检测任务")
379
+ with gr.Column(scale=1):
380
+ det_output = gr.Image(type='pil', label="检测结果 (边界框)")
381
+
382
+ # 🌟 调整顺序:Examples 先于 类别列表/查询 UI
383
+ gr.Examples(examples=DET_EXAMPLES, inputs=[det_input_image], label="示例图片")
384
+
385
+ # 🌟 目标检测类别列表
386
+ gr.Markdown("### 🎯 模型支持的检测类别 (COCO)")
387
+ det_category_json = gr.JSON(value=ALL_DET_LABELS, label="所有类别列表")
388
+
389
+ # 🌟 查询 UI 提示改为英文
390
+ with gr.Row():
391
+ det_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., bicycle, train)",
392
+ scale=3)
393
+ det_search_button = gr.Button("🔍 搜索", scale=1)
394
+ det_search_output = gr.Markdown("搜索结果将显示在这里。")
395
+
396
+ det_search_button.click(
397
+ fn=search_labels,
398
+ inputs=[det_search_query, det_category_json],
399
+ outputs=det_search_output
400
+ )
401
+
402
+ det_button.click(det_predict, inputs=[det_input_image, det_input_number], outputs=det_output)
403
+
404
+ # 🌟 添加页脚和 Logo/版权
405
+ gr.HTML(
406
+ f"""
407
+ <div class='footer-logo-container'>
408
+ <div class="footer-links">
409
+ <p>{COPYRIGHT_TEXT}</p>
410
+ </div>
411
+
412
+ <div class="footer-links">
413
+
414
+ <!-- 🌟 学校图标和名称 (添加超链接) -->
415
+ <span class="footer-icon">🏢</span>
416
+ <a href='https://wutinfo.whut.edu.cn/' target='_blank' style='text-decoration: none; color: inherit;'>
417
+ <span>{SCHOOL_NAME_EN}</span>
418
+ </a>
419
+
420
+ <!-- 🌟 额外 Logo 位于版权信息之后 -->
421
+ <img src='file/{LOGO_PATH}' alt='Logo' style='height: 30px; margin-left: 20px;' onerror="this.style.display='none'">
422
+ </div>
423
+ </div>
424
+ """
425
+ )
426
+
427
+ if __name__ == "__main__":
428
+ gr.close_all()
429
+ print("Launching Gradio demo...")
430
+ # 🌟 传入 css 参数
431
  demo.launch(share=True, css=CUSTOM_CSS)