Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| AI 垃圾分类助手 - 环保小能手 | |
| 主入口程序 (CLI) | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from config import DATASET_DIR, MODEL_DIR, MODEL_PATH | |
| def cmd_train(args): | |
| from train import train | |
| train(args) | |
| def cmd_predict(args): | |
| from predict import GarbageClassifier | |
| from knowledge import get_class_info | |
| classifier = GarbageClassifier(model_path=args.model_path) | |
| image_path = Path(args.image) | |
| if not image_path.exists(): | |
| print(f"✗ 图片不存在: {image_path}") | |
| return | |
| print(f"正在识别: {image_path.name}") | |
| results = classifier.predict(str(image_path), top_k=args.top_k) | |
| print(f"\n{'='*40}") | |
| print("识别结果:") | |
| print(f"{'='*40}") | |
| for i, r in enumerate(results, 1): | |
| pct = r["confidence"] * 100 | |
| bar = "█" * int(pct // 5) + "░" * (20 - int(pct // 5)) | |
| print(f" {i}. [{bar}] {r['class_name_cn']} ({pct:.1f}%)") | |
| if args.detail: | |
| info = get_class_info(results[0]["class_name"]) | |
| if info: | |
| print(f"\n{'='*40}") | |
| print(f"垃圾分类指南 - {info['name_cn']}") | |
| print(f"分类: {info['category']}") | |
| print(f"{'='*40}") | |
| print(info["disposal"]) | |
| print(f"\n💡 {info['fun_fact']}") | |
| def cmd_query(args): | |
| from knowledge import search_knowledge | |
| results = search_knowledge(args.keyword) | |
| if not results: | |
| print(f"未找到与 '{args.keyword}' 相关的知识") | |
| return | |
| for _, info in results: | |
| print(f"\n{'='*50}") | |
| print(f"{info['name_cn']} | 类别: {info['category']}") | |
| print(f"{'='*50}") | |
| print(info["description"]) | |
| print(f"\n📋 投放方法:") | |
| print(info["disposal"]) | |
| print(f"\n💡 小贴士:") | |
| for tip in info["tips"]: | |
| print(f" • {tip}") | |
| print(f"\n🎯 {info['fun_fact']}") | |
| print(f"⏱ 降解时间: {info['degradation_time']}") | |
| def cmd_record(args): | |
| from database import Database | |
| from predict import GarbageClassifier | |
| from knowledge import get_class_info | |
| db = Database() | |
| user_id = db.register_user(args.username) | |
| classifier = GarbageClassifier(model_path=args.model_path) | |
| image_path = Path(args.image) | |
| if not image_path.exists(): | |
| print(f"✗ 图片不存在: {image_path}") | |
| return | |
| results = classifier.predict(str(image_path)) | |
| best = results[0] | |
| points = db.add_record(user_id, best["class_name"], best["confidence"]) | |
| print(f"\n✓ 已记录! {best['class_name_cn']} (置信度: {best['confidence']*100:.1f}%)") | |
| print(f" +{points} 环保积分!") | |
| info = get_class_info(best["class_name"]) | |
| if info: | |
| print(f"\n📋 投放提示: {info['disposal'].split(chr(10))[0]}") | |
| def cmd_stats(args): | |
| from database import Database | |
| from knowledge import KNOWLEDGE_BASE | |
| db = Database() | |
| user = db.get_user(args.username) | |
| if not user: | |
| print(f"用户 '{args.username}' 不存在,请先使用 record 命令") | |
| return | |
| stats = db.get_user_stats(user["id"]) | |
| print(f"\n{'='*50}") | |
| print(f" 环保统计 - {stats['username']}") | |
| print(f"{'='*50}") | |
| print(f" 总分类次数: {stats['total_classifications']}") | |
| print(f" 总环保积分: {stats['total_points']}") | |
| print(f" 今日分类: {stats['today']['count']} 次") | |
| print(f" 今日积分: {stats['today']['points']}") | |
| if stats["class_distribution"]: | |
| print(f"\n 各类别分类统计:") | |
| for item in stats["class_distribution"]: | |
| cn = KNOWLEDGE_BASE.get(item["predicted_class"], {}).get("name_cn", item["predicted_class"]) | |
| print(f" • {cn}: {item['count']} 次") | |
| if stats["recent_records"]: | |
| print(f"\n 最近记录:") | |
| for r in stats["recent_records"][:5]: | |
| cn = KNOWLEDGE_BASE.get(r["predicted_class"], {}).get("name_cn", r["predicted_class"]) | |
| print(f" • {cn} | 积分: +{r['points']} | {r['created_at']}") | |
| def cmd_leaderboard(args): | |
| from database import Database | |
| db = Database() | |
| leaders = db.get_leaderboard(args.limit) | |
| if not leaders: | |
| print("暂无环保数据,快去分类吧!") | |
| return | |
| print(f"\n{'='*50}") | |
| print(" 🏆 环保积分排行榜") | |
| print(f"{'='*50}") | |
| print(f" {'排名':>4} {'用户名':<15} {'积分':<8} {'分类次数':<8}") | |
| print(f" {'-'*35}") | |
| badges = ["🥇", "🥈", "🥉"] | |
| for i, u in enumerate(leaders, 1): | |
| badge = badges[i - 1] if i <= 3 else " " | |
| print(f" {badge} {i:<2} {u['username']:<15} {u['total_points']:<8} {u['total_classifications']:<8}") | |
| def cmd_web(args): | |
| from app.api import start_api | |
| start_api(host=args.host, port=args.port) | |
| def cmd_webui(args): | |
| from webui import launch_gradio | |
| launch_gradio(server_port=args.port) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="AI 垃圾分类助手 - 环保小能手", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| 使用示例: | |
| %(prog)s train --data-dir dataset/trashnet | |
| %(prog)s predict image.jpg | |
| %(prog)s query 塑料 | |
| %(prog)s record --username 小明 image.jpg | |
| %(prog)s stats --username 小明 | |
| %(prog)s leaderboard | |
| %(prog)s web # 启动 API 服务 (供小程序调用) | |
| %(prog)s webui # 启动 Gradio 网页界面 | |
| """, | |
| ) | |
| sub = parser.add_subparsers(dest="command", help="可用命令") | |
| p = sub.add_parser("train", help="训练分类模型") | |
| p.add_argument("--data-dir", default=str(DATASET_DIR)) | |
| p.add_argument("--model-dir", default=str(MODEL_DIR)) | |
| p.add_argument("--epochs", type=int, default=30) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--lr", type=float, default=0.001) | |
| p = sub.add_parser("predict", help="分类垃圾图片") | |
| p.add_argument("image") | |
| p.add_argument("--model-path", default=str(MODEL_PATH)) | |
| p.add_argument("--top-k", type=int, default=3) | |
| p.add_argument("--no-detail", dest="detail", action="store_false", default=True) | |
| p = sub.add_parser("query", help="查询垃圾分类知识") | |
| p.add_argument("keyword") | |
| p = sub.add_parser("record", help="分类并记录积分") | |
| p.add_argument("image") | |
| p.add_argument("--username", default="default") | |
| p.add_argument("--model-path", default=str(MODEL_PATH)) | |
| p = sub.add_parser("stats", help="查看个人统计") | |
| p.add_argument("--username", default="default") | |
| p = sub.add_parser("leaderboard", help="查看排行榜") | |
| p.add_argument("--limit", type=int, default=10) | |
| p = sub.add_parser("web", help="启动 API 服务") | |
| p.add_argument("--host", default="0.0.0.0") | |
| p.add_argument("--port", type=int, default=8000) | |
| p = sub.add_parser("webui", help="启动 Gradio 网页界面") | |
| p.add_argument("--port", type=int, default=7860) | |
| args = parser.parse_args() | |
| cmds = { | |
| "train": cmd_train, | |
| "predict": cmd_predict, | |
| "query": cmd_query, | |
| "record": cmd_record, | |
| "stats": cmd_stats, | |
| "leaderboard": cmd_leaderboard, | |
| "web": cmd_web, | |
| "webui": cmd_webui, | |
| } | |
| fn = cmds.get(args.command) | |
| if fn: | |
| fn(args) | |
| else: | |
| parser.print_help() | |
| if __name__ == "__main__": | |
| main() | |