File size: 7,395 Bytes
bf5b4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
#!/usr/bin/env python3
"""
AI 垃圾分类助手 - 环保小能手
主入口程序 (CLI)
"""

import argparse
from pathlib import Path
from config import DATASET_DIR, MODEL_DIR, MODEL_PATH


def cmd_train(args):
    from train import train
    train(args)


def cmd_predict(args):
    from predict import GarbageClassifier
    from knowledge import get_class_info

    classifier = GarbageClassifier(model_path=args.model_path)
    image_path = Path(args.image)
    if not image_path.exists():
        print(f"✗ 图片不存在: {image_path}")
        return

    print(f"正在识别: {image_path.name}")
    results = classifier.predict(str(image_path), top_k=args.top_k)

    print(f"\n{'='*40}")
    print("识别结果:")
    print(f"{'='*40}")
    for i, r in enumerate(results, 1):
        pct = r["confidence"] * 100
        bar = "█" * int(pct // 5) + "░" * (20 - int(pct // 5))
        print(f"  {i}. [{bar}] {r['class_name_cn']} ({pct:.1f}%)")

    if args.detail:
        info = get_class_info(results[0]["class_name"])
        if info:
            print(f"\n{'='*40}")
            print(f"垃圾分类指南 - {info['name_cn']}")
            print(f"分类: {info['category']}")
            print(f"{'='*40}")
            print(info["disposal"])
            print(f"\n💡 {info['fun_fact']}")


def cmd_query(args):
    from knowledge import search_knowledge

    results = search_knowledge(args.keyword)
    if not results:
        print(f"未找到与 '{args.keyword}' 相关的知识")
        return

    for _, info in results:
        print(f"\n{'='*50}")
        print(f"{info['name_cn']} | 类别: {info['category']}")
        print(f"{'='*50}")
        print(info["description"])
        print(f"\n📋 投放方法:")
        print(info["disposal"])
        print(f"\n💡 小贴士:")
        for tip in info["tips"]:
            print(f"  • {tip}")
        print(f"\n🎯 {info['fun_fact']}")
        print(f"⏱ 降解时间: {info['degradation_time']}")


def cmd_record(args):
    from database import Database
    from predict import GarbageClassifier
    from knowledge import get_class_info

    db = Database()
    user_id = db.register_user(args.username)

    classifier = GarbageClassifier(model_path=args.model_path)
    image_path = Path(args.image)
    if not image_path.exists():
        print(f"✗ 图片不存在: {image_path}")
        return

    results = classifier.predict(str(image_path))
    best = results[0]
    points = db.add_record(user_id, best["class_name"], best["confidence"])

    print(f"\n✓ 已记录! {best['class_name_cn']} (置信度: {best['confidence']*100:.1f}%)")
    print(f"  +{points} 环保积分!")

    info = get_class_info(best["class_name"])
    if info:
        print(f"\n📋 投放提示: {info['disposal'].split(chr(10))[0]}")


def cmd_stats(args):
    from database import Database
    from knowledge import KNOWLEDGE_BASE

    db = Database()
    user = db.get_user(args.username)
    if not user:
        print(f"用户 '{args.username}' 不存在,请先使用 record 命令")
        return

    stats = db.get_user_stats(user["id"])
    print(f"\n{'='*50}")
    print(f"  环保统计 - {stats['username']}")
    print(f"{'='*50}")
    print(f"  总分类次数: {stats['total_classifications']}")
    print(f"  总环保积分: {stats['total_points']}")
    print(f"  今日分类: {stats['today']['count']} 次")
    print(f"  今日积分: {stats['today']['points']}")

    if stats["class_distribution"]:
        print(f"\n  各类别分类统计:")
        for item in stats["class_distribution"]:
            cn = KNOWLEDGE_BASE.get(item["predicted_class"], {}).get("name_cn", item["predicted_class"])
            print(f"    • {cn}: {item['count']} 次")

    if stats["recent_records"]:
        print(f"\n  最近记录:")
        for r in stats["recent_records"][:5]:
            cn = KNOWLEDGE_BASE.get(r["predicted_class"], {}).get("name_cn", r["predicted_class"])
            print(f"    • {cn} | 积分: +{r['points']} | {r['created_at']}")


def cmd_leaderboard(args):
    from database import Database

    db = Database()
    leaders = db.get_leaderboard(args.limit)

    if not leaders:
        print("暂无环保数据,快去分类吧!")
        return

    print(f"\n{'='*50}")
    print("  🏆 环保积分排行榜")
    print(f"{'='*50}")
    print(f"  {'排名':>4} {'用户名':<15} {'积分':<8} {'分类次数':<8}")
    print(f"  {'-'*35}")
    badges = ["🥇", "🥈", "🥉"]
    for i, u in enumerate(leaders, 1):
        badge = badges[i - 1] if i <= 3 else "  "
        print(f"  {badge} {i:<2} {u['username']:<15} {u['total_points']:<8} {u['total_classifications']:<8}")


def cmd_web(args):
    from app.api import start_api
    start_api(host=args.host, port=args.port)


def cmd_webui(args):
    from webui import launch_gradio
    launch_gradio(server_port=args.port)


def main():
    parser = argparse.ArgumentParser(
        description="AI 垃圾分类助手 - 环保小能手",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
使用示例:
  %(prog)s train --data-dir dataset/trashnet
  %(prog)s predict image.jpg
  %(prog)s query 塑料
  %(prog)s record --username 小明 image.jpg
  %(prog)s stats --username 小明
  %(prog)s leaderboard
  %(prog)s web          # 启动 API 服务 (供小程序调用)
  %(prog)s webui        # 启动 Gradio 网页界面
        """,
    )
    sub = parser.add_subparsers(dest="command", help="可用命令")

    p = sub.add_parser("train", help="训练分类模型")
    p.add_argument("--data-dir", default=str(DATASET_DIR))
    p.add_argument("--model-dir", default=str(MODEL_DIR))
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch-size", type=int, default=32)
    p.add_argument("--lr", type=float, default=0.001)

    p = sub.add_parser("predict", help="分类垃圾图片")
    p.add_argument("image")
    p.add_argument("--model-path", default=str(MODEL_PATH))
    p.add_argument("--top-k", type=int, default=3)
    p.add_argument("--no-detail", dest="detail", action="store_false", default=True)

    p = sub.add_parser("query", help="查询垃圾分类知识")
    p.add_argument("keyword")

    p = sub.add_parser("record", help="分类并记录积分")
    p.add_argument("image")
    p.add_argument("--username", default="default")
    p.add_argument("--model-path", default=str(MODEL_PATH))

    p = sub.add_parser("stats", help="查看个人统计")
    p.add_argument("--username", default="default")

    p = sub.add_parser("leaderboard", help="查看排行榜")
    p.add_argument("--limit", type=int, default=10)

    p = sub.add_parser("web", help="启动 API 服务")
    p.add_argument("--host", default="0.0.0.0")
    p.add_argument("--port", type=int, default=8000)

    p = sub.add_parser("webui", help="启动 Gradio 网页界面")
    p.add_argument("--port", type=int, default=7860)

    args = parser.parse_args()

    cmds = {
        "train": cmd_train,
        "predict": cmd_predict,
        "query": cmd_query,
        "record": cmd_record,
        "stats": cmd_stats,
        "leaderboard": cmd_leaderboard,
        "web": cmd_web,
        "webui": cmd_webui,
    }
    fn = cmds.get(args.command)
    if fn:
        fn(args)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()