Spaces:
Sleeping
Sleeping
File size: 7,395 Bytes
bf5b4d8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | #!/usr/bin/env python3
"""
AI 垃圾分类助手 - 环保小能手
主入口程序 (CLI)
"""
import argparse
from pathlib import Path
from config import DATASET_DIR, MODEL_DIR, MODEL_PATH
def cmd_train(args):
from train import train
train(args)
def cmd_predict(args):
from predict import GarbageClassifier
from knowledge import get_class_info
classifier = GarbageClassifier(model_path=args.model_path)
image_path = Path(args.image)
if not image_path.exists():
print(f"✗ 图片不存在: {image_path}")
return
print(f"正在识别: {image_path.name}")
results = classifier.predict(str(image_path), top_k=args.top_k)
print(f"\n{'='*40}")
print("识别结果:")
print(f"{'='*40}")
for i, r in enumerate(results, 1):
pct = r["confidence"] * 100
bar = "█" * int(pct // 5) + "░" * (20 - int(pct // 5))
print(f" {i}. [{bar}] {r['class_name_cn']} ({pct:.1f}%)")
if args.detail:
info = get_class_info(results[0]["class_name"])
if info:
print(f"\n{'='*40}")
print(f"垃圾分类指南 - {info['name_cn']}")
print(f"分类: {info['category']}")
print(f"{'='*40}")
print(info["disposal"])
print(f"\n💡 {info['fun_fact']}")
def cmd_query(args):
from knowledge import search_knowledge
results = search_knowledge(args.keyword)
if not results:
print(f"未找到与 '{args.keyword}' 相关的知识")
return
for _, info in results:
print(f"\n{'='*50}")
print(f"{info['name_cn']} | 类别: {info['category']}")
print(f"{'='*50}")
print(info["description"])
print(f"\n📋 投放方法:")
print(info["disposal"])
print(f"\n💡 小贴士:")
for tip in info["tips"]:
print(f" • {tip}")
print(f"\n🎯 {info['fun_fact']}")
print(f"⏱ 降解时间: {info['degradation_time']}")
def cmd_record(args):
from database import Database
from predict import GarbageClassifier
from knowledge import get_class_info
db = Database()
user_id = db.register_user(args.username)
classifier = GarbageClassifier(model_path=args.model_path)
image_path = Path(args.image)
if not image_path.exists():
print(f"✗ 图片不存在: {image_path}")
return
results = classifier.predict(str(image_path))
best = results[0]
points = db.add_record(user_id, best["class_name"], best["confidence"])
print(f"\n✓ 已记录! {best['class_name_cn']} (置信度: {best['confidence']*100:.1f}%)")
print(f" +{points} 环保积分!")
info = get_class_info(best["class_name"])
if info:
print(f"\n📋 投放提示: {info['disposal'].split(chr(10))[0]}")
def cmd_stats(args):
from database import Database
from knowledge import KNOWLEDGE_BASE
db = Database()
user = db.get_user(args.username)
if not user:
print(f"用户 '{args.username}' 不存在,请先使用 record 命令")
return
stats = db.get_user_stats(user["id"])
print(f"\n{'='*50}")
print(f" 环保统计 - {stats['username']}")
print(f"{'='*50}")
print(f" 总分类次数: {stats['total_classifications']}")
print(f" 总环保积分: {stats['total_points']}")
print(f" 今日分类: {stats['today']['count']} 次")
print(f" 今日积分: {stats['today']['points']}")
if stats["class_distribution"]:
print(f"\n 各类别分类统计:")
for item in stats["class_distribution"]:
cn = KNOWLEDGE_BASE.get(item["predicted_class"], {}).get("name_cn", item["predicted_class"])
print(f" • {cn}: {item['count']} 次")
if stats["recent_records"]:
print(f"\n 最近记录:")
for r in stats["recent_records"][:5]:
cn = KNOWLEDGE_BASE.get(r["predicted_class"], {}).get("name_cn", r["predicted_class"])
print(f" • {cn} | 积分: +{r['points']} | {r['created_at']}")
def cmd_leaderboard(args):
from database import Database
db = Database()
leaders = db.get_leaderboard(args.limit)
if not leaders:
print("暂无环保数据,快去分类吧!")
return
print(f"\n{'='*50}")
print(" 🏆 环保积分排行榜")
print(f"{'='*50}")
print(f" {'排名':>4} {'用户名':<15} {'积分':<8} {'分类次数':<8}")
print(f" {'-'*35}")
badges = ["🥇", "🥈", "🥉"]
for i, u in enumerate(leaders, 1):
badge = badges[i - 1] if i <= 3 else " "
print(f" {badge} {i:<2} {u['username']:<15} {u['total_points']:<8} {u['total_classifications']:<8}")
def cmd_web(args):
from app.api import start_api
start_api(host=args.host, port=args.port)
def cmd_webui(args):
from webui import launch_gradio
launch_gradio(server_port=args.port)
def main():
parser = argparse.ArgumentParser(
description="AI 垃圾分类助手 - 环保小能手",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
%(prog)s train --data-dir dataset/trashnet
%(prog)s predict image.jpg
%(prog)s query 塑料
%(prog)s record --username 小明 image.jpg
%(prog)s stats --username 小明
%(prog)s leaderboard
%(prog)s web # 启动 API 服务 (供小程序调用)
%(prog)s webui # 启动 Gradio 网页界面
""",
)
sub = parser.add_subparsers(dest="command", help="可用命令")
p = sub.add_parser("train", help="训练分类模型")
p.add_argument("--data-dir", default=str(DATASET_DIR))
p.add_argument("--model-dir", default=str(MODEL_DIR))
p.add_argument("--epochs", type=int, default=30)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--lr", type=float, default=0.001)
p = sub.add_parser("predict", help="分类垃圾图片")
p.add_argument("image")
p.add_argument("--model-path", default=str(MODEL_PATH))
p.add_argument("--top-k", type=int, default=3)
p.add_argument("--no-detail", dest="detail", action="store_false", default=True)
p = sub.add_parser("query", help="查询垃圾分类知识")
p.add_argument("keyword")
p = sub.add_parser("record", help="分类并记录积分")
p.add_argument("image")
p.add_argument("--username", default="default")
p.add_argument("--model-path", default=str(MODEL_PATH))
p = sub.add_parser("stats", help="查看个人统计")
p.add_argument("--username", default="default")
p = sub.add_parser("leaderboard", help="查看排行榜")
p.add_argument("--limit", type=int, default=10)
p = sub.add_parser("web", help="启动 API 服务")
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=8000)
p = sub.add_parser("webui", help="启动 Gradio 网页界面")
p.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
cmds = {
"train": cmd_train,
"predict": cmd_predict,
"query": cmd_query,
"record": cmd_record,
"stats": cmd_stats,
"leaderboard": cmd_leaderboard,
"web": cmd_web,
"webui": cmd_webui,
}
fn = cmds.get(args.command)
if fn:
fn(args)
else:
parser.print_help()
if __name__ == "__main__":
main()
|