hutiger's picture
Upload folder using huggingface_hub
bf5b4d8 verified
Raw
History Blame Contribute Delete
3.67 kB
"""
AI 垃圾分类助手 - Web API (FastAPI)
为小程序提供后端服务接口
"""
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from predict import GarbageClassifier
from knowledge import get_class_info, search_knowledge, KNOWLEDGE_BASE
from database import Database
from config import UPLOAD_DIR
app = FastAPI(title="AI 垃圾分类助手", version="1.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
classifier = None
db = Database()
UPLOAD_DIR.mkdir(exist_ok=True)
def get_classifier():
global classifier
if classifier is None:
classifier = GarbageClassifier()
return classifier
@app.get("/")
def root():
return {"message": "AI 垃圾分类助手 API", "version": "1.0.0"}
@app.post("/predict")
async def predict(file: UploadFile = File(...), username: str = Form("default")):
"""上传图片并分类"""
image_path = UPLOAD_DIR / file.filename
content = await file.read()
with open(image_path, "wb") as f:
f.write(content)
clf = get_classifier()
results = clf.predict(str(image_path))
best = results[0]
user_id = db.register_user(username)
points = db.add_record(user_id, best["class_name"], best["confidence"], str(image_path))
info = get_class_info(best["class_name"])
return {"success": True, "results": results, "points_earned": points, "knowledge": info}
@app.get("/predict_url")
def predict_url(image_url: str, username: str = "default"):
"""通过 URL 识别图片"""
import requests
from PIL import Image
import io
response = requests.get(image_url, timeout=10)
image = Image.open(io.BytesIO(response.content)).convert("RGB")
temp_path = UPLOAD_DIR / "url_temp.jpg"
image.save(temp_path)
clf = get_classifier()
results = clf.predict(str(temp_path))
best = results[0]
user_id = db.register_user(username)
points = db.add_record(user_id, best["class_name"], best["confidence"], image_url)
info = get_class_info(best["class_name"])
return {"success": True, "results": results, "points_earned": points, "knowledge": info}
@app.get("/knowledge/{class_name}")
def get_knowledge(class_name: str):
info = get_class_info(class_name)
if not info:
raise HTTPException(status_code=404, detail="未找到该类别信息")
return info
@app.get("/knowledge")
def search_knowledge_api(q: str = ""):
if q:
return {k: v for k, v in search_knowledge(q)}
return KNOWLEDGE_BASE
@app.post("/user/register")
def register_user(username: str = Form(...)):
user_id = db.register_user(username)
return {"user_id": user_id, "username": username}
@app.get("/user/{username}/stats")
def user_stats(username: str):
user = db.get_user(username)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
return db.get_user_stats(user["id"])
@app.get("/leaderboard")
def leaderboard(limit: int = 10):
return {"leaderboard": db.get_leaderboard(limit)}
def start_api(host="0.0.0.0", port=8000):
print(f"🌐 API 服务已启动: http://localhost:{port}")
print(f" POST /predict # 上传图片分类")
print(f" GET /predict_url # URL 图片分类")
print(f" GET /knowledge # 查询知识库")
print(f" POST /user/register # 注册用户")
print(f" GET /user/{{name}}/stats # 用户统计")
print(f" GET /leaderboard # 排行榜")
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
start_api()