Spaces:
Sleeping
Sleeping
| """ | |
| AI 垃圾分类助手 - Web API (FastAPI) | |
| 为小程序提供后端服务接口 | |
| """ | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from predict import GarbageClassifier | |
| from knowledge import get_class_info, search_knowledge, KNOWLEDGE_BASE | |
| from database import Database | |
| from config import UPLOAD_DIR | |
| app = FastAPI(title="AI 垃圾分类助手", version="1.0.0") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| classifier = None | |
| db = Database() | |
| UPLOAD_DIR.mkdir(exist_ok=True) | |
| def get_classifier(): | |
| global classifier | |
| if classifier is None: | |
| classifier = GarbageClassifier() | |
| return classifier | |
| def root(): | |
| return {"message": "AI 垃圾分类助手 API", "version": "1.0.0"} | |
| async def predict(file: UploadFile = File(...), username: str = Form("default")): | |
| """上传图片并分类""" | |
| image_path = UPLOAD_DIR / file.filename | |
| content = await file.read() | |
| with open(image_path, "wb") as f: | |
| f.write(content) | |
| clf = get_classifier() | |
| results = clf.predict(str(image_path)) | |
| best = results[0] | |
| user_id = db.register_user(username) | |
| points = db.add_record(user_id, best["class_name"], best["confidence"], str(image_path)) | |
| info = get_class_info(best["class_name"]) | |
| return {"success": True, "results": results, "points_earned": points, "knowledge": info} | |
| def predict_url(image_url: str, username: str = "default"): | |
| """通过 URL 识别图片""" | |
| import requests | |
| from PIL import Image | |
| import io | |
| response = requests.get(image_url, timeout=10) | |
| image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
| temp_path = UPLOAD_DIR / "url_temp.jpg" | |
| image.save(temp_path) | |
| clf = get_classifier() | |
| results = clf.predict(str(temp_path)) | |
| best = results[0] | |
| user_id = db.register_user(username) | |
| points = db.add_record(user_id, best["class_name"], best["confidence"], image_url) | |
| info = get_class_info(best["class_name"]) | |
| return {"success": True, "results": results, "points_earned": points, "knowledge": info} | |
| def get_knowledge(class_name: str): | |
| info = get_class_info(class_name) | |
| if not info: | |
| raise HTTPException(status_code=404, detail="未找到该类别信息") | |
| return info | |
| def search_knowledge_api(q: str = ""): | |
| if q: | |
| return {k: v for k, v in search_knowledge(q)} | |
| return KNOWLEDGE_BASE | |
| def register_user(username: str = Form(...)): | |
| user_id = db.register_user(username) | |
| return {"user_id": user_id, "username": username} | |
| def user_stats(username: str): | |
| user = db.get_user(username) | |
| if not user: | |
| raise HTTPException(status_code=404, detail="用户不存在") | |
| return db.get_user_stats(user["id"]) | |
| def leaderboard(limit: int = 10): | |
| return {"leaderboard": db.get_leaderboard(limit)} | |
| def start_api(host="0.0.0.0", port=8000): | |
| print(f"🌐 API 服务已启动: http://localhost:{port}") | |
| print(f" POST /predict # 上传图片分类") | |
| print(f" GET /predict_url # URL 图片分类") | |
| print(f" GET /knowledge # 查询知识库") | |
| print(f" POST /user/register # 注册用户") | |
| print(f" GET /user/{{name}}/stats # 用户统计") | |
| print(f" GET /leaderboard # 排行榜") | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| start_api() | |