""" AI 垃圾分类助手 - 高级 Gradio UI 版本 特点: 1. 缩略图显示 2. 点击查看原图 3. Tabs 分区布局 4. 排行榜折叠 5. 更现代化卡片式 UI 6. 更清晰的视觉层级 """ import base64 from PIL import Image # HEIC 支持 try: import pillow_heif pillow_heif.register_heif_opener() HEIF_SUPPORT = True except Exception: HEIF_SUPPORT = False import random import gradio as gr from knowledge import get_class_info from database import Database # 数据库 # -------------------------------------------------- db = Database() # 分类器懒加载 _classifier = None TEMP_IMG = "temp_upload.jpg" # -------------------------------------------------- # 模型加载 # -------------------------------------------------- def get_classifier(): global _classifier if _classifier is None: from predict import GarbageClassifier _classifier = GarbageClassifier() return _classifier # -------------------------------------------------- # 图片转 base64 # -------------------------------------------------- def make_img_data_uri(path): with open(path, "rb") as f: return ( "data:image/jpeg;base64," + base64.b64encode(f.read()).decode() ) # -------------------------------------------------- # 主识别逻辑 # -------------------------------------------------- def classify_and_advise(image, username="default"): if image is None: return ( """

⚠️ 未检测到图片

请先上传一张垃圾图片

""", "", "" ) # 加载模型 try: classifier = get_classifier() except FileNotFoundError as e: return ( f"""

❌ 模型未训练

{e}

请先运行:

python main.py train
""", "", "" ) # 推理 try: # 自动兼容 HEIC / HEIF if isinstance(image, str): if image.lower().endswith((".heic", ".heif")): if not HEIF_SUPPORT: raise RuntimeError( "未安装 pillow-heif,请执行: pip install pillow-heif" ) heif_file = pillow_heif.read_heif(image) image = Image.frombytes( heif_file.mode, heif_file.size, heif_file.data, "raw" ) else: image = Image.open(image) img = image.convert("RGB") img.save(TEMP_IMG, "JPEG") results = classifier.predict(TEMP_IMG) best = results[0] info = get_class_info(best["class_name"]) except Exception as e: return ( f"""

❌ 识别失败

{e}

