Spaces:
Running
Running
Upload 13 files
Browse files- app.py +19 -15
- router_items.py +21 -10
- 数据库连接.py +51 -48
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# ⚙️ 后端逻辑/核心服务端.py (Hugging Face Spaces app.py)
|
| 2 |
-
from fastapi import FastAPI, File, UploadFile, Form, Depends
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from fastapi.responses import Response, JSONResponse
|
| 5 |
from sqlalchemy.orm import Session
|
|
@@ -42,14 +42,26 @@ app.include_router(wallet_router)
|
|
| 42 |
|
| 43 |
@app.get("/")
|
| 44 |
def read_root():
|
| 45 |
-
return {"status": "ok"}
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
@app.post("/api/upload")
|
| 48 |
async def upload_file(file: UploadFile = File(...), file_type: str = Form(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
content = await file.read()
|
|
|
|
|
|
|
|
|
|
| 50 |
file_hash = hashlib.md5(content).hexdigest()[:10]
|
| 51 |
|
| 52 |
-
new_filename = f"{file_hash}
|
| 53 |
safe_filename = urllib.parse.quote(file.filename)
|
| 54 |
safe_url_filename = f"{file_hash}_{safe_filename}"
|
| 55 |
|
|
@@ -57,14 +69,12 @@ async def upload_file(file: UploadFile = File(...), file_type: str = Form(...)):
|
|
| 57 |
target_dir = dir_mapping.get(file_type, "others")
|
| 58 |
full_path_in_repo = f"{target_dir}/{new_filename}"
|
| 59 |
|
|
|
|
| 60 |
db.save_file(full_path_in_repo, content)
|
| 61 |
|
| 62 |
url = f"https://huggingface.co/datasets/{db.DATASET_REPO_ID}/resolve/main/{target_dir}/{safe_url_filename}"
|
| 63 |
return {"status": "success", "url": url, "display_name": file.filename, "hashed_name": new_filename}
|
| 64 |
|
| 65 |
-
# =======================================================
|
| 66 |
-
# 【核心新增】:死链与资源有效性检测 (解决问题1:防买空)
|
| 67 |
-
# =======================================================
|
| 68 |
class ValidateRequest(BaseModel):
|
| 69 |
item_id: str
|
| 70 |
|
|
@@ -79,17 +89,15 @@ async def validate_resource(req: ValidateRequest):
|
|
| 79 |
itype = item.get("type", "")
|
| 80 |
|
| 81 |
if itype.startswith("tool"):
|
| 82 |
-
# 探测 Git 仓库是否为 404 死链
|
| 83 |
try:
|
| 84 |
req_obj = urllib.request.Request(link, method="HEAD", headers={'User-Agent': 'Mozilla/5.0'})
|
| 85 |
with urllib.request.urlopen(req_obj, timeout=5) as response:
|
| 86 |
if response.status >= 400:
|
| 87 |
return JSONResponse(content={"error": "原作者的 Git 仓库已失效或设为私有"}, status_code=400)
|
| 88 |
-
except Exception
|
| 89 |
return JSONResponse(content={"error": "原作者的 Git 仓库无法访问,链接已失效"}, status_code=400)
|
| 90 |
|
| 91 |
elif itype.startswith("app"):
|
| 92 |
-
# 探测 HF 云端的 JSON 文件是否丢失
|
| 93 |
if "resolve/main/" in link:
|
| 94 |
repo_path = urllib.parse.unquote(link.split("resolve/main/")[-1])
|
| 95 |
hf_token = os.environ.get("HF_TOKEN")
|
|
@@ -99,13 +107,10 @@ async def validate_resource(req: ValidateRequest):
|
|
| 99 |
if not exists:
|
| 100 |
return JSONResponse(content={"error": "该工作流的 JSON 文件已在云端损坏或丢失"}, status_code=400)
|
| 101 |
except Exception:
|
| 102 |
-
pass
|
| 103 |
|
| 104 |
return {"status": "success"}
|
| 105 |
|
| 106 |
-
# =======================================================
|
| 107 |
-
# 代理下载与所有权鉴权防线
|
| 108 |
-
# =======================================================
|
| 109 |
class ProxyDownloadRequest(BaseModel):
|
| 110 |
url: str
|
| 111 |
item_id: str
|
|
@@ -125,7 +130,6 @@ async def proxy_download(req_data: ProxyDownloadRequest, sql_db: Session = Depen
|
|
| 125 |
price = int(item.get("price", 0))
|
| 126 |
author = item.get("author")
|
| 127 |
|
| 128 |
-
# 所有权拦截:如果收费且不是作者本人,严查 SQL 所有权表
|
| 129 |
if price > 0 and req_data.account != author:
|
| 130 |
owned = sql_db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
|
| 131 |
if not owned:
|
|
@@ -151,4 +155,4 @@ async def proxy_download(req_data: ProxyDownloadRequest, sql_db: Session = Depen
|
|
| 151 |
return Response(content=content, media_type="application/json")
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
-
return JSONResponse(content={"error":
|
|
|
|
| 1 |
# ⚙️ 后端逻辑/核心服务端.py (Hugging Face Spaces app.py)
|
| 2 |
+
from fastapi import FastAPI, File, UploadFile, Form, Depends, HTTPException
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from fastapi.responses import Response, JSONResponse
|
| 5 |
from sqlalchemy.orm import Session
|
|
|
|
| 42 |
|
| 43 |
@app.get("/")
|
| 44 |
def read_root():
|
| 45 |
+
return {"status": "ok", "message": "API System Protected & Running"}
|
| 46 |
+
|
| 47 |
+
# 【安全优化】:允许的文件后缀白名单,防挂马
|
| 48 |
+
ALLOWED_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".json", ".zip"}
|
| 49 |
|
| 50 |
@app.post("/api/upload")
|
| 51 |
async def upload_file(file: UploadFile = File(...), file_type: str = Form(...)):
|
| 52 |
+
# 验证后缀名
|
| 53 |
+
_, ext = os.path.splitext(file.filename)
|
| 54 |
+
if ext.lower() not in ALLOWED_EXTENSIONS:
|
| 55 |
+
return JSONResponse(status_code=400, content={"error": f"安全拦截:不支持上传 {ext} 格式的文件"})
|
| 56 |
+
|
| 57 |
+
# 限制单次读取文件大小,防止撑爆内存
|
| 58 |
content = await file.read()
|
| 59 |
+
if len(content) > 10 * 1024 * 1024: # 10MB 限制
|
| 60 |
+
return JSONResponse(status_code=400, content={"error": "文件超过 10MB 大小限制"})
|
| 61 |
+
|
| 62 |
file_hash = hashlib.md5(content).hexdigest()[:10]
|
| 63 |
|
| 64 |
+
new_filename = f"{file_hash}{ext.lower()}"
|
| 65 |
safe_filename = urllib.parse.quote(file.filename)
|
| 66 |
safe_url_filename = f"{file_hash}_{safe_filename}"
|
| 67 |
|
|
|
|
| 69 |
target_dir = dir_mapping.get(file_type, "others")
|
| 70 |
full_path_in_repo = f"{target_dir}/{new_filename}"
|
| 71 |
|
| 72 |
+
# 交给底层带锁与异步线程的 db 处理
|
| 73 |
db.save_file(full_path_in_repo, content)
|
| 74 |
|
| 75 |
url = f"https://huggingface.co/datasets/{db.DATASET_REPO_ID}/resolve/main/{target_dir}/{safe_url_filename}"
|
| 76 |
return {"status": "success", "url": url, "display_name": file.filename, "hashed_name": new_filename}
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
class ValidateRequest(BaseModel):
|
| 79 |
item_id: str
|
| 80 |
|
|
|
|
| 89 |
itype = item.get("type", "")
|
| 90 |
|
| 91 |
if itype.startswith("tool"):
|
|
|
|
| 92 |
try:
|
| 93 |
req_obj = urllib.request.Request(link, method="HEAD", headers={'User-Agent': 'Mozilla/5.0'})
|
| 94 |
with urllib.request.urlopen(req_obj, timeout=5) as response:
|
| 95 |
if response.status >= 400:
|
| 96 |
return JSONResponse(content={"error": "原作者的 Git 仓库已失效或设为私有"}, status_code=400)
|
| 97 |
+
except Exception:
|
| 98 |
return JSONResponse(content={"error": "原作者的 Git 仓库无法访问,链接已失效"}, status_code=400)
|
| 99 |
|
| 100 |
elif itype.startswith("app"):
|
|
|
|
| 101 |
if "resolve/main/" in link:
|
| 102 |
repo_path = urllib.parse.unquote(link.split("resolve/main/")[-1])
|
| 103 |
hf_token = os.environ.get("HF_TOKEN")
|
|
|
|
| 107 |
if not exists:
|
| 108 |
return JSONResponse(content={"error": "该工作流的 JSON 文件已在云端损坏或丢失"}, status_code=400)
|
| 109 |
except Exception:
|
| 110 |
+
pass
|
| 111 |
|
| 112 |
return {"status": "success"}
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
class ProxyDownloadRequest(BaseModel):
|
| 115 |
url: str
|
| 116 |
item_id: str
|
|
|
|
| 130 |
price = int(item.get("price", 0))
|
| 131 |
author = item.get("author")
|
| 132 |
|
|
|
|
| 133 |
if price > 0 and req_data.account != author:
|
| 134 |
owned = sql_db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
|
| 135 |
if not owned:
|
|
|
|
| 155 |
return Response(content=content, media_type="application/json")
|
| 156 |
|
| 157 |
except Exception as e:
|
| 158 |
+
return JSONResponse(content={"error": "云端代理读取失败,可能是源文件损坏"}, status_code=500)
|
router_items.py
CHANGED
|
@@ -21,11 +21,11 @@ def get_last_6_months():
|
|
| 21 |
return res
|
| 22 |
|
| 23 |
@router.get("/api/items")
|
| 24 |
-
async def get_items(type: str = "tool", sort: str = "time", limit: int =
|
| 25 |
items_db = db.load_data("items.json", default_data=[])
|
| 26 |
comments_db = db.load_data("comments.json", default_data={})
|
| 27 |
|
| 28 |
-
#
|
| 29 |
if type == "recommend":
|
| 30 |
filtered_items = [item for item in items_db if item.get("type", "").startswith("recommend")]
|
| 31 |
else:
|
|
@@ -34,10 +34,12 @@ async def get_items(type: str = "tool", sort: str = "time", limit: int = 20):
|
|
| 34 |
for item in filtered_items:
|
| 35 |
item["commentsData"] = comments_db.get(item["id"], [])
|
| 36 |
item["comments"] = len(item["commentsData"])
|
|
|
|
| 37 |
if sort == "likes": filtered_items.sort(key=lambda x: x.get("likes", 0), reverse=True)
|
| 38 |
elif sort == "favorites": filtered_items.sort(key=lambda x: x.get("favorites", 0), reverse=True)
|
| 39 |
elif sort == "downloads": filtered_items.sort(key=lambda x: x.get("uses", 0), reverse=True)
|
| 40 |
else: filtered_items.sort(key=lambda x: x.get("created_at", 0), reverse=True)
|
|
|
|
| 41 |
return {"status": "success", "data": filtered_items[:limit]}
|
| 42 |
|
| 43 |
@router.get("/api/creators")
|
|
@@ -54,7 +56,7 @@ async def get_creators(sort: str = "downloads", limit: int = 20):
|
|
| 54 |
|
| 55 |
trend_tools = {m: 0 for m in months}
|
| 56 |
trend_apps = {m: 0 for m in months}
|
| 57 |
-
trend_recommends = {m: 0 for m in months}
|
| 58 |
tools_count = 0
|
| 59 |
apps_count = 0
|
| 60 |
|
|
@@ -68,7 +70,6 @@ async def get_creators(sort: str = "downloads", limit: int = 20):
|
|
| 68 |
if itype == "app": apps_count += 1
|
| 69 |
for m in months: trend_apps[m] += history.get(m, 0)
|
| 70 |
elif itype.startswith("recommend"):
|
| 71 |
-
# 纯链接推荐
|
| 72 |
for m in months: trend_recommends[m] += history.get(m, 0)
|
| 73 |
|
| 74 |
creators.append({
|
|
@@ -82,7 +83,7 @@ async def get_creators(sort: str = "downloads", limit: int = 20):
|
|
| 82 |
"months": months,
|
| 83 |
"tools": [trend_tools[m] for m in months],
|
| 84 |
"apps": [trend_apps[m] for m in months],
|
| 85 |
-
"recommends": [trend_recommends[m] for m in months]
|
| 86 |
}
|
| 87 |
})
|
| 88 |
|
|
@@ -90,11 +91,13 @@ async def get_creators(sort: str = "downloads", limit: int = 20):
|
|
| 90 |
elif sort == "favorites": creators.sort(key=lambda x: x.get("favorites", 0), reverse=True)
|
| 91 |
elif sort == "downloads": creators.sort(key=lambda x: x.get("downloads", 0), reverse=True)
|
| 92 |
else: creators.sort(key=lambda x: x.get("created_at", 0), reverse=True)
|
|
|
|
| 93 |
return {"status": "success", "data": creators[:limit]}
|
| 94 |
|
| 95 |
@router.post("/api/items")
|
| 96 |
async def create_item(item: ItemCreate):
|
| 97 |
-
# 【安全
|
|
|
|
| 98 |
if item.price < 0:
|
| 99 |
raise HTTPException(status_code=400, detail="🚨 安全拦截:商品价格不能为负数")
|
| 100 |
|
|
@@ -110,28 +113,34 @@ async def create_item(item: ItemCreate):
|
|
| 110 |
|
| 111 |
@router.put("/api/items/{item_id}")
|
| 112 |
async def update_item(item_id: str, update_data: ItemUpdate, author: str):
|
| 113 |
-
# 【安全
|
| 114 |
-
if update_data.price is not None
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
| 117 |
items_db = db.load_data("items.json", default_data=[])
|
| 118 |
for item in items_db:
|
| 119 |
if item["id"] == item_id:
|
| 120 |
if item.get("author") != author: raise HTTPException(status_code=403, detail="无权修改他人发布的内容")
|
|
|
|
| 121 |
if update_data.title is not None: item["title"] = update_data.title
|
| 122 |
if update_data.shortDesc is not None: item["shortDesc"] = update_data.shortDesc
|
| 123 |
if update_data.fullDesc is not None: item["fullDesc"] = update_data.fullDesc
|
| 124 |
if update_data.link is not None: item["link"] = update_data.link
|
| 125 |
if update_data.coverUrl is not None: item["coverUrl"] = update_data.coverUrl
|
| 126 |
if update_data.price is not None: item["price"] = update_data.price
|
|
|
|
| 127 |
db.save_data("items.json", items_db)
|
| 128 |
return {"status": "success"}
|
|
|
|
| 129 |
raise HTTPException(status_code=404, detail="找不到该内容记录")
|
| 130 |
|
| 131 |
@router.delete("/api/items/{item_id}")
|
| 132 |
async def delete_item(item_id: str, author: str):
|
| 133 |
items_db = db.load_data("items.json", default_data=[])
|
| 134 |
target_idx = next((i for i, item in enumerate(items_db) if item["id"] == item_id), None)
|
|
|
|
| 135 |
if target_idx is None: raise HTTPException(status_code=404, detail="找不到该内容记录")
|
| 136 |
if items_db[target_idx].get("author") != author: raise HTTPException(status_code=403, detail="无权删除他人发布的内容")
|
| 137 |
|
|
@@ -142,6 +151,7 @@ async def delete_item(item_id: str, author: str):
|
|
| 142 |
if item_id in comments_db:
|
| 143 |
del comments_db[item_id]
|
| 144 |
db.save_data("comments.json", comments_db)
|
|
|
|
| 145 |
return {"status": "success"}
|
| 146 |
|
| 147 |
@router.post("/api/items/{item_id}/use")
|
|
@@ -157,4 +167,5 @@ async def record_item_use(item_id: str):
|
|
| 157 |
item["use_history"][current_month] = item["use_history"].get(current_month, 0) + 1
|
| 158 |
db.save_data("items.json", items_db)
|
| 159 |
return {"status": "success", "uses": item["uses"]}
|
|
|
|
| 160 |
raise HTTPException(status_code=404, detail="找不到该内容记录")
|
|
|
|
| 21 |
return res
|
| 22 |
|
| 23 |
@router.get("/api/items")
|
| 24 |
+
async def get_items(type: str = "tool", sort: str = "time", limit: int = 50): # 优化:默认限制调大至 50,提升前端列表体验
|
| 25 |
items_db = db.load_data("items.json", default_data=[])
|
| 26 |
comments_db = db.load_data("comments.json", default_data={})
|
| 27 |
|
| 28 |
+
# 如果是推荐榜,匹配所有 recommend 开头的子类型
|
| 29 |
if type == "recommend":
|
| 30 |
filtered_items = [item for item in items_db if item.get("type", "").startswith("recommend")]
|
| 31 |
else:
|
|
|
|
| 34 |
for item in filtered_items:
|
| 35 |
item["commentsData"] = comments_db.get(item["id"], [])
|
| 36 |
item["comments"] = len(item["commentsData"])
|
| 37 |
+
|
| 38 |
if sort == "likes": filtered_items.sort(key=lambda x: x.get("likes", 0), reverse=True)
|
| 39 |
elif sort == "favorites": filtered_items.sort(key=lambda x: x.get("favorites", 0), reverse=True)
|
| 40 |
elif sort == "downloads": filtered_items.sort(key=lambda x: x.get("uses", 0), reverse=True)
|
| 41 |
else: filtered_items.sort(key=lambda x: x.get("created_at", 0), reverse=True)
|
| 42 |
+
|
| 43 |
return {"status": "success", "data": filtered_items[:limit]}
|
| 44 |
|
| 45 |
@router.get("/api/creators")
|
|
|
|
| 56 |
|
| 57 |
trend_tools = {m: 0 for m in months}
|
| 58 |
trend_apps = {m: 0 for m in months}
|
| 59 |
+
trend_recommends = {m: 0 for m in months}
|
| 60 |
tools_count = 0
|
| 61 |
apps_count = 0
|
| 62 |
|
|
|
|
| 70 |
if itype == "app": apps_count += 1
|
| 71 |
for m in months: trend_apps[m] += history.get(m, 0)
|
| 72 |
elif itype.startswith("recommend"):
|
|
|
|
| 73 |
for m in months: trend_recommends[m] += history.get(m, 0)
|
| 74 |
|
| 75 |
creators.append({
|
|
|
|
| 83 |
"months": months,
|
| 84 |
"tools": [trend_tools[m] for m in months],
|
| 85 |
"apps": [trend_apps[m] for m in months],
|
| 86 |
+
"recommends": [trend_recommends[m] for m in months]
|
| 87 |
}
|
| 88 |
})
|
| 89 |
|
|
|
|
| 91 |
elif sort == "favorites": creators.sort(key=lambda x: x.get("favorites", 0), reverse=True)
|
| 92 |
elif sort == "downloads": creators.sort(key=lambda x: x.get("downloads", 0), reverse=True)
|
| 93 |
else: creators.sort(key=lambda x: x.get("created_at", 0), reverse=True)
|
| 94 |
+
|
| 95 |
return {"status": "success", "data": creators[:limit]}
|
| 96 |
|
| 97 |
@router.post("/api/items")
|
| 98 |
async def create_item(item: ItemCreate):
|
| 99 |
+
# 【安全加固】:强制转换为整数,并拦截负数 (防浮点漏洞与洗钱)
|
| 100 |
+
item.price = int(item.price)
|
| 101 |
if item.price < 0:
|
| 102 |
raise HTTPException(status_code=400, detail="🚨 安全拦截:商品价格不能为负数")
|
| 103 |
|
|
|
|
| 113 |
|
| 114 |
@router.put("/api/items/{item_id}")
|
| 115 |
async def update_item(item_id: str, update_data: ItemUpdate, author: str):
|
| 116 |
+
# 【安全加固】:更新时同样强制转换为整数并拦截负数
|
| 117 |
+
if update_data.price is not None:
|
| 118 |
+
update_data.price = int(update_data.price)
|
| 119 |
+
if update_data.price < 0:
|
| 120 |
+
raise HTTPException(status_code=400, detail="🚨 安全拦截:商品价格不能为负数")
|
| 121 |
+
|
| 122 |
items_db = db.load_data("items.json", default_data=[])
|
| 123 |
for item in items_db:
|
| 124 |
if item["id"] == item_id:
|
| 125 |
if item.get("author") != author: raise HTTPException(status_code=403, detail="无权修改他人发布的内容")
|
| 126 |
+
|
| 127 |
if update_data.title is not None: item["title"] = update_data.title
|
| 128 |
if update_data.shortDesc is not None: item["shortDesc"] = update_data.shortDesc
|
| 129 |
if update_data.fullDesc is not None: item["fullDesc"] = update_data.fullDesc
|
| 130 |
if update_data.link is not None: item["link"] = update_data.link
|
| 131 |
if update_data.coverUrl is not None: item["coverUrl"] = update_data.coverUrl
|
| 132 |
if update_data.price is not None: item["price"] = update_data.price
|
| 133 |
+
|
| 134 |
db.save_data("items.json", items_db)
|
| 135 |
return {"status": "success"}
|
| 136 |
+
|
| 137 |
raise HTTPException(status_code=404, detail="找不到该内容记录")
|
| 138 |
|
| 139 |
@router.delete("/api/items/{item_id}")
|
| 140 |
async def delete_item(item_id: str, author: str):
|
| 141 |
items_db = db.load_data("items.json", default_data=[])
|
| 142 |
target_idx = next((i for i, item in enumerate(items_db) if item["id"] == item_id), None)
|
| 143 |
+
|
| 144 |
if target_idx is None: raise HTTPException(status_code=404, detail="找不到该内容记录")
|
| 145 |
if items_db[target_idx].get("author") != author: raise HTTPException(status_code=403, detail="无权删除他人发布的内容")
|
| 146 |
|
|
|
|
| 151 |
if item_id in comments_db:
|
| 152 |
del comments_db[item_id]
|
| 153 |
db.save_data("comments.json", comments_db)
|
| 154 |
+
|
| 155 |
return {"status": "success"}
|
| 156 |
|
| 157 |
@router.post("/api/items/{item_id}/use")
|
|
|
|
| 167 |
item["use_history"][current_month] = item["use_history"].get(current_month, 0) + 1
|
| 168 |
db.save_data("items.json", items_db)
|
| 169 |
return {"status": "success", "uses": item["uses"]}
|
| 170 |
+
|
| 171 |
raise HTTPException(status_code=404, detail="找不到该内容记录")
|
数据库连接.py
CHANGED
|
@@ -1,86 +1,89 @@
|
|
| 1 |
# ⚙️ 后端逻辑/数据库连接.py
|
| 2 |
import os
|
| 3 |
import json
|
|
|
|
| 4 |
from huggingface_hub import HfApi, hf_hub_download
|
| 5 |
|
| 6 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 7 |
-
# 已经替换为你真实的 Dataset 仓库
|
| 8 |
DATASET_REPO_ID = "ZHIWEI666/ComfyUI-Ranking"
|
| 9 |
|
| 10 |
-
# ==========================================
|
| 11 |
-
# 【核心修改】:智能判断运行环境
|
| 12 |
-
# ==========================================
|
| 13 |
if os.environ.get("SPACE_ID"):
|
| 14 |
-
# 如果检测到在 Hugging Face Spaces 环境运行,则使用 /tmp 目录(绕过容器只读限制)
|
| 15 |
LOCAL_DB_DIR = "/tmp/local_db_data"
|
| 16 |
else:
|
| 17 |
-
# 本地运行时,自动在当前 Python 文件同级目录下创建一个 "cache" 文件夹
|
| 18 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
LOCAL_DB_DIR = os.path.join(BASE_DIR, "cache")
|
| 20 |
|
| 21 |
api = HfApi() if HF_TOKEN else None
|
| 22 |
|
| 23 |
-
# 确保本地缓存目录存在
|
| 24 |
if not os.path.exists(LOCAL_DB_DIR):
|
| 25 |
os.makedirs(LOCAL_DB_DIR)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
| 27 |
def load_data(file_name: str, default_data=None):
|
| 28 |
-
"""读取数据:
|
| 29 |
if default_data is None:
|
| 30 |
default_data = {} if file_name == "users.json" else []
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return default_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
except Exception as e:
|
| 43 |
-
print(f"
|
| 44 |
-
return default_data
|
| 45 |
|
| 46 |
def save_data(file_name: str, data):
|
| 47 |
-
"""保存数据:写入本地并同步
|
| 48 |
local_path = os.path.join(LOCAL_DB_DIR, file_name)
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
|
| 54 |
-
#
|
| 55 |
if HF_TOKEN:
|
| 56 |
-
|
| 57 |
-
api.upload_file(
|
| 58 |
-
path_or_fileobj=local_path,
|
| 59 |
-
path_in_repo=file_name,
|
| 60 |
-
repo_id=DATASET_REPO_ID,
|
| 61 |
-
repo_type="dataset",
|
| 62 |
-
token=HF_TOKEN,
|
| 63 |
-
commit_message=f"Auto-update {file_name}"
|
| 64 |
-
)
|
| 65 |
-
except Exception as e:
|
| 66 |
-
print(f"同步到 HF Dataset 失败: {e}")
|
| 67 |
|
| 68 |
-
# --- 保存真实的二进制文件 ---
|
| 69 |
def save_file(file_path_in_repo: str, content: bytes):
|
|
|
|
| 70 |
local_full_path = os.path.join(LOCAL_DB_DIR, file_path_in_repo)
|
| 71 |
-
|
| 72 |
-
# 自动创建子目录(如 avatars/, tools/)
|
| 73 |
os.makedirs(os.path.dirname(local_full_path), exist_ok=True)
|
| 74 |
|
| 75 |
-
with
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
if HF_TOKEN:
|
| 79 |
-
|
| 80 |
-
api.upload_file(
|
| 81 |
-
path_or_fileobj=local_full_path, path_in_repo=file_path_in_repo,
|
| 82 |
-
repo_id=DATASET_REPO_ID, repo_type="dataset",
|
| 83 |
-
token=HF_TOKEN, commit_message=f"Upload File: {file_path_in_repo}"
|
| 84 |
-
)
|
| 85 |
-
except Exception as e:
|
| 86 |
-
print(f"同步文件到 HF Dataset 失败: {e}")
|
|
|
|
| 1 |
# ⚙️ 后端逻辑/数据库连接.py
|
| 2 |
import os
|
| 3 |
import json
|
| 4 |
+
import threading
|
| 5 |
from huggingface_hub import HfApi, hf_hub_download
|
| 6 |
|
| 7 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
| 8 |
DATASET_REPO_ID = "ZHIWEI666/ComfyUI-Ranking"
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
if os.environ.get("SPACE_ID"):
|
|
|
|
| 11 |
LOCAL_DB_DIR = "/tmp/local_db_data"
|
| 12 |
else:
|
|
|
|
| 13 |
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
LOCAL_DB_DIR = os.path.join(BASE_DIR, "cache")
|
| 15 |
|
| 16 |
api = HfApi() if HF_TOKEN else None
|
| 17 |
|
|
|
|
| 18 |
if not os.path.exists(LOCAL_DB_DIR):
|
| 19 |
os.makedirs(LOCAL_DB_DIR)
|
| 20 |
|
| 21 |
+
# 【核心优化 1】:引入全局读写锁,防止高并发下 JSON 数据覆写和丢失
|
| 22 |
+
db_lock = threading.Lock()
|
| 23 |
+
|
| 24 |
def load_data(file_name: str, default_data=None):
|
| 25 |
+
"""读取数据:引入线程锁,保证读取时不会读到写入一半的残缺数据"""
|
| 26 |
if default_data is None:
|
| 27 |
default_data = {} if file_name == "users.json" else []
|
| 28 |
|
| 29 |
+
local_path = os.path.join(LOCAL_DB_DIR, file_name)
|
| 30 |
+
|
| 31 |
+
with db_lock:
|
| 32 |
+
if not os.path.exists(local_path):
|
| 33 |
+
if HF_TOKEN:
|
| 34 |
+
try:
|
| 35 |
+
file_path = hf_hub_download(repo_id=DATASET_REPO_ID, repo_type="dataset", filename=file_name, token=HF_TOKEN)
|
| 36 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 37 |
+
data = json.load(f)
|
| 38 |
+
with open(local_path, "w", encoding="utf-8") as f:
|
| 39 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 40 |
+
return data
|
| 41 |
+
except Exception:
|
| 42 |
+
return default_data
|
| 43 |
+
else:
|
| 44 |
return default_data
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
with open(local_path, "r", encoding="utf-8") as f:
|
| 48 |
+
return json.load(f)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print(f"解析 {file_name} 失败,启用默认数据。原因: {e}")
|
| 51 |
+
return default_data
|
| 52 |
|
| 53 |
+
def _background_upload_to_hf(local_path, file_name):
|
| 54 |
+
"""【核心优化 2】:后台独立线程执行 HF 推送,彻底解放 FastAPI 主线程性能"""
|
| 55 |
+
try:
|
| 56 |
+
api.upload_file(
|
| 57 |
+
path_or_fileobj=local_path,
|
| 58 |
+
path_in_repo=file_name,
|
| 59 |
+
repo_id=DATASET_REPO_ID,
|
| 60 |
+
repo_type="dataset",
|
| 61 |
+
token=HF_TOKEN,
|
| 62 |
+
commit_message=f"Auto-update {file_name}"
|
| 63 |
+
)
|
| 64 |
except Exception as e:
|
| 65 |
+
print(f"后台同步到 HF Dataset 失败: {e}")
|
|
|
|
| 66 |
|
| 67 |
def save_data(file_name: str, data):
|
| 68 |
+
"""保存数据:加锁写入本地,并触发异步后台同步"""
|
| 69 |
local_path = os.path.join(LOCAL_DB_DIR, file_name)
|
| 70 |
|
| 71 |
+
with db_lock:
|
| 72 |
+
with open(local_path, "w", encoding="utf-8") as f:
|
| 73 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 74 |
|
| 75 |
+
# 触发后台线程推送,接口毫秒级返回
|
| 76 |
if HF_TOKEN:
|
| 77 |
+
threading.Thread(target=_background_upload_to_hf, args=(local_path, file_name)).start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
|
|
|
| 79 |
def save_file(file_path_in_repo: str, content: bytes):
|
| 80 |
+
"""保存二进制文件(图片/应用等)"""
|
| 81 |
local_full_path = os.path.join(LOCAL_DB_DIR, file_path_in_repo)
|
|
|
|
|
|
|
| 82 |
os.makedirs(os.path.dirname(local_full_path), exist_ok=True)
|
| 83 |
|
| 84 |
+
with db_lock:
|
| 85 |
+
with open(local_full_path, "wb") as f:
|
| 86 |
+
f.write(content)
|
| 87 |
+
|
| 88 |
if HF_TOKEN:
|
| 89 |
+
threading.Thread(target=_background_upload_to_hf, args=(local_full_path, file_path_in_repo)).start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|