import gradio as gr from model import ( cls_predict, det_predict, seg_predict, ALL_SEG_LABELS, ALL_DET_LABELS, ALL_SEG_COLOR_MAP, ALL_CLS_LABELS ) import os import requests from PIL import Image from io import BytesIO import time # 用于重试 # --- 配置 Logo 和版权信息 --- LOGO_PATH = "logo/logo.png" # 请将您的 Logo 图片存放在此路径 # 🌟 更新:版权信息仅保留年份和权利声明,具体内容在 HTML 中处理 COPYRIGHT_TEXT = "© 2025 All Rights Reserved." SCHOOL_NAME_EN = "School of Information Engineering, Wuhan University of Technology" # --- 自动下载示例图片逻辑 (嵌入) --- # 🌟 恢复到用户指定的 COCO ID 列表 TASK_EXAMPLE_URLS = { "cls": [ # 图像分类示例 (val2017) "http://images.cocodataset.org/val2017/000000000285.jpg", # 猫和键盘 (保留) "http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留) "http://images.cocodataset.org/val2017/000000000724.jpg", # 食物/餐具 (保留) "http://images.cocodataset.org/val2017/000000001584.jpg", # 多人,多物体 (保留) "http://images.cocodataset.org/train2017/000000001097.jpg", # 原 cls_5 ], "seg": [ # 语义分割示例 (val2017) "http://images.cocodataset.org/val2017/000000000139.jpg", # 街景/汽车 (保留) "http://images.cocodataset.org/val2017/000000000632.jpg", # 街景/行人 (保留) "http://images.cocodataset.org/val2017/000000000885.jpg", # 滑板手 (保留) "http://images.cocodataset.org/train2017/000000000267.jpg", # 原 seg_4 "http://images.cocodataset.org/train2017/000000001140.jpg", # 原 seg_5 ], "det": [ # 目标检测示例 (val2017) "http://images.cocodataset.org/val2017/000000000785.jpg", # 交通灯/汽车/巴士 (保留) "http://images.cocodataset.org/val2017/000000001268.jpg", # 原 det_2 "http://images.cocodataset.org/train2017/000000001072.jpg", # 原 det_3 "http://images.cocodataset.org/train2017/000000000119.jpg", # 原 det_4 "http://images.cocodataset.org/train2017/000000000570.jpg", # 原 det_5 ] } # 项目期望的本地路径 OUTPUT_DIR = "examples" def download_and_save_examples(max_retries=3): """下载示例图片到本地 examples/ 目录,使用任务前缀命名,增加重试机制""" if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) total_urls = sum(len(urls) for urls in TASK_EXAMPLE_URLS.values()) print(f"🚀 检查和下载 {total_urls} 张示例图片...") # 迭代所有任务和 URL for prefix, urls in TASK_EXAMPLE_URLS.items(): for i, url in enumerate(urls): # 文件名格式:cls_1.jpg, seg_2.jpg, det_3.jpg filename = f"{prefix}_{i + 1}.jpg" filepath = os.path.join(OUTPUT_DIR, filename) if os.path.exists(filepath): continue # 跳过已存在的文件 for attempt in range(max_retries): try: # 增加更长的超时时间 response = requests.get(url, stream=True, timeout=15) response.raise_for_status() image = Image.open(BytesIO(response.content)) image.save(filepath) print(f" 成功下载并保存: {filename} (尝试 {attempt + 1}/{max_retries})") break # 成功则跳出重试循环 except requests.exceptions.RequestException as e: print(f" ⚠️ 下载 {filename} 失败 (尝试 {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: time.sleep(2) # 失败后等待2秒再重试 else: # 404 错误是 ClientError,意味着 URL 不存在 if '404 Client Error' in str(e): print(f"❌ 最终下载失败 {filename}: URL {url} 不存在 (404 错误)。") else: print(f"❌ 最终下载失败 {filename}: 请检查网络连接或 URL。") break except Exception as e: # 图像处理失败 (如 BytesIO 或 PIL 错误),停止重试 print(f"❌ 图像处理失败 {filename}: {e}") break # 立即执行下载,确保 examples 目录下的文件存在 download_and_save_examples() # 🌟 关键:创建三个独立的示例列表,用于 Gradio Examples 组件 CLS_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"cls_{i + 1}.jpg")] for i in range(5)] SEG_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"seg_{i + 1}.jpg")] for i in range(5)] DET_EXAMPLES = [[os.path.join(OUTPUT_DIR, f"det_{i + 1}.jpg")] for i in range(5)] # --- 辅助函数:生成颜色图例 HTML --- def generate_legend_html(color_map_dict): """根据颜色映射字典生成 HTML 图例""" html_content = "
" html_content += "

🎨 分割颜色图例

" if "Error" in color_map_dict: html_content += "

模型加载失败,图例不可用。