""", "", "" ) # -------------------------------------------------- # 用户记录 # -------------------------------------------------- user_id = db.register_user(username) points = db.add_record( user_id, best["class_name"], best["confidence"] ) stats = db.get_user_stats(user_id) leaderboard = db.get_leaderboard(5) # -------------------------------------------------- # 图片 data uri # -------------------------------------------------- data_uri = make_img_data_uri(TEMP_IMG) modal_id = f"modal-{random.randint(10000,99999)}" # -------------------------------------------------- # 结果卡片 # -------------------------------------------------- pct = best["confidence"] * 100 progress_width = min(max(pct, 5), 100) if pct >= 80: result_color = "#2e7d32" result_bg = "#e8f5e9" elif pct >= 60: result_color = "#ef6c00" result_bg = "#fff3e0" else: result_color = "#c62828" result_bg = "#ffebee" result_html = f"""
🔍 点击查看原图
AI 识别结果
{best['class_name_cn']}
♻️ {info['category'] if info else '未知分类'}
识别置信度:{pct:.1f}%
🎉 获得 +{points} 环保积分
""" # -------------------------------------------------- # 投放指南 # -------------------------------------------------- disposal_html = ( info["disposal"].replace("\n", "
") if info else "暂无信息" ) tips_html = "".join( [f"
  • 💡 {t}
  • " for t in info["tips"]] ) if info else "" knowledge_html = f"""
    📋 {info['name_cn']} 投放指南
    {disposal_html}
    💡 分类小贴士
    🎯 {info['fun_fact']}
    ⏱ 降解时间:{info['degradation_time']}
    """ # -------------------------------------------------- # 排行榜 # -------------------------------------------------- leaderboard_html = "" for i, user in enumerate(leaderboard): medal = "" if i == 0: medal = "🥇" elif i == 1: medal = "🥈" elif i == 2: medal = "🥉" else: medal = f"{i+1}." leaderboard_html += f"""
    {medal} {user['username']} {user['total_points']} 分
    """ stats_html = f"""
    📊 {stats['username']} 的环保数据
    {stats['total_points']}
    总积分
    {stats['total_classifications']}
    分类次数
    {stats['today']['points']}
    今日积分
    🏆 环保排行榜 TOP5
    {leaderboard_html}
    """ return result_html, knowledge_html, stats_html # -------------------------------------------------- # CSS # -------------------------------------------------- CSS = """ .gradio-container { width: 100% !important; max-width: 900px !important; margin: auto !important; overflow-x: hidden !important; margin: auto; } footer { display: none !important; } /* 标题 */ .main-title { text-align:center; padding: 10px 0 20px 0; } /* 提示标签 */ .class-badge { display:inline-block; padding:6px 14px; border-radius:20px; margin:4px; font-size:13px; font-weight:bold; background:#f1f8e9; border:1px solid #c5e1a5; } /* 上传区域 */ .upload-panel { width: 100%; max-width: 900px; margin: auto; background:white; border-radius:18px; padding:20px; box-shadow:0 4px 15px rgba(0,0,0,0.06); } /* 结果卡片 */ .result-card { box-sizing: border-box; width: 100%; max-width: 900px; margin: 10px auto 0 auto; display:flex; align-items:center; gap:20px; background:white; border-radius:20px; padding:20px; box-shadow:0 6px 18px rgba(0,0,0,0.08); margin-top:5px; } .thumb-wrapper { text-align:center; flex-shrink:0; } .thumb-image { max-width: 95px; min-width: 95px; width:95px; height:95px; object-fit:cover; border-radius:14px; cursor:pointer; border:3px solid #c8e6c9; transition:0.2s; } .thumb-image:hover { transform:scale(1.05); } .thumb-text { margin-top:6px; font-size:11px; color:#777; } .result-content { flex:1; } .result-label { color:#777; font-size:13px; } .result-name { font-size:34px; font-weight:800; margin:4px 0; } .result-category { font-size:15px; color:#555; margin-bottom:10px; } .confidence-text { font-size:13px; margin-bottom:6px; color:#666; } .progress-bar-bg { width:100%; height:10px; background:#eeeeee; border-radius:999px; overflow:hidden; } .progress-bar-fill { height:100%; border-radius:999px; } .score-badge { display:inline-block; margin-top:12px; padding:8px 14px; background:#fff3e0; border-radius:999px; color:#ef6c00; font-size:13px; font-weight:bold; } /* 投放指南 */ .knowledge-card { background:white; padding:22px; border-radius:18px; box-shadow:0 4px 15px rgba(0,0,0,0.06); } .knowledge-title { font-size:22px; font-weight:bold; color:#2e7d32; margin-bottom:15px; } .knowledge-body { background:#f8f9fa; padding:16px; border-radius:12px; line-height:1.8; font-size:15px; } .tips-title { margin-top:18px; font-size:17px; font-weight:bold; color:#ef6c00; } .tips-list { margin-top:8px; line-height:1.9; } .fun-fact { margin-top:15px; background:#fff8e1; padding:12px; border-radius:12px; color:#e65100; font-weight:bold; } .degradation { margin-top:10px; color:#777; font-size:13px; } /* 环保统计 */ .stats-card { background:white; padding:22px; border-radius:18px; box-shadow:0 4px 15px rgba(0,0,0,0.06); } .stats-title { font-size:22px; font-weight:bold; color:#1565c0; margin-bottom:20px; } .stats-grid { display:grid; grid-template-columns:repeat(3,1fr); gap:15px; } .stat-box { background:#f5f7fa; padding:18px; border-radius:14px; text-align:center; } .stat-number { font-size:28px; font-weight:bold; color:#1565c0; } .stat-label { margin-top:6px; color:#666; font-size:13px; } .leaderboard-title { margin-top:24px; font-size:18px; font-weight:bold; color:#2e7d32; } .leaderboard-list { margin-top:12px; } .leader-item { display:flex; justify-content:space-between; padding:12px 14px; background:#f8f9fa; border-radius:12px; margin-bottom:10px; font-size:14px; } /* 弹窗 */ .image-modal { display:none; position:fixed; top:0; left:0; width:100%; height:100%; background:rgba(0,0,0,0.92); z-index:99999; justify-content:center; align-items:center; cursor:pointer; } .modal-image { width: auto; height: auto; object-fit: contain; max-width:90%; max-height:90%; border-radius:10px; object-fit:contain; } .modal-close { position:absolute; top:20px; right:30px; color:white; font-size:36px; font-weight:bold; } /* 空卡片 */ .empty-card, .error-card { text-align:center; padding:40px; background:white; border-radius:18px; } /* 手机端适配 */ @media (max-width:768px) { .result-card { flex-direction:column; text-align:center; } .stats-grid { grid-template-columns:1fr; } .result-name { font-size:28px; } } """ # -------------------------------------------------- # 分类提示 # -------------------------------------------------- CLASS_HINT = """
    🥤 塑料 📦 纸板 📄 纸张 🍾 玻璃 🥫 金属 🍂 其他垃圾
    本系统当前支持以上 6 类垃圾识别
    """ # -------------------------------------------------- # Gradio UI # -------------------------------------------------- with gr.Blocks(fill_width=False, title="AI 垃圾分类助手", theme=gr.themes.Soft(primary_hue="green"), css=CSS ) as demo: gr.Markdown( """

    ♻️ AI 垃圾分类助手

    拍照识别 · 投放指南 · 环保积分

    """ ) gr.HTML(CLASS_HINT) # 上传区域 with gr.Group(elem_classes="upload-panel"): image_input = gr.Image( type="pil", label="📷 上传垃圾图片", height=220, elem_id="upload-image" ) username_input = gr.Textbox( label="👤 用户名", value="default", placeholder="输入用户名记录积分" ) submit_btn = gr.Button( "🔍 开始识别", variant="primary", size="lg" ) # 识别结果区域 gr.Markdown( "### 🤖 AI 识别结果" ) result_output = gr.HTML( value="""

    📷 等待上传图片

    上传垃圾图片后点击「开始识别」

    """ ) # Tabs with gr.Tabs(): with gr.Tab("📋 投放指南"): knowledge_output = gr.HTML( value="""
    等待识别结果...
    """ ) with gr.Tab("📊 环保统计"): stats_output = gr.HTML( value="""
    等待识别结果...
    """ ) # 按钮事件 submit_btn.click( fn=classify_and_advise, inputs=[image_input, username_input], outputs=[ result_output, knowledge_output, stats_output ] ) # -------------------------------------------------- # 启动 # -------------------------------------------------- def launch_gradio(server_port=7860): print( f"🌐 Gradio Web 界面: http://localhost:{server_port}" ) demo.launch( server_name="0.0.0.0", server_port=server_port, share=False ) if __name__ == "__main__": launch_gradio()