hutiger's picture
Upload folder using huggingface_hub
bf5b4d8 verified
Raw
History Blame Contribute Delete
16.2 kB
"""
AI 垃圾分类助手 - 高级 Gradio UI 版本
特点:
1. 缩略图显示
2. 点击查看原图
3. Tabs 分区布局
4. 排行榜折叠
5. 更现代化卡片式 UI
6. 更清晰的视觉层级
"""
import base64
from PIL import Image
# HEIC 支持
try:
import pillow_heif
pillow_heif.register_heif_opener()
HEIF_SUPPORT = True
except Exception:
HEIF_SUPPORT = False
import random
import gradio as gr
from knowledge import get_class_info
from database import Database
# 数据库
# --------------------------------------------------
db = Database()
# 分类器懒加载
_classifier = None
TEMP_IMG = "temp_upload.jpg"
# --------------------------------------------------
# 模型加载
# --------------------------------------------------
def get_classifier():
global _classifier
if _classifier is None:
from predict import GarbageClassifier
_classifier = GarbageClassifier()
return _classifier
# --------------------------------------------------
# 图片转 base64
# --------------------------------------------------
def make_img_data_uri(path):
with open(path, "rb") as f:
return (
"data:image/jpeg;base64,"
+ base64.b64encode(f.read()).decode()
)
# --------------------------------------------------
# 主识别逻辑
# --------------------------------------------------
def classify_and_advise(image, username="default"):
if image is None:
return (
"""
<div class='empty-card'>
<h2>⚠️ 未检测到图片</h2>
<p>请先上传一张垃圾图片</p>
</div>
""",
"",
""
)
# 加载模型
try:
classifier = get_classifier()
except FileNotFoundError as e:
return (
f"""
<div class='error-card'>
<h2>❌ 模型未训练</h2>
<p>{e}</p>
<p>请先运行:</p>
<code>python main.py train</code>
</div>
""",
"",
""
)
# 推理
try:
# 自动兼容 HEIC / HEIF
if isinstance(image, str):
if image.lower().endswith((".heic", ".heif")):
if not HEIF_SUPPORT:
raise RuntimeError(
"未安装 pillow-heif,请执行: pip install pillow-heif"
)
heif_file = pillow_heif.read_heif(image)
image = Image.frombytes(
heif_file.mode,
heif_file.size,
heif_file.data,
"raw"
)
else:
image = Image.open(image)
img = image.convert("RGB")
img.save(TEMP_IMG, "JPEG")
results = classifier.predict(TEMP_IMG)
best = results[0]
info = get_class_info(best["class_name"])
except Exception as e:
return (
f"""
<div class='error-card'>
<h2>❌ 识别失败</h2>
<p>{e}</p>
</div>
""",
"",
""
)
# --------------------------------------------------
# 用户记录
# --------------------------------------------------
user_id = db.register_user(username)
points = db.add_record(
user_id,
best["class_name"],
best["confidence"]
)
stats = db.get_user_stats(user_id)
leaderboard = db.get_leaderboard(5)
# --------------------------------------------------
# 图片 data uri
# --------------------------------------------------
data_uri = make_img_data_uri(TEMP_IMG)
modal_id = f"modal-{random.randint(10000,99999)}"
# --------------------------------------------------
# 结果卡片
# --------------------------------------------------
pct = best["confidence"] * 100
progress_width = min(max(pct, 5), 100)
if pct >= 80:
result_color = "#2e7d32"
result_bg = "#e8f5e9"
elif pct >= 60:
result_color = "#ef6c00"
result_bg = "#fff3e0"
else:
result_color = "#c62828"
result_bg = "#ffebee"
result_html = f"""
<div class='result-card'>
<div class='thumb-wrapper'>
<img
src='{data_uri}'
class='thumb-image'
onclick="document.getElementById('{modal_id}').style.display='flex'"
>
<div class='thumb-text'>🔍 点击查看原图</div>
</div>
<div class='result-content'>
<div class='result-label'>AI 识别结果</div>
<div class='result-name' style='color:{result_color};'>
{best['class_name_cn']}
</div>
<div class='result-category'>
♻️ {info['category'] if info else '未知分类'}
</div>
<div class='confidence-text'>
识别置信度:{pct:.1f}%
</div>
<div class='progress-bar-bg'>
<div
class='progress-bar-fill'
style='width:{progress_width}%;background:{result_color};'>
</div>
</div>
<div class='score-badge'>
🎉 获得 +{points} 环保积分
</div>
</div>
</div>
<!-- 原图弹窗 -->
<div
id='{modal_id}'
class='image-modal'
onclick="this.style.display='none'">
<img src='{data_uri}' class='modal-image'>
<div class='modal-close'>✕</div>
</div>
"""
# --------------------------------------------------
# 投放指南
# --------------------------------------------------
disposal_html = (
info["disposal"].replace("\n", "<br>")
if info else "暂无信息"
)
tips_html = "".join(
[f"<li>💡 {t}</li>" for t in info["tips"]]
) if info else ""
knowledge_html = f"""
<div class='knowledge-card'>
<div class='knowledge-title'>
📋 {info['name_cn']} 投放指南
</div>
<div class='knowledge-body'>
{disposal_html}
</div>
<div class='tips-title'>💡 分类小贴士</div>
<ul class='tips-list'>
{tips_html}
</ul>
<div class='fun-fact'>
🎯 {info['fun_fact']}
</div>
<div class='degradation'>
⏱ 降解时间:{info['degradation_time']}
</div>
</div>
"""
# --------------------------------------------------
# 排行榜
# --------------------------------------------------
leaderboard_html = ""
for i, user in enumerate(leaderboard):
medal = ""
if i == 0:
medal = "🥇"
elif i == 1:
medal = "🥈"
elif i == 2:
medal = "🥉"
else:
medal = f"{i+1}."
leaderboard_html += f"""
<div class='leader-item'>
<span>{medal} {user['username']}</span>
<span>{user['total_points']} 分</span>
</div>
"""
stats_html = f"""
<div class='stats-card'>
<div class='stats-title'>
📊 {stats['username']} 的环保数据
</div>
<div class='stats-grid'>
<div class='stat-box'>
<div class='stat-number'>{stats['total_points']}</div>
<div class='stat-label'>总积分</div>
</div>
<div class='stat-box'>
<div class='stat-number'>{stats['total_classifications']}</div>
<div class='stat-label'>分类次数</div>
</div>
<div class='stat-box'>
<div class='stat-number'>{stats['today']['points']}</div>
<div class='stat-label'>今日积分</div>
</div>
</div>
<div class='leaderboard-title'>🏆 环保排行榜 TOP5</div>
<div class='leaderboard-list'>
{leaderboard_html}
</div>
</div>
"""
return result_html, knowledge_html, stats_html
# --------------------------------------------------
# CSS
# --------------------------------------------------
CSS = """
.gradio-container {
width: 100% !important;
max-width: 900px !important;
margin: auto !important;
overflow-x: hidden !important;
margin: auto;
}
footer {
display: none !important;
}
/* 标题 */
.main-title {
text-align:center;
padding: 10px 0 20px 0;
}
/* 提示标签 */
.class-badge {
display:inline-block;
padding:6px 14px;
border-radius:20px;
margin:4px;
font-size:13px;
font-weight:bold;
background:#f1f8e9;
border:1px solid #c5e1a5;
}
/* 上传区域 */
.upload-panel {
width: 100%;
max-width: 900px;
margin: auto;
background:white;
border-radius:18px;
padding:20px;
box-shadow:0 4px 15px rgba(0,0,0,0.06);
}
/* 结果卡片 */
.result-card {
box-sizing: border-box;
width: 100%;
max-width: 900px;
margin: 10px auto 0 auto;
display:flex;
align-items:center;
gap:20px;
background:white;
border-radius:20px;
padding:20px;
box-shadow:0 6px 18px rgba(0,0,0,0.08);
margin-top:5px;
}
.thumb-wrapper {
text-align:center;
flex-shrink:0;
}
.thumb-image {
max-width: 95px;
min-width: 95px;
width:95px;
height:95px;
object-fit:cover;
border-radius:14px;
cursor:pointer;
border:3px solid #c8e6c9;
transition:0.2s;
}
.thumb-image:hover {
transform:scale(1.05);
}
.thumb-text {
margin-top:6px;
font-size:11px;
color:#777;
}
.result-content {
flex:1;
}
.result-label {
color:#777;
font-size:13px;
}
.result-name {
font-size:34px;
font-weight:800;
margin:4px 0;
}
.result-category {
font-size:15px;
color:#555;
margin-bottom:10px;
}
.confidence-text {
font-size:13px;
margin-bottom:6px;
color:#666;
}
.progress-bar-bg {
width:100%;
height:10px;
background:#eeeeee;
border-radius:999px;
overflow:hidden;
}
.progress-bar-fill {
height:100%;
border-radius:999px;
}
.score-badge {
display:inline-block;
margin-top:12px;
padding:8px 14px;
background:#fff3e0;
border-radius:999px;
color:#ef6c00;
font-size:13px;
font-weight:bold;
}
/* 投放指南 */
.knowledge-card {
background:white;
padding:22px;
border-radius:18px;
box-shadow:0 4px 15px rgba(0,0,0,0.06);
}
.knowledge-title {
font-size:22px;
font-weight:bold;
color:#2e7d32;
margin-bottom:15px;
}
.knowledge-body {
background:#f8f9fa;
padding:16px;
border-radius:12px;
line-height:1.8;
font-size:15px;
}
.tips-title {
margin-top:18px;
font-size:17px;
font-weight:bold;
color:#ef6c00;
}
.tips-list {
margin-top:8px;
line-height:1.9;
}
.fun-fact {
margin-top:15px;
background:#fff8e1;
padding:12px;
border-radius:12px;
color:#e65100;
font-weight:bold;
}
.degradation {
margin-top:10px;
color:#777;
font-size:13px;
}
/* 环保统计 */
.stats-card {
background:white;
padding:22px;
border-radius:18px;
box-shadow:0 4px 15px rgba(0,0,0,0.06);
}
.stats-title {
font-size:22px;
font-weight:bold;
color:#1565c0;
margin-bottom:20px;
}
.stats-grid {
display:grid;
grid-template-columns:repeat(3,1fr);
gap:15px;
}
.stat-box {
background:#f5f7fa;
padding:18px;
border-radius:14px;
text-align:center;
}
.stat-number {
font-size:28px;
font-weight:bold;
color:#1565c0;
}
.stat-label {
margin-top:6px;
color:#666;
font-size:13px;
}
.leaderboard-title {
margin-top:24px;
font-size:18px;
font-weight:bold;
color:#2e7d32;
}
.leaderboard-list {
margin-top:12px;
}
.leader-item {
display:flex;
justify-content:space-between;
padding:12px 14px;
background:#f8f9fa;
border-radius:12px;
margin-bottom:10px;
font-size:14px;
}
/* 弹窗 */
.image-modal {
display:none;
position:fixed;
top:0;
left:0;
width:100%;
height:100%;
background:rgba(0,0,0,0.92);
z-index:99999;
justify-content:center;
align-items:center;
cursor:pointer;
}
.modal-image {
width: auto;
height: auto;
object-fit: contain;
max-width:90%;
max-height:90%;
border-radius:10px;
object-fit:contain;
}
.modal-close {
position:absolute;
top:20px;
right:30px;
color:white;
font-size:36px;
font-weight:bold;
}
/* 空卡片 */
.empty-card,
.error-card {
text-align:center;
padding:40px;
background:white;
border-radius:18px;
}
/* 手机端适配 */
@media (max-width:768px) {
.result-card {
flex-direction:column;
text-align:center;
}
.stats-grid {
grid-template-columns:1fr;
}
.result-name {
font-size:28px;
}
}
"""
# --------------------------------------------------
# 分类提示
# --------------------------------------------------
CLASS_HINT = """
<div style='text-align:center;margin-bottom:12px;'>
<span class='class-badge'>🥤 塑料</span>
<span class='class-badge'>📦 纸板</span>
<span class='class-badge'>📄 纸张</span>
<span class='class-badge'>🍾 玻璃</span>
<span class='class-badge'>🥫 金属</span>
<span class='class-badge'>🍂 其他垃圾</span>
<div style='margin-top:10px;color:#777;font-size:13px;'>
本系统当前支持以上 6 类垃圾识别
</div>
</div>
"""
# --------------------------------------------------
# Gradio UI
# --------------------------------------------------
with gr.Blocks(fill_width=False,
title="AI 垃圾分类助手",
theme=gr.themes.Soft(primary_hue="green"),
css=CSS
) as demo:
gr.Markdown(
"""
<div class='main-title'>
<h1>♻️ AI 垃圾分类助手</h1>
<h3>拍照识别 · 投放指南 · 环保积分</h3>
</div>
"""
)
gr.HTML(CLASS_HINT)
# 上传区域
with gr.Group(elem_classes="upload-panel"):
image_input = gr.Image(
type="pil",
label="📷 上传垃圾图片",
height=220,
elem_id="upload-image"
)
username_input = gr.Textbox(
label="👤 用户名",
value="default",
placeholder="输入用户名记录积分"
)
submit_btn = gr.Button(
"🔍 开始识别",
variant="primary",
size="lg"
)
# 识别结果区域
gr.Markdown(
"### 🤖 AI 识别结果"
)
result_output = gr.HTML(
value="""
<div class='empty-card'>
<h2>📷 等待上传图片</h2>
<p>上传垃圾图片后点击「开始识别」</p>
</div>
"""
)
# Tabs
with gr.Tabs():
with gr.Tab("📋 投放指南"):
knowledge_output = gr.HTML(
value="""
<div class='empty-card'>
等待识别结果...
</div>
"""
)
with gr.Tab("📊 环保统计"):
stats_output = gr.HTML(
value="""
<div class='empty-card'>
等待识别结果...
</div>
"""
)
# 按钮事件
submit_btn.click(
fn=classify_and_advise,
inputs=[image_input, username_input],
outputs=[
result_output,
knowledge_output,
stats_output
]
)
# --------------------------------------------------
# 启动
# --------------------------------------------------
def launch_gradio(server_port=7860):
print(
f"🌐 Gradio Web 界面: http://localhost:{server_port}"
)
demo.launch(
server_name="0.0.0.0",
server_port=server_port,
share=False
)
if __name__ == "__main__":
launch_gradio()