File size: 3,672 Bytes
bf5b4d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
AI 垃圾分类助手 - Web API (FastAPI)
为小程序提供后端服务接口
"""

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

from predict import GarbageClassifier
from knowledge import get_class_info, search_knowledge, KNOWLEDGE_BASE
from database import Database
from config import UPLOAD_DIR

app = FastAPI(title="AI 垃圾分类助手", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

classifier = None
db = Database()

UPLOAD_DIR.mkdir(exist_ok=True)


def get_classifier():
    global classifier
    if classifier is None:
        classifier = GarbageClassifier()
    return classifier


@app.get("/")
def root():
    return {"message": "AI 垃圾分类助手 API", "version": "1.0.0"}


@app.post("/predict")
async def predict(file: UploadFile = File(...), username: str = Form("default")):
    """上传图片并分类"""
    image_path = UPLOAD_DIR / file.filename
    content = await file.read()
    with open(image_path, "wb") as f:
        f.write(content)

    clf = get_classifier()
    results = clf.predict(str(image_path))
    best = results[0]

    user_id = db.register_user(username)
    points = db.add_record(user_id, best["class_name"], best["confidence"], str(image_path))

    info = get_class_info(best["class_name"])
    return {"success": True, "results": results, "points_earned": points, "knowledge": info}


@app.get("/predict_url")
def predict_url(image_url: str, username: str = "default"):
    """通过 URL 识别图片"""
    import requests
    from PIL import Image
    import io

    response = requests.get(image_url, timeout=10)
    image = Image.open(io.BytesIO(response.content)).convert("RGB")

    temp_path = UPLOAD_DIR / "url_temp.jpg"
    image.save(temp_path)

    clf = get_classifier()
    results = clf.predict(str(temp_path))
    best = results[0]

    user_id = db.register_user(username)
    points = db.add_record(user_id, best["class_name"], best["confidence"], image_url)

    info = get_class_info(best["class_name"])
    return {"success": True, "results": results, "points_earned": points, "knowledge": info}


@app.get("/knowledge/{class_name}")
def get_knowledge(class_name: str):
    info = get_class_info(class_name)
    if not info:
        raise HTTPException(status_code=404, detail="未找到该类别信息")
    return info


@app.get("/knowledge")
def search_knowledge_api(q: str = ""):
    if q:
        return {k: v for k, v in search_knowledge(q)}
    return KNOWLEDGE_BASE


@app.post("/user/register")
def register_user(username: str = Form(...)):
    user_id = db.register_user(username)
    return {"user_id": user_id, "username": username}


@app.get("/user/{username}/stats")
def user_stats(username: str):
    user = db.get_user(username)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return db.get_user_stats(user["id"])


@app.get("/leaderboard")
def leaderboard(limit: int = 10):
    return {"leaderboard": db.get_leaderboard(limit)}


def start_api(host="0.0.0.0", port=8000):
    print(f"🌐 API 服务已启动: http://localhost:{port}")
    print(f"   POST /predict          # 上传图片分类")
    print(f"   GET  /predict_url       # URL 图片分类")
    print(f"   GET  /knowledge         # 查询知识库")
    print(f"   POST /user/register     # 注册用户")
    print(f"   GET  /user/{{name}}/stats # 用户统计")
    print(f"   GET  /leaderboard       # 排行榜")
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    start_api()