|
|
import gradio as gr |
|
|
from model import ( |
|
|
cls_predict, |
|
|
det_predict, |
|
|
seg_predict, |
|
|
ALL_SEG_LABELS, |
|
|
ALL_DET_LABELS, |
|
|
ALL_SEG_COLOR_MAP, |
|
|
ALL_CLS_LABELS |
|
|
) |
|
|
import os |
|
|
import requests |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
import time |
|
|
|
|
|
|
|
|
LOGO_PATH = "logo/logo.png" |
|
|
|
|
|
COPYRIGHT_TEXT = "© 2025 All Rights Reserved." |
|
|
SCHOOL_NAME_EN = "School of Information Engineering, Wuhan University of Technology" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TASK_EXAMPLE_URLS = { |
|
|
"cls": [ |
|
|
|
|
|
"http://images.cocodataset.org/val2017/000000000285.jpg", |
|
|
"http://images.cocodataset.org/val2017/000000000785.jpg", |
|
|
"http://images.cocodataset.org/val2017/000000000724.jpg", |
|
|
"http://images.cocodataset.org/val2017/000000001584.jpg", |
|
|
"http://images.cocodataset.org/train2017/000000001097.jpg", |
|
|
], |
|
|
"seg": [ |
|
|
|
|
|
"http://images.cocodataset.org/val2017/000000000139.jpg", |
|
|
"http://images.cocodataset.org/val2017/000000000632.jpg", |
|
|
"http://images.cocodataset.org/val2017/000000000885.jpg", |
|
|
"http://images.cocodataset.org/train2017/000000000267.jpg", |
|
|
"http://images.cocodataset.org/train2017/000000001140.jpg", |
|
|
], |
|
|
"det": [ |
|
|
|
|
|
"http://images.cocodataset.org/val2017/000000000785.jpg", |
|
|
"http://images.cocodataset.org/val2017/000000001268.jpg", |
|
|
"http://images.cocodataset.org/train2017/000000001072.jpg", |
|
|
"http://images.cocodataset.org/train2017/000000000119.jpg", |
|
|
"http://images.cocodataset.org/train2017/000000000570.jpg", |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
OUTPUT_DIR = "examples" |
|
|
|
|
|
|
|
|
def download_and_save_examples(max_retries=3): |
|
|
"""下载示例图片到本地 examples/ 目录,使用任务前缀命名,增加重试机制""" |
|
|
if not os.path.exists(OUTPUT_DIR): |
|
|
os.makedirs(OUTPUT_DIR) |
|
|
|
|
|
total_urls = sum(len(urls) for urls in TASK_EXAMPLE_URLS.values()) |
|
|
print(f"🚀 检查和下载 {total_urls} 张示例图片...") |
|
|
|
|
|
|
|
|
for prefix, urls in TASK_EXAMPLE_URLS.items(): |
|
|
for i, url in enumerate(urls): |
|
|
|
|
|
filename = f"{prefix}_{i + 1}.jpg" |
|
|
filepath = os.path.join(OUTPUT_DIR, filename) |
|
|
|
|
|
if os.path.exists(filepath): |
|
|
continue |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
response = requests.get(url, stream=True, timeout=15) |
|
|
response.raise_for_status() |
|
|
|
|
|
image = Image.open(BytesIO(response.content)) |
|
|
image.save(filepath) |
|
|
print(f" 成功下载并保存: {filename} (尝试 {attempt + 1}/{max_retries})") |
|
|
break |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f" ⚠️ 下载 {filename} 失败 (尝试 {attempt + 1}/{max_retries}): {e}") |
|
|
if attempt < max_retries - 1: |
|
|
time.sleep(2) |
|
|
else: |
|
|
|
|
|
if '404 Client Error' in str(e): |
|
|
print(f"❌ 最终下载失败 {filename}: URL {url} 不存在 (404 错误)。") |
|
|
else: |
|
|
print(f"❌ 最终下载失败 {filename}: 请检查网络连接或 URL。") |
|
|
break |
|
|
except Exception as e: |
|
|
|
|
|
print(f"❌ 图像处理失败 {filename}: {e}") |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
download_and_save_examples() |
|
|
|
|
|
|
|
|
CLS_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"cls_{i + 1}.jpg")] for i in range(5)] |
|
|
SEG_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"seg_{i + 1}.jpg")] for i in range(5)] |
|
|
DET_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"det_{i + 1}.jpg")] for i in range(5)] |
|
|
|
|
|
|
|
|
|
|
|
def generate_legend_html(color_map_dict): |
|
|
"""根据颜色映射字典生成 HTML 图例""" |
|
|
html_content = "<div style='max-height: 300px; overflow-y: scroll; padding: 10px; border: 1px solid #ccc; background-color: #f7f7f7; border-radius: 8px;'>" |
|
|
html_content += "<h4 style='margin-top: 0; color: #333;'>🎨 分割颜色图例</h4>" |
|
|
|
|
|
if "Error" in color_map_dict: |
|
|
html_content += "<p style='color: red;'>模型加载失败,图例不可用。</p>" |
|
|
return html_content |
|
|
|
|
|
for label, hex_color in color_map_dict.items(): |
|
|
html_content += f""" |
|
|
<div style='display: flex; align-items: center; margin-bottom: 5px; font-family: sans-serif;'> |
|
|
<div style='width: 15px; height: 15px; background-color: {hex_color}; border: 1px solid #333; margin-right: 10px; border-radius: 3px;'></div> |
|
|
<span style='font-size: 14px; color: #555;'>{label}</span> |
|
|
</div> |
|
|
""" |
|
|
html_content += "</div>" |
|
|
return html_content |
|
|
|
|
|
|
|
|
|
|
|
def search_labels(query: str, all_labels) -> str: |
|
|
""" |
|
|
在标签列表或字典中搜索给定的查询。 |
|
|
all_labels 可以是 list (分类/分割) 或 dict (检测)。 |
|
|
""" |
|
|
query = query.strip().lower() |
|
|
if not query: |
|
|
return "请输入有效的查询内容。" |
|
|
|
|
|
MAX_MATCHES = 10 |
|
|
|
|
|
|
|
|
if isinstance(all_labels, list): |
|
|
found_matches = [label for label in all_labels if query in label.lower()] |
|
|
|
|
|
if found_matches: |
|
|
result_list = "\n- ".join(found_matches[:MAX_MATCHES]) |
|
|
summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):" |
|
|
return f"{summary}\n- {result_list}" |
|
|
else: |
|
|
return f"❌ 未找到包含 '{query}' 的类别。" |
|
|
|
|
|
|
|
|
elif isinstance(all_labels, dict): |
|
|
found_matches = {k: v for k, v in all_labels.items() if query in v.lower()} |
|
|
|
|
|
if found_matches: |
|
|
result_list = [f"ID {k}: {v}" for k, v in list(found_matches.items())[:MAX_MATCHES]] |
|
|
summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):" |
|
|
return f"{summary}\n- {result_list}" |
|
|
else: |
|
|
return f"❌ 未找到包含 '{query}' 的类别。" |
|
|
|
|
|
return "类别数据格式错误或未加载。" |
|
|
|
|
|
|
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
/* 整体背景和卡片阴影优化 */ |
|
|
.gradio-container { |
|
|
background-color: #f7f7f7; /* 调整为柔和的浅灰色背景 */ |
|
|
font-family: 'Inter', system-ui, sans-serif; |
|
|
} |
|
|
/* 卡片和主内容区的美化 */ |
|
|
.gradio-container > div { |
|
|
border-radius: 12px; |
|
|
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* 柔和阴影 */ |
|
|
transition: all 0.3s ease; |
|
|
} |
|
|
|
|
|
/* 主标题样式 (用于承载 Logo 和文本) */ |
|
|
h1 { |
|
|
display: flex; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
font-size: 2.2em; |
|
|
color: #333333; /* 深色文字 */ |
|
|
padding: 20px 0; |
|
|
margin: 0; |
|
|
} |
|
|
|
|
|
/* 按钮和输入框圆角 */ |
|
|
.gr-button, .gr-textbox, .gr-number, .gr-image { |
|
|
border-radius: 8px !important; |
|
|
} |
|
|
|
|
|
/* 标签和组件背景 */ |
|
|
.gradio-container > div:not(.prose):not(.gr-row) { |
|
|
background: white; |
|
|
padding: 15px; |
|
|
} |
|
|
/* 页脚 Logo 样式 */ |
|
|
.footer-logo-container { |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
text-align: center; |
|
|
padding-top: 15px; |
|
|
} |
|
|
|
|
|
/* 页脚图标/链接容器样式 */ |
|
|
.footer-links { |
|
|
margin-top: 10px; |
|
|
font-size: 14px; |
|
|
color: #555; |
|
|
display: flex; |
|
|
gap: 15px; /* 图标之间的间隔 */ |
|
|
align-items: center; |
|
|
/* 确保内容居中 */ |
|
|
justify-content: center; |
|
|
} |
|
|
|
|
|
/* 图标样式 */ |
|
|
.footer-icon { |
|
|
font-size: 18px; |
|
|
vertical-align: middle; |
|
|
} |
|
|
|
|
|
/* 链接颜色 */ |
|
|
.footer-links a { |
|
|
color: #555; |
|
|
text-decoration: none; |
|
|
} |
|
|
.footer-links a:hover { |
|
|
color: #333; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks( |
|
|
title="AI基础模型视觉任务演示平台", |
|
|
) as demo: |
|
|
|
|
|
gr.HTML(f""" |
|
|
<head> |
|
|
<link rel='icon' type='image/png' href='file/{LOGO_PATH}'/> |
|
|
</head> |
|
|
""", visible=False) |
|
|
|
|
|
|
|
|
gr.Markdown("<h1>🤖 AI基础模型视觉任务演示平台</h1>") |
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
with gr.Accordion("📚 功能说明"): |
|
|
gr.Markdown(""" |
|
|
本平台支持图像分类、语义分割和目标检测三大任务。 |
|
|
您可以通过以下步骤使用平台: |
|
|
1. **切换选项卡**:选择您希望执行的 AI 任务。 |
|
|
2. **上传或选择图片**:上传您自己的图片或点击下方的示例图片。 |
|
|
3. **设置参数**:对于目标检测,调整置信度阈值。 |
|
|
4. **点击提交**:点击“提交任务”按钮,查看 AI 分析结果。 |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Accordion("📚 基础数据集介绍", open=False): |
|
|
gr.Markdown(""" |
|
|
### 📖 模型训练数据集概览 |
|
|
|
|
|
| 任务 | 模型 | 数据集 | 类别数 | 简介 | |
|
|
| :--- | :--- | :--- | :--- | :--- | |
|
|
| 图像分类 | ViT | **ImageNet-1K** | 1000 | 包含超过 100 万张图像,是图像识别领域的标准基准。 | |
|
|
| 语义分割 | SegFormer | **ADE20K** | 150 | 专注于场景解析,提供 150 种语义概念的像素级标注。 | |
|
|
| 目标检测 | YOLOv8n | **COCO** | 80 | 最常用的目标检测数据集之一,包含大量物体实例。 | |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Accordion("🧠 网络结构介绍", open=False): |
|
|
gr.Markdown(""" |
|
|
### 💻 模型架构说明 |
|
|
|
|
|
1. **图像分类 (ViT):** |
|
|
* **全称:** Vision Transformer (ViT-Base-Patch16-224) |
|
|
* **特点:** 基于 Transformer 结构,将图像切片后进行序列输入,通过自注意力机制实现全局建模。 |
|
|
|
|
|
2. **语义分割 (SegFormer):** |
|
|
* **全称:** Segmentation Transformer |
|
|
* **特点:** 高效的 Transformer 架构,使用轻量级解码器,专注于速度和准确性的平衡。 |
|
|
|
|
|
3. **目标检测 (YOLOv8n):** |
|
|
* **全称:** You Only Look Once, Version 8 (Nano) |
|
|
* **特点:** 单阶段检测器,以速度著称,Nano (n) 版本在保持高性能的同时,体积最小。 |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.TabItem("🖼️ 图像分类 (ViT)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
cls_input = gr.Image(type='pil', label="输入图像") |
|
|
cls_button = gr.Button("🚀 提交分类任务") |
|
|
with gr.Column(scale=1): |
|
|
cls_output = gr.Label(num_top_classes=5, label="分类结果 (前 5)") |
|
|
|
|
|
|
|
|
gr.Examples(examples=CLS_EXAMPLES, inputs=[cls_input], label="示例图片") |
|
|
|
|
|
|
|
|
gr.Markdown("### 🌟 模型支持的全部分类类别 (ImageNet-1K)") |
|
|
cls_category_json = gr.JSON(value=ALL_CLS_LABELS, label="所有类别列表", scale=1) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
cls_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., dog)", scale=3) |
|
|
cls_search_button = gr.Button("🔍 搜索", scale=1) |
|
|
cls_search_output = gr.Markdown("搜索结果将显示在这里。") |
|
|
|
|
|
cls_search_button.click( |
|
|
fn=search_labels, |
|
|
inputs=[cls_search_query, cls_category_json], |
|
|
outputs=cls_search_output |
|
|
) |
|
|
|
|
|
cls_button.click(cls_predict, inputs=cls_input, outputs=cls_output) |
|
|
|
|
|
|
|
|
with gr.TabItem("✂️ 语义分割 (SegFormer)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
seg_input = gr.Image(type='pil', label="输入图像") |
|
|
seg_button = gr.Button("🚀 提交分割任务") |
|
|
with gr.Column(scale=2): |
|
|
seg_output = gr.Image(type='pil', label="分割结果 (叠加)") |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.HTML(value=generate_legend_html(ALL_SEG_COLOR_MAP), scale=1) |
|
|
|
|
|
|
|
|
gr.Examples(examples=SEG_EXAMPLES, inputs=[seg_input], label="示例图片") |
|
|
|
|
|
|
|
|
gr.Markdown("### 完整类别列表 (JSON)") |
|
|
seg_category_json = gr.JSON(value={f"ID {i}": label for i, label in enumerate(ALL_SEG_LABELS)}, |
|
|
label="所有类别 JSON") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
seg_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., road, sky)", |
|
|
scale=3) |
|
|
seg_search_button = gr.Button("🔍 搜索", scale=1) |
|
|
seg_search_output = gr.Markdown("搜索结果将显示在这里。") |
|
|
|
|
|
seg_search_button.click( |
|
|
fn=search_labels, |
|
|
inputs=[seg_search_query, seg_category_json], |
|
|
outputs=seg_search_output |
|
|
) |
|
|
|
|
|
seg_button.click(seg_predict, inputs=seg_input, outputs=seg_output) |
|
|
|
|
|
|
|
|
with gr.TabItem("🎯 目标检测 (YOLOv8n)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
det_input_image = gr.Image(type='pil', label="输入图像") |
|
|
det_input_number = gr.Number( |
|
|
precision=2, |
|
|
minimum=0.01, |
|
|
maximum=1, |
|
|
value=0.30, |
|
|
label='置信度阈值' |
|
|
) |
|
|
det_button = gr.Button("🚀 提交检测任务") |
|
|
with gr.Column(scale=1): |
|
|
det_output = gr.Image(type='pil', label="检测结果 (边界框)") |
|
|
|
|
|
|
|
|
gr.Examples(examples=DET_EXAMPLES, inputs=[det_input_image], label="示例图片") |
|
|
|
|
|
|
|
|
gr.Markdown("### 🎯 模型支持的检测类别 (COCO)") |
|
|
det_category_json = gr.JSON(value=ALL_DET_LABELS, label="所有类别列表") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
det_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., bicycle, train)", |
|
|
scale=3) |
|
|
det_search_button = gr.Button("🔍 搜索", scale=1) |
|
|
det_search_output = gr.Markdown("搜索结果将显示在这里。") |
|
|
|
|
|
det_search_button.click( |
|
|
fn=search_labels, |
|
|
inputs=[det_search_query, det_category_json], |
|
|
outputs=det_search_output |
|
|
) |
|
|
|
|
|
det_button.click(det_predict, inputs=[det_input_image, det_input_number], outputs=det_output) |
|
|
|
|
|
|
|
|
gr.HTML( |
|
|
f""" |
|
|
<div class='footer-logo-container'> |
|
|
<div class="footer-links"> |
|
|
<p>{COPYRIGHT_TEXT}</p> |
|
|
</div> |
|
|
|
|
|
<div class="footer-links"> |
|
|
|
|
|
<!-- 🌟 学校图标和名称 (添加超链接) --> |
|
|
<span class="footer-icon">🏢</span> |
|
|
<a href='https://wutinfo.whut.edu.cn/' target='_blank' style='text-decoration: none; color: inherit;'> |
|
|
<span>{SCHOOL_NAME_EN}</span> |
|
|
</a> |
|
|
|
|
|
<!-- 🌟 额外 Logo 位于版权信息之后 --> |
|
|
<img src='file/{LOGO_PATH}' alt='Logo' style='height: 30px; margin-left: 20px;' onerror="this.style.display='none'"> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
gr.close_all() |
|
|
print("Launching Gradio demo...") |
|
|
|
|
|
demo.launch(share=True, css=CUSTOM_CSS) |