hutiger's picture
Upload folder using huggingface_hub
bf5b4d8 verified
Raw
History Blame Contribute Delete
7.4 kB
#!/usr/bin/env python3
"""
AI 垃圾分类助手 - 环保小能手
主入口程序 (CLI)
"""
import argparse
from pathlib import Path
from config import DATASET_DIR, MODEL_DIR, MODEL_PATH
def cmd_train(args):
from train import train
train(args)
def cmd_predict(args):
from predict import GarbageClassifier
from knowledge import get_class_info
classifier = GarbageClassifier(model_path=args.model_path)
image_path = Path(args.image)
if not image_path.exists():
print(f"✗ 图片不存在: {image_path}")
return
print(f"正在识别: {image_path.name}")
results = classifier.predict(str(image_path), top_k=args.top_k)
print(f"\n{'='*40}")
print("识别结果:")
print(f"{'='*40}")
for i, r in enumerate(results, 1):
pct = r["confidence"] * 100
bar = "█" * int(pct // 5) + "░" * (20 - int(pct // 5))
print(f" {i}. [{bar}] {r['class_name_cn']} ({pct:.1f}%)")
if args.detail:
info = get_class_info(results[0]["class_name"])
if info:
print(f"\n{'='*40}")
print(f"垃圾分类指南 - {info['name_cn']}")
print(f"分类: {info['category']}")
print(f"{'='*40}")
print(info["disposal"])
print(f"\n💡 {info['fun_fact']}")
def cmd_query(args):
from knowledge import search_knowledge
results = search_knowledge(args.keyword)
if not results:
print(f"未找到与 '{args.keyword}' 相关的知识")
return
for _, info in results:
print(f"\n{'='*50}")
print(f"{info['name_cn']} | 类别: {info['category']}")
print(f"{'='*50}")
print(info["description"])
print(f"\n📋 投放方法:")
print(info["disposal"])
print(f"\n💡 小贴士:")
for tip in info["tips"]:
print(f" • {tip}")
print(f"\n🎯 {info['fun_fact']}")
print(f"⏱ 降解时间: {info['degradation_time']}")
def cmd_record(args):
from database import Database
from predict import GarbageClassifier
from knowledge import get_class_info
db = Database()
user_id = db.register_user(args.username)
classifier = GarbageClassifier(model_path=args.model_path)
image_path = Path(args.image)
if not image_path.exists():
print(f"✗ 图片不存在: {image_path}")
return
results = classifier.predict(str(image_path))
best = results[0]
points = db.add_record(user_id, best["class_name"], best["confidence"])
print(f"\n✓ 已记录! {best['class_name_cn']} (置信度: {best['confidence']*100:.1f}%)")
print(f" +{points} 环保积分!")
info = get_class_info(best["class_name"])
if info:
print(f"\n📋 投放提示: {info['disposal'].split(chr(10))[0]}")
def cmd_stats(args):
from database import Database
from knowledge import KNOWLEDGE_BASE
db = Database()
user = db.get_user(args.username)
if not user:
print(f"用户 '{args.username}' 不存在,请先使用 record 命令")
return
stats = db.get_user_stats(user["id"])
print(f"\n{'='*50}")
print(f" 环保统计 - {stats['username']}")
print(f"{'='*50}")
print(f" 总分类次数: {stats['total_classifications']}")
print(f" 总环保积分: {stats['total_points']}")
print(f" 今日分类: {stats['today']['count']} 次")
print(f" 今日积分: {stats['today']['points']}")
if stats["class_distribution"]:
print(f"\n 各类别分类统计:")
for item in stats["class_distribution"]:
cn = KNOWLEDGE_BASE.get(item["predicted_class"], {}).get("name_cn", item["predicted_class"])
print(f" • {cn}: {item['count']} 次")
if stats["recent_records"]:
print(f"\n 最近记录:")
for r in stats["recent_records"][:5]:
cn = KNOWLEDGE_BASE.get(r["predicted_class"], {}).get("name_cn", r["predicted_class"])
print(f" • {cn} | 积分: +{r['points']} | {r['created_at']}")
def cmd_leaderboard(args):
from database import Database
db = Database()
leaders = db.get_leaderboard(args.limit)
if not leaders:
print("暂无环保数据,快去分类吧!")
return
print(f"\n{'='*50}")
print(" 🏆 环保积分排行榜")
print(f"{'='*50}")
print(f" {'排名':>4} {'用户名':<15} {'积分':<8} {'分类次数':<8}")
print(f" {'-'*35}")
badges = ["🥇", "🥈", "🥉"]
for i, u in enumerate(leaders, 1):
badge = badges[i - 1] if i <= 3 else " "
print(f" {badge} {i:<2} {u['username']:<15} {u['total_points']:<8} {u['total_classifications']:<8}")
def cmd_web(args):
from app.api import start_api
start_api(host=args.host, port=args.port)
def cmd_webui(args):
from webui import launch_gradio
launch_gradio(server_port=args.port)
def main():
parser = argparse.ArgumentParser(
description="AI 垃圾分类助手 - 环保小能手",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
%(prog)s train --data-dir dataset/trashnet
%(prog)s predict image.jpg
%(prog)s query 塑料
%(prog)s record --username 小明 image.jpg
%(prog)s stats --username 小明
%(prog)s leaderboard
%(prog)s web # 启动 API 服务 (供小程序调用)
%(prog)s webui # 启动 Gradio 网页界面
""",
)
sub = parser.add_subparsers(dest="command", help="可用命令")
p = sub.add_parser("train", help="训练分类模型")
p.add_argument("--data-dir", default=str(DATASET_DIR))
p.add_argument("--model-dir", default=str(MODEL_DIR))
p.add_argument("--epochs", type=int, default=30)
p.add_argument("--batch-size", type=int, default=32)
p.add_argument("--lr", type=float, default=0.001)
p = sub.add_parser("predict", help="分类垃圾图片")
p.add_argument("image")
p.add_argument("--model-path", default=str(MODEL_PATH))
p.add_argument("--top-k", type=int, default=3)
p.add_argument("--no-detail", dest="detail", action="store_false", default=True)
p = sub.add_parser("query", help="查询垃圾分类知识")
p.add_argument("keyword")
p = sub.add_parser("record", help="分类并记录积分")
p.add_argument("image")
p.add_argument("--username", default="default")
p.add_argument("--model-path", default=str(MODEL_PATH))
p = sub.add_parser("stats", help="查看个人统计")
p.add_argument("--username", default="default")
p = sub.add_parser("leaderboard", help="查看排行榜")
p.add_argument("--limit", type=int, default=10)
p = sub.add_parser("web", help="启动 API 服务")
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=8000)
p = sub.add_parser("webui", help="启动 Gradio 网页界面")
p.add_argument("--port", type=int, default=7860)
args = parser.parse_args()
cmds = {
"train": cmd_train,
"predict": cmd_predict,
"query": cmd_query,
"record": cmd_record,
"stats": cmd_stats,
"leaderboard": cmd_leaderboard,
"web": cmd_web,
"webui": cmd_webui,
}
fn = cmds.get(args.command)
if fn:
fn(args)
else:
parser.print_help()
if __name__ == "__main__":
main()