" return html_content for label, hex_color in color_map_dict.items(): html_content += f"""
{label}
""" html_content += "
" return html_content # --- 辅助函数:类别搜索逻辑 --- def search_labels(query: str, all_labels) -> str: """ 在标签列表或字典中搜索给定的查询。 all_labels 可以是 list (分类/分割) 或 dict (检测)。 """ query = query.strip().lower() if not query: return "请输入有效的查询内容。" MAX_MATCHES = 10 # 处理 List (Classification, Segmentation) if isinstance(all_labels, list): found_matches = [label for label in all_labels if query in label.lower()] if found_matches: result_list = "\n- ".join(found_matches[:MAX_MATCHES]) summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):" return f"{summary}\n- {result_list}" else: return f"❌ 未找到包含 '{query}' 的类别。" # 处理 Dict (Detection) elif isinstance(all_labels, dict): found_matches = {k: v for k, v in all_labels.items() if query in v.lower()} if found_matches: result_list = [f"ID {k}: {v}" for k, v in list(found_matches.items())[:MAX_MATCHES]] summary = f"✅ 找到 {len(found_matches)} 个匹配项 (仅显示前 {MAX_MATCHES} 个):" return f"{summary}\n- {result_list}" else: return f"❌ 未找到包含 '{query}' 的类别。" return "类别数据格式错误或未加载。" # 🌟 美化 1: Gradio Soft 主题 CSS (背景颜色调整) CUSTOM_CSS = """ /* 整体背景和卡片阴影优化 */ .gradio-container { background-color: #f7f7f7; /* 调整为柔和的浅灰色背景 */ font-family: 'Inter', system-ui, sans-serif; } /* 卡片和主内容区的美化 */ .gradio-container > div { border-radius: 12px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); /* 柔和阴影 */ transition: all 0.3s ease; } /* 主标题样式 (用于承载 Logo 和文本) */ h1 { display: flex; align-items: center; justify-content: center; font-size: 2.2em; color: #333333; /* 深色文字 */ padding: 20px 0; margin: 0; } /* 按钮和输入框圆角 */ .gr-button, .gr-textbox, .gr-number, .gr-image { border-radius: 8px !important; } /* 标签和组件背景 */ .gradio-container > div:not(.prose):not(.gr-row) { background: white; padding: 15px; } /* 页脚 Logo 样式 */ .footer-logo-container { display: flex; flex-direction: column; align-items: center; justify-content: center; text-align: center; padding-top: 15px; } /* 页脚图标/链接容器样式 */ .footer-links { margin-top: 10px; font-size: 14px; color: #555; display: flex; gap: 15px; /* 图标之间的间隔 */ align-items: center; /* 确保内容居中 */ justify-content: center; } /* 图标样式 */ .footer-icon { font-size: 18px; vertical-align: middle; } /* 链接颜色 */ .footer-links a { color: #555; text-decoration: none; } .footer-links a:hover { color: #333; } """ with gr.Blocks( title="AI基础模型视觉任务演示平台", ) as demo: # 注入 Favicon (网页选项卡图标) gr.HTML(f""" """, visible=False) # 🌟 主标题区域:恢复机器人图标 gr.Markdown("

🤖 AI基础模型视觉任务演示平台

") gr.Markdown("---") # 🌟 简化功能说明 (只关注使用方法) with gr.Accordion("📚 功能说明"): gr.Markdown(""" 本平台支持图像分类、语义分割和目标检测三大任务。 您可以通过以下步骤使用平台: 1. **切换选项卡**:选择您希望执行的 AI 任务。 2. **上传或选择图片**:上传您自己的图片或点击下方的示例图片。 3. **设置参数**:对于目标检测,调整置信度阈值。 4. **点击提交**:点击“提交任务”按钮,查看 AI 分析结果。 """) # 🌟 新增:数据集介绍 with gr.Accordion("📚 基础数据集介绍", open=False): gr.Markdown(""" ### 📖 模型训练数据集概览 | 任务 | 模型 | 数据集 | 类别数 | 简介 | | :--- | :--- | :--- | :--- | :--- | | 图像分类 | ViT | **ImageNet-1K** | 1000 | 包含超过 100 万张图像,是图像识别领域的标准基准。 | | 语义分割 | SegFormer | **ADE20K** | 150 | 专注于场景解析,提供 150 种语义概念的像素级标注。 | | 目标检测 | YOLOv8n | **COCO** | 80 | 最常用的目标检测数据集之一,包含大量物体实例。 | """) # 🌟 新增:网络结构介绍 with gr.Accordion("🧠 网络结构介绍", open=False): gr.Markdown(""" ### 💻 模型架构说明 1. **图像分类 (ViT):** * **全称:** Vision Transformer (ViT-Base-Patch16-224) * **特点:** 基于 Transformer 结构,将图像切片后进行序列输入,通过自注意力机制实现全局建模。 2. **语义分割 (SegFormer):** * **全称:** Segmentation Transformer * **特点:** 高效的 Transformer 架构,使用轻量级解码器,专注于速度和准确性的平衡。 3. **目标检测 (YOLOv8n):** * **全称:** You Only Look Once, Version 8 (Nano) * **特点:** 单阶段检测器,以速度著称,Nano (n) 版本在保持高性能的同时,体积最小。 """) # --- 任务选项卡 --- with gr.Tabs(): # 1. 图像分类 Tab with gr.TabItem("🖼️ 图像分类 (ViT)"): with gr.Row(): with gr.Column(scale=1): cls_input = gr.Image(type='pil', label="输入图像") cls_button = gr.Button("🚀 提交分类任务") with gr.Column(scale=1): cls_output = gr.Label(num_top_classes=5, label="分类结果 (前 5)") # 🌟 调整顺序:Examples 先于 类别列表/查询 UI gr.Examples(examples=CLS_EXAMPLES, inputs=[cls_input], label="示例图片") # 🌟 展示所有分类类别列表 gr.Markdown("### 🌟 模型支持的全部分类类别 (ImageNet-1K)") cls_category_json = gr.JSON(value=ALL_CLS_LABELS, label="所有类别列表", scale=1) # 🌟 查询 UI with gr.Row(): cls_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., dog)", scale=3) cls_search_button = gr.Button("🔍 搜索", scale=1) cls_search_output = gr.Markdown("搜索结果将显示在这里。") cls_search_button.click( fn=search_labels, inputs=[cls_search_query, cls_category_json], outputs=cls_search_output ) cls_button.click(cls_predict, inputs=cls_input, outputs=cls_output) # 2. 语义分割 Tab with gr.TabItem("✂️ 语义分割 (SegFormer)"): with gr.Row(): with gr.Column(scale=2): seg_input = gr.Image(type='pil', label="输入图像") seg_button = gr.Button("🚀 提交分割任务") with gr.Column(scale=2): seg_output = gr.Image(type='pil', label="分割结果 (叠加)") with gr.Column(scale=1): # 🌟 展示图例 gr.HTML(value=generate_legend_html(ALL_SEG_COLOR_MAP), scale=1) # 🌟 调整顺序:Examples 先于 类别列表/查询 UI gr.Examples(examples=SEG_EXAMPLES, inputs=[seg_input], label="示例图片") # 保留完整的类别列表(以 JSON 格式展示,作为额外的参考) gr.Markdown("### 完整类别列表 (JSON)") seg_category_json = gr.JSON(value={f"ID {i}": label for i, label in enumerate(ALL_SEG_LABELS)}, label="所有类别 JSON") # 🌟 查询 UI 提示改为英文 with gr.Row(): seg_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., road, sky)", scale=3) seg_search_button = gr.Button("🔍 搜索", scale=1) seg_search_output = gr.Markdown("搜索结果将显示在这里。") seg_search_button.click( fn=search_labels, inputs=[seg_search_query, seg_category_json], outputs=seg_search_output ) seg_button.click(seg_predict, inputs=seg_input, outputs=seg_output) # 3. 目标检测 Tab with gr.TabItem("🎯 目标检测 (YOLOv8n)"): with gr.Row(): with gr.Column(scale=1): det_input_image = gr.Image(type='pil', label="输入图像") det_input_number = gr.Number( precision=2, minimum=0.01, maximum=1, value=0.30, label='置信度阈值' ) det_button = gr.Button("🚀 提交检测任务") with gr.Column(scale=1): det_output = gr.Image(type='pil', label="检测结果 (边界框)") # 🌟 调整顺序:Examples 先于 类别列表/查询 UI gr.Examples(examples=DET_EXAMPLES, inputs=[det_input_image], label="示例图片") # 🌟 展示目标检测类别列表 gr.Markdown("### 🎯 模型支持的检测类别 (COCO)") det_category_json = gr.JSON(value=ALL_DET_LABELS, label="所有类别列表") # 🌟 查询 UI 提示改为英文 with gr.Row(): det_search_query = gr.Textbox(label="查询类别", placeholder="Search Class Name (e.g., bicycle, train)", scale=3) det_search_button = gr.Button("🔍 搜索", scale=1) det_search_output = gr.Markdown("搜索结果将显示在这里。") det_search_button.click( fn=search_labels, inputs=[det_search_query, det_category_json], outputs=det_search_output ) det_button.click(det_predict, inputs=[det_input_image, det_input_number], outputs=det_output) # 🌟 添加页脚和 Logo/版权 gr.HTML( f""" """ ) if __name__ == "__main__": gr.close_all() print("Launching Gradio demo...") # 🌟 传入 css 参数 demo.launch(share=True, css=CUSTOM_CSS)