""" AI 垃圾分类助手 - Web API (FastAPI) 为小程序提供后端服务接口 """ from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn from predict import GarbageClassifier from knowledge import get_class_info, search_knowledge, KNOWLEDGE_BASE from database import Database from config import UPLOAD_DIR app = FastAPI(title="AI 垃圾分类助手", version="1.0.0") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) classifier = None db = Database() UPLOAD_DIR.mkdir(exist_ok=True) def get_classifier(): global classifier if classifier is None: classifier = GarbageClassifier() return classifier @app.get("/") def root(): return {"message": "AI 垃圾分类助手 API", "version": "1.0.0"} @app.post("/predict") async def predict(file: UploadFile = File(...), username: str = Form("default")): """上传图片并分类""" image_path = UPLOAD_DIR / file.filename content = await file.read() with open(image_path, "wb") as f: f.write(content) clf = get_classifier() results = clf.predict(str(image_path)) best = results[0] user_id = db.register_user(username) points = db.add_record(user_id, best["class_name"], best["confidence"], str(image_path)) info = get_class_info(best["class_name"]) return {"success": True, "results": results, "points_earned": points, "knowledge": info} @app.get("/predict_url") def predict_url(image_url: str, username: str = "default"): """通过 URL 识别图片""" import requests from PIL import Image import io response = requests.get(image_url, timeout=10) image = Image.open(io.BytesIO(response.content)).convert("RGB") temp_path = UPLOAD_DIR / "url_temp.jpg" image.save(temp_path) clf = get_classifier() results = clf.predict(str(temp_path)) best = results[0] user_id = db.register_user(username) points = db.add_record(user_id, best["class_name"], best["confidence"], image_url) info = get_class_info(best["class_name"]) return {"success": True, "results": results, "points_earned": points, "knowledge": info} @app.get("/knowledge/{class_name}") def get_knowledge(class_name: str): info = get_class_info(class_name) if not info: raise HTTPException(status_code=404, detail="未找到该类别信息") return info @app.get("/knowledge") def search_knowledge_api(q: str = ""): if q: return {k: v for k, v in search_knowledge(q)} return KNOWLEDGE_BASE @app.post("/user/register") def register_user(username: str = Form(...)): user_id = db.register_user(username) return {"user_id": user_id, "username": username} @app.get("/user/{username}/stats") def user_stats(username: str): user = db.get_user(username) if not user: raise HTTPException(status_code=404, detail="用户不存在") return db.get_user_stats(user["id"]) @app.get("/leaderboard") def leaderboard(limit: int = 10): return {"leaderboard": db.get_leaderboard(limit)} def start_api(host="0.0.0.0", port=8000): print(f"🌐 API 服务已启动: http://localhost:{port}") print(f" POST /predict # 上传图片分类") print(f" GET /predict_url # URL 图片分类") print(f" GET /knowledge # 查询知识库") print(f" POST /user/register # 注册用户") print(f" GET /user/{{name}}/stats # 用户统计") print(f" GET /leaderboard # 排行榜") uvicorn.run(app, host=host, port=port) if __name__ == "__main__": start_api()