#!/usr/bin/env python3 """ AI 垃圾分类助手 - 环保小能手 主入口程序 (CLI) """ import argparse from pathlib import Path from config import DATASET_DIR, MODEL_DIR, MODEL_PATH def cmd_train(args): from train import train train(args) def cmd_predict(args): from predict import GarbageClassifier from knowledge import get_class_info classifier = GarbageClassifier(model_path=args.model_path) image_path = Path(args.image) if not image_path.exists(): print(f"✗ 图片不存在: {image_path}") return print(f"正在识别: {image_path.name}") results = classifier.predict(str(image_path), top_k=args.top_k) print(f"\n{'='*40}") print("识别结果:") print(f"{'='*40}") for i, r in enumerate(results, 1): pct = r["confidence"] * 100 bar = "█" * int(pct // 5) + "░" * (20 - int(pct // 5)) print(f" {i}. [{bar}] {r['class_name_cn']} ({pct:.1f}%)") if args.detail: info = get_class_info(results[0]["class_name"]) if info: print(f"\n{'='*40}") print(f"垃圾分类指南 - {info['name_cn']}") print(f"分类: {info['category']}") print(f"{'='*40}") print(info["disposal"]) print(f"\n💡 {info['fun_fact']}") def cmd_query(args): from knowledge import search_knowledge results = search_knowledge(args.keyword) if not results: print(f"未找到与 '{args.keyword}' 相关的知识") return for _, info in results: print(f"\n{'='*50}") print(f"{info['name_cn']} | 类别: {info['category']}") print(f"{'='*50}") print(info["description"]) print(f"\n📋 投放方法:") print(info["disposal"]) print(f"\n💡 小贴士:") for tip in info["tips"]: print(f" • {tip}") print(f"\n🎯 {info['fun_fact']}") print(f"⏱ 降解时间: {info['degradation_time']}") def cmd_record(args): from database import Database from predict import GarbageClassifier from knowledge import get_class_info db = Database() user_id = db.register_user(args.username) classifier = GarbageClassifier(model_path=args.model_path) image_path = Path(args.image) if not image_path.exists(): print(f"✗ 图片不存在: {image_path}") return results = classifier.predict(str(image_path)) best = results[0] points = db.add_record(user_id, best["class_name"], best["confidence"]) print(f"\n✓ 已记录! {best['class_name_cn']} (置信度: {best['confidence']*100:.1f}%)") print(f" +{points} 环保积分!") info = get_class_info(best["class_name"]) if info: print(f"\n📋 投放提示: {info['disposal'].split(chr(10))[0]}") def cmd_stats(args): from database import Database from knowledge import KNOWLEDGE_BASE db = Database() user = db.get_user(args.username) if not user: print(f"用户 '{args.username}' 不存在,请先使用 record 命令") return stats = db.get_user_stats(user["id"]) print(f"\n{'='*50}") print(f" 环保统计 - {stats['username']}") print(f"{'='*50}") print(f" 总分类次数: {stats['total_classifications']}") print(f" 总环保积分: {stats['total_points']}") print(f" 今日分类: {stats['today']['count']} 次") print(f" 今日积分: {stats['today']['points']}") if stats["class_distribution"]: print(f"\n 各类别分类统计:") for item in stats["class_distribution"]: cn = KNOWLEDGE_BASE.get(item["predicted_class"], {}).get("name_cn", item["predicted_class"]) print(f" • {cn}: {item['count']} 次") if stats["recent_records"]: print(f"\n 最近记录:") for r in stats["recent_records"][:5]: cn = KNOWLEDGE_BASE.get(r["predicted_class"], {}).get("name_cn", r["predicted_class"]) print(f" • {cn} | 积分: +{r['points']} | {r['created_at']}") def cmd_leaderboard(args): from database import Database db = Database() leaders = db.get_leaderboard(args.limit) if not leaders: print("暂无环保数据,快去分类吧!") return print(f"\n{'='*50}") print(" 🏆 环保积分排行榜") print(f"{'='*50}") print(f" {'排名':>4} {'用户名':<15} {'积分':<8} {'分类次数':<8}") print(f" {'-'*35}") badges = ["🥇", "🥈", "🥉"] for i, u in enumerate(leaders, 1): badge = badges[i - 1] if i <= 3 else " " print(f" {badge} {i:<2} {u['username']:<15} {u['total_points']:<8} {u['total_classifications']:<8}") def cmd_web(args): from app.api import start_api start_api(host=args.host, port=args.port) def cmd_webui(args): from webui import launch_gradio launch_gradio(server_port=args.port) def main(): parser = argparse.ArgumentParser( description="AI 垃圾分类助手 - 环保小能手", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 使用示例: %(prog)s train --data-dir dataset/trashnet %(prog)s predict image.jpg %(prog)s query 塑料 %(prog)s record --username 小明 image.jpg %(prog)s stats --username 小明 %(prog)s leaderboard %(prog)s web # 启动 API 服务 (供小程序调用) %(prog)s webui # 启动 Gradio 网页界面 """, ) sub = parser.add_subparsers(dest="command", help="可用命令") p = sub.add_parser("train", help="训练分类模型") p.add_argument("--data-dir", default=str(DATASET_DIR)) p.add_argument("--model-dir", default=str(MODEL_DIR)) p.add_argument("--epochs", type=int, default=30) p.add_argument("--batch-size", type=int, default=32) p.add_argument("--lr", type=float, default=0.001) p = sub.add_parser("predict", help="分类垃圾图片") p.add_argument("image") p.add_argument("--model-path", default=str(MODEL_PATH)) p.add_argument("--top-k", type=int, default=3) p.add_argument("--no-detail", dest="detail", action="store_false", default=True) p = sub.add_parser("query", help="查询垃圾分类知识") p.add_argument("keyword") p = sub.add_parser("record", help="分类并记录积分") p.add_argument("image") p.add_argument("--username", default="default") p.add_argument("--model-path", default=str(MODEL_PATH)) p = sub.add_parser("stats", help="查看个人统计") p.add_argument("--username", default="default") p = sub.add_parser("leaderboard", help="查看排行榜") p.add_argument("--limit", type=int, default=10) p = sub.add_parser("web", help="启动 API 服务") p.add_argument("--host", default="0.0.0.0") p.add_argument("--port", type=int, default=8000) p = sub.add_parser("webui", help="启动 Gradio 网页界面") p.add_argument("--port", type=int, default=7860) args = parser.parse_args() cmds = { "train": cmd_train, "predict": cmd_predict, "query": cmd_query, "record": cmd_record, "stats": cmd_stats, "leaderboard": cmd_leaderboard, "web": cmd_web, "webui": cmd_webui, } fn = cmds.get(args.command) if fn: fn(args) else: parser.print_help() if __name__ == "__main__": main()