ai4vision / app.py
2ephyrh's picture
Update app.py
aa2a48a verified
import gradio as gr
from model import (
cls_predict,
det_predict,
seg_predict,
ALL_SEG_LABELS,
ALL_DET_LABELS,
ALL_SEG_COLOR_MAP,
ALL_CLS_LABELS
)
import os
import requests
from PIL import Image
from io import BytesIO
import time # 用于重试
# --- 配置 Logo 和版权信息 ---
LOGO_PATH = "logo/logo.png" # 请将您的 Logo 图片存放在此路径
# 🌟 更新:版权信息仅保留年份和权利声明,具体内容在 HTML 中处理
COPYRIGHT_TEXT = "© 2025 All Rights Reserved."
SCHOOL_NAME_EN = "School of Information Engineering, Wuhan University of Technology"
# --- 自动下载示例图片逻辑 (嵌入) ---
# 🌟 恢复到用户指定的 COCO ID 列表
TASK_EXAMPLE_URLS = {
"cls": [
# 图像分类示例 (val2017)
"http://images.cocodataset.org/val2017/000000000285.jpg", # 猫和键盘 (保留)
"http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留)
"http://images.cocodataset.org/val2017/000000000724.jpg", # 食物/餐具 (保留)
"http://images.cocodataset.org/val2017/000000001584.jpg", # 多人,多物体 (保留)
"http://images.cocodataset.org/train2017/000000001097.jpg", # 原 cls_5
],
"seg": [
# 语义分割示例 (val2017)
"http://images.cocodataset.org/val2017/000000000139.jpg", # 街景/汽车 (保留)
"http://images.cocodataset.org/val2017/000000000632.jpg", # 街景/行人 (保留)
"http://images.cocodataset.org/val2017/000000000885.jpg", # 滑板手 (保留)
"http://images.cocodataset.org/train2017/000000000267.jpg", # 原 seg_4
"http://images.cocodataset.org/train2017/000000001140.jpg", # 原 seg_5
],
"det": [
# 目标检测示例 (val2017)
"http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留)
"http://images.cocodataset.org/val2017/000000001268.jpg", # 原 det_2
"http://images.cocodataset.org/train2017/000000001072.jpg", # 原 det_3
"http://images.cocodataset.org/train2017/000000000119.jpg", # 原 det_4
"http://images.cocodataset.org/train2017/000000000570.jpg", # 原 det_5
]
}
# 项目期望的本地路径
OUTPUT_DIR = "examples"
def download_and_save_examples(max_retries=3):
"""下载示例图片到本地 examples/ 目录,使用任务前缀命名,增加重试机制"""
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
total_urls = sum(len(urls) for urls in TASK_EXAMPLE_URLS.values())
print(f"🚀 检查和下载 {total_urls} 张示例图片...")
# 迭代所有任务和 URL
for prefix, urls in TASK_EXAMPLE_URLS.items():
for i, url in enumerate(urls):
# 文件名格式:cls_1.jpg, seg_2.jpg, det_3.jpg
filename = f"{prefix}_{i + 1}.jpg"
filepath = os.path.join(OUTPUT_DIR, filename)
if os.path.exists(filepath):
continue # 跳过已存在的文件
for attempt in range(max_retries):
try:
# 增加更长的超时时间
response = requests.get(url, stream=True, timeout=15)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
image.save(filepath)
print(f" 成功下载并保存: {filename} (尝试 {attempt + 1}/{max_retries})")
break # 成功则跳出重试循环
except requests.exceptions.RequestException as e:
print(f" ⚠️ 下载 {filename} 失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(2) # 失败后等待2秒再重试
else:
# 404 错误是 ClientError,意味着 URL 不存在
if '404 Client Error' in str(e):
print(f"❌ 最终下载失败 {filename}: URL {url} 不存在 (404 错误)。")
else:
print(f"❌ 最终下载失败 {filename}: 请检查网络连接或 URL。")
break
except Exception as e:
# 图像处理失败 (如 BytesIO 或 PIL 错误),停止重试
print(f"❌ 图像处理失败 {filename}: {e}")
break
# 立即执行下载,确保 examples 目录下的文件存在
download_and_save_examples()
# 🌟 关键:创建三个独立的示例列表,用于 Gradio Examples 组件
CLS_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"cls_{i + 1}.jpg")] for i in range(5)]
SEG_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"seg_{i + 1}.jpg")] for i in range(5)]
DET_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"det_{i + 1}.jpg")] for i in range(5)]
# --- 辅助函数:生成颜色图例 HTML ---
def generate_legend_html(color_map_dict):
"""根据颜色映射字典生成 HTML 图例"""
html_content = "<div style='max-height: 300px; overflow-y: scroll; padding: 10px; border: 1px solid #ccc; background-color: #f7f7f7; border-radius: 8px;'>"
html_content += "<h4 style='margin-top: 0; color: #333;'>🎨 分割颜色图例</h4>"
if "Error" in color_map_dict:
html_content += "<p style='color: red;'>模型加载失败,图例不可用。</p>"
return html_content
for label, hex_color in color_map_dict.items():
html_content += f"""
<div style='display: flex; align-items: center; margin-bottom: 5px; font-family: sans-serif;'>
<div style='width: 15px; height: 15px; background-color: {hex_color}; border: 1px solid #333; margin-right: 10px; border-radius: 3px;'></div>
<span style='font-size: 14px; color: #555;'>{label}</span>
</div>
"""
html_content += "</div>"
return html_content
# --- 辅助函数:类别搜索逻辑 ---
def search_labels(query: str, all_labels) -> str:
"""
在标签列表或字典中搜索给定的查询。
all_labels 可以是 list (分类/分割) 或 dict (检测)。
"""
query = query.strip().lower()
if not query:
return "请输入有效的查询内容。"
MAX_MATCHES = 10
# 处理 List (Classification, Segmentation)
if isinstance(all_labels, list):
found_matches = [label for label in all_labels if query in label.lower()]
if found_matches:
result_list = "\n- ".join(found_matches[:MAX_MATCHES])
summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):"
return f"{summary}\n- {result_list}"
else:
return f"❌ 未找到包含 '{query}' 的类别。"
# 处理 Dict (Detection)
elif isinstance(all_labels, dict):
found_matches = {k: v for k, v in all_labels.items() if query in v.lower()}
if found_matches:
result_list = [f"ID {k}: {v}" for k, v in list(found_matches.items())[:MAX_MATCHES]]
summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):"
return f"{summary}\n- {result_list}"
else:
return f"❌ 未找到包含 '{query}' 的类别。"
return "类别数据格式错误或未加载。"
# 🌟 美化 1: Gradio Soft 主题 CSS (背景颜色调整)
CUSTOM_CSS = """
/* 整体背景和卡片阴影优化 */
.gradio-container {
background-color: #f7f7f7; /* 调整为柔和的浅灰色背景 */
font-family: 'Inter', system-ui, sans-serif;
}
/* 卡片和主内容区的美化 */
.gradio-container > div {
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* 柔和阴影 */
transition: all 0.3s ease;
}
/* 主标题样式 (用于承载 Logo 和文本) */
h1 {
display: flex;
align-items: center;
justify-content: center;
font-size: 2.2em;
color: #333333; /* 深色文字 */
padding: 20px 0;
margin: 0;
}
/* 按钮和输入框圆角 */
.gr-button, .gr-textbox, .gr-number, .gr-image {
border-radius: 8px !important;
}
/* 标签和组件背景 */
.gradio-container > div:not(.prose):not(.gr-row) {
background: white;
padding: 15px;
}
/* 页脚 Logo 样式 */
.footer-logo-container {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
text-align: center;
padding-top: 15px;
}
/* 页脚图标/链接容器样式 */
.footer-links {
margin-top: 10px;
font-size: 14px;
color: #555;
display: flex;
gap: 15px; /* 图标之间的间隔 */
align-items: center;
/* 确保内容居中 */
justify-content: center;
}
/* 图标样式 */
.footer-icon {
font-size: 18px;
vertical-align: middle;
}
/* 链接颜色 */
.footer-links a {
color: #555;
text-decoration: none;
}
.footer-links a:hover {
color: #333;
}
"""
with gr.Blocks(
title="AI基础模型视觉任务演示平台",
) as demo:
# 注入 Favicon (网页选项卡图标)
gr.HTML(f"""
<head>
<link rel='icon' type='image/png' href='file/{LOGO_PATH}'/>
</head>
""", visible=False)
# 🌟 主标题区域:恢复机器人图标
gr.Markdown("<h1>🤖 AI基础模型视觉任务演示平台</h1>")
gr.Markdown("---")
# 🌟 简化功能说明 (只关注使用方法)
with gr.Accordion("📚 功能说明"):
gr.Markdown("""
本平台支持图像分类、语义分割和目标检测三大任务。
您可以通过以下步骤使用平台:
1. **切换选项卡**:选择您希望执行的 AI 任务。
2. **上传或选择图片**:上传您自己的图片或点击下方的示例图片。
3. **设置参数**:对于目标检测,调整置信度阈值。
4. **点击提交**:点击“提交任务”按钮,查看 AI 分析结果。
""")
# 🌟 新增:数据集介绍
with gr.Accordion("📚 基础数据集介绍", open=False):
gr.Markdown("""
### 📖 模型训练数据集概览
| 任务 | 模型 | 数据集 | 类别数 | 简介 |
| :--- | :--- | :--- | :--- | :--- |
| 图像分类 | ViT | **ImageNet-1K** | 1000 | 包含超过 100 万张图像,是图像识别领域的标准基准。 |
| 语义分割 | SegFormer | **ADE20K** | 150 | 专注于场景解析,提供 150 种语义概念的像素级标注。 |
| 目标检测 | YOLOv8n | **COCO** | 80 | 最常用的目标检测数据集之一,包含大量物体实例。 |
""")
# 🌟 新增:网络结构介绍
with gr.Accordion("🧠 网络结构介绍", open=False):
gr.Markdown("""
### 💻 模型架构说明
1. **图像分类 (ViT):**
* **全称:** Vision Transformer (ViT-Base-Patch16-224)
* **特点:** 基于 Transformer 结构,将图像切片后进行序列输入,通过自注意力机制实现全局建模。
2. **语义分割 (SegFormer):**
* **全称:** Segmentation Transformer
* **特点:** 高效的 Transformer 架构,使用轻量级解码器,专注于速度和准确性的平衡。
3. **目标检测 (YOLOv8n):**
* **全称:** You Only Look Once, Version 8 (Nano)
* **特点:** 单阶段检测器,以速度著称,Nano (n) 版本在保持高性能的同时,体积最小。
""")
# --- 任务选项卡 ---
with gr.Tabs():
# 1. 图像分类 Tab
with gr.TabItem("🖼️ 图像分类 (ViT)"):
with gr.Row():
with gr.Column(scale=1):
cls_input = gr.Image(type='pil', label="输入图像")
cls_button = gr.Button("🚀 提交分类任务")
with gr.Column(scale=1):
cls_output = gr.Label(num_top_classes=5, label="分类结果 (前 5)")
# 🌟 调整顺序:Examples 先于 类别列表/查询 UI
gr.Examples(examples=CLS_EXAMPLES, inputs=[cls_input], label="示例图片")
# 🌟 展示所有分类类别列表
gr.Markdown("### 🌟 模型支持的全部分类类别 (ImageNet-1K)")
cls_category_json = gr.JSON(value=ALL_CLS_LABELS, label="所有类别列表", scale=1)
# 🌟 查询 UI
with gr.Row():
cls_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., dog)", scale=3)
cls_search_button = gr.Button("🔍 搜索", scale=1)
cls_search_output = gr.Markdown("搜索结果将显示在这里。")
cls_search_button.click(
fn=search_labels,
inputs=[cls_search_query, cls_category_json],
outputs=cls_search_output
)
cls_button.click(cls_predict, inputs=cls_input, outputs=cls_output)
# 2. 语义分割 Tab
with gr.TabItem("✂️ 语义分割 (SegFormer)"):
with gr.Row():
with gr.Column(scale=2):
seg_input = gr.Image(type='pil', label="输入图像")
seg_button = gr.Button("🚀 提交分割任务")
with gr.Column(scale=2):
seg_output = gr.Image(type='pil', label="分割结果 (叠加)")
with gr.Column(scale=1):
# 🌟 展示图例
gr.HTML(value=generate_legend_html(ALL_SEG_COLOR_MAP), scale=1)
# 🌟 调整顺序:Examples 先于 类别列表/查询 UI
gr.Examples(examples=SEG_EXAMPLES, inputs=[seg_input], label="示例图片")
# 保留完整的类别列表(以 JSON 格式展示,作为额外的参考)
gr.Markdown("### 完整类别列表 (JSON)")
seg_category_json = gr.JSON(value={f"ID {i}": label for i, label in enumerate(ALL_SEG_LABELS)},
label="所有类别 JSON")
# 🌟 查询 UI 提示改为英文
with gr.Row():
seg_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., road, sky)",
scale=3)
seg_search_button = gr.Button("🔍 搜索", scale=1)
seg_search_output = gr.Markdown("搜索结果将显示在这里。")
seg_search_button.click(
fn=search_labels,
inputs=[seg_search_query, seg_category_json],
outputs=seg_search_output
)
seg_button.click(seg_predict, inputs=seg_input, outputs=seg_output)
# 3. 目标检测 Tab
with gr.TabItem("🎯 目标检测 (YOLOv8n)"):
with gr.Row():
with gr.Column(scale=1):
det_input_image = gr.Image(type='pil', label="输入图像")
det_input_number = gr.Number(
precision=2,
minimum=0.01,
maximum=1,
value=0.30,
label='置信度阈值'
)
det_button = gr.Button("🚀 提交检测任务")
with gr.Column(scale=1):
det_output = gr.Image(type='pil', label="检测结果 (边界框)")
# 🌟 调整顺序:Examples 先于 类别列表/查询 UI
gr.Examples(examples=DET_EXAMPLES, inputs=[det_input_image], label="示例图片")
# 🌟 展示目标检测类别列表
gr.Markdown("### 🎯 模型支持的检测类别 (COCO)")
det_category_json = gr.JSON(value=ALL_DET_LABELS, label="所有类别列表")
# 🌟 查询 UI 提示改为英文
with gr.Row():
det_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., bicycle, train)",
scale=3)
det_search_button = gr.Button("🔍 搜索", scale=1)
det_search_output = gr.Markdown("搜索结果将显示在这里。")
det_search_button.click(
fn=search_labels,
inputs=[det_search_query, det_category_json],
outputs=det_search_output
)
det_button.click(det_predict, inputs=[det_input_image, det_input_number], outputs=det_output)
# 🌟 添加页脚和 Logo/版权
gr.HTML(
f"""
<div class='footer-logo-container'>
<div class="footer-links">
<p>{COPYRIGHT_TEXT}</p>
</div>
<div class="footer-links">
<!-- 🌟 学校图标和名称 (添加超链接) -->
<span class="footer-icon">🏢</span>
<a href='https://wutinfo.whut.edu.cn/' target='_blank' style='text-decoration: none; color: inherit;'>
<span>{SCHOOL_NAME_EN}</span>
</a>
<!-- 🌟 额外 Logo 位于版权信息之后 -->
<img src='file/{LOGO_PATH}' alt='Logo' style='height: 30px; margin-left: 20px;' onerror="this.style.display='none'">
</div>
</div>
"""
)
if __name__ == "__main__":
gr.close_all()
print("Launching Gradio demo...")
# 🌟 传入 css 参数
demo.launch(share=True, css=CUSTOM_CSS)