Spaces:
Running
Running
Upload 8 files
Browse files- app.py +9 -0
- notifications.py +1 -1
- router_comments.py +1 -1
- router_items.py +1 -1
- router_messages.py +2 -2
- router_users_auth.py +35 -17
- 数据库连接.py +80 -55
app.py
CHANGED
|
@@ -222,6 +222,10 @@ async def on_startup():
|
|
| 222 |
asyncio.create_task(daily_version_check_task())
|
| 223 |
logger.info("✅ 定时版本检测任务已挂载")
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
logger.info("🎉 ComfyUI-Ranking API 启动完成!")
|
| 226 |
|
| 227 |
|
|
@@ -229,6 +233,11 @@ async def on_startup():
|
|
| 229 |
async def on_shutdown():
|
| 230 |
"""优雅关闭,清理资源"""
|
| 231 |
logger.info("🛑 ComfyUI-Ranking API 正在关闭...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
# 这里可以添加其他清理逻辑(如关闭连接池等)
|
| 233 |
logger.info("✅ 关闭完成")
|
| 234 |
|
|
|
|
| 222 |
asyncio.create_task(daily_version_check_task())
|
| 223 |
logger.info("✅ 定时版本检测任务已挂载")
|
| 224 |
|
| 225 |
+
# ========== 启动 HF 批量同步定时器 ==========
|
| 226 |
+
db.start_batch_sync()
|
| 227 |
+
logger.info("✅ HF 批量同步定时器已启动")
|
| 228 |
+
|
| 229 |
logger.info("🎉 ComfyUI-Ranking API 启动完成!")
|
| 230 |
|
| 231 |
|
|
|
|
| 233 |
async def on_shutdown():
|
| 234 |
"""优雅关闭,清理资源"""
|
| 235 |
logger.info("🛑 ComfyUI-Ranking API 正在关闭...")
|
| 236 |
+
|
| 237 |
+
# ========== 关闭前同步所有脏文件 ==========
|
| 238 |
+
db.flush_sync()
|
| 239 |
+
logger.info("✅ HF 批量同步已完成")
|
| 240 |
+
|
| 241 |
# 这里可以添加其他清理逻辑(如关闭连接池等)
|
| 242 |
logger.info("✅ 关闭完成")
|
| 243 |
|
notifications.py
CHANGED
|
@@ -15,7 +15,7 @@ def add_notification(target_account: str, notif_data: dict):
|
|
| 15 |
"type": notif_data.get("type"),
|
| 16 |
"from_user": from_user,
|
| 17 |
"from_name": user_info.get("name", from_user),
|
| 18 |
-
"from_avatar": user_info.get("avatarDataUrl", "
|
| 19 |
"target_item_id": notif_data.get("target_item_id", ""),
|
| 20 |
"target_item_title": notif_data.get("target_item_title", ""),
|
| 21 |
"content": notif_data.get("content", ""),
|
|
|
|
| 15 |
"type": notif_data.get("type"),
|
| 16 |
"from_user": from_user,
|
| 17 |
"from_name": user_info.get("name", from_user),
|
| 18 |
+
"from_avatar": user_info.get("avatarDataUrl", "")
|
| 19 |
"target_item_id": notif_data.get("target_item_id", ""),
|
| 20 |
"target_item_title": notif_data.get("target_item_title", ""),
|
| 21 |
"content": notif_data.get("content", ""),
|
router_comments.py
CHANGED
|
@@ -42,7 +42,7 @@ async def post_comment(comment: CommentCreate):
|
|
| 42 |
item_comments = comments_db.get(comment.item_id, [])
|
| 43 |
new_comment = {
|
| 44 |
"id": f"c_{int(time.time())}_{uuid.uuid4().hex[:6]}", "author": comment.author,
|
| 45 |
-
"authorName": author_info.get("name", comment.author), "avatar": author_info.get("avatarDataUrl", "
|
| 46 |
"content": comment.content, "replyToUser": comment.reply_to_user, "replyToUserName": reply_name,
|
| 47 |
"isDeleted": False, "replies": [], "created_at": int(time.time())
|
| 48 |
}
|
|
|
|
| 42 |
item_comments = comments_db.get(comment.item_id, [])
|
| 43 |
new_comment = {
|
| 44 |
"id": f"c_{int(time.time())}_{uuid.uuid4().hex[:6]}", "author": comment.author,
|
| 45 |
+
"authorName": author_info.get("name", comment.author), "avatar": author_info.get("avatarDataUrl", ""),
|
| 46 |
"content": comment.content, "replyToUser": comment.reply_to_user, "replyToUserName": reply_name,
|
| 47 |
"isDeleted": False, "replies": [], "created_at": int(time.time())
|
| 48 |
}
|
router_items.py
CHANGED
|
@@ -121,7 +121,7 @@ async def get_creators(sort: str = "downloads", limit: int = 20):
|
|
| 121 |
for m in months: trend_recommends[m] += history.get(m, 0)
|
| 122 |
|
| 123 |
creators.append({
|
| 124 |
-
"account": account, "name": u.get("name", account), "avatar": u.get("avatarDataUrl", "
|
| 125 |
"bannerUrl": u.get("bannerUrl"), # 🖼️ 个人资料卡背景图
|
| 126 |
"shortDesc": u.get("intro") or "这个人很懒,什么都没写...", "fullDesc": u.get("intro") or "这个人很懒,什么都没写...",
|
| 127 |
"likes": sum(i.get("likes", 0) for i in u_items), "favorites": sum(i.get("favorites", 0) for i in u_items),
|
|
|
|
| 121 |
for m in months: trend_recommends[m] += history.get(m, 0)
|
| 122 |
|
| 123 |
creators.append({
|
| 124 |
+
"account": account, "name": u.get("name", account), "avatar": u.get("avatarDataUrl", ""),
|
| 125 |
"bannerUrl": u.get("bannerUrl"), # 🖼️ 个人资料卡背景图
|
| 126 |
"shortDesc": u.get("intro") or "这个人很懒,什么都没写...", "fullDesc": u.get("intro") or "这个人很懒,什么都没写...",
|
| 127 |
"likes": sum(i.get("likes", 0) for i in u_items), "favorites": sum(i.get("favorites", 0) for i in u_items),
|
router_messages.py
CHANGED
|
@@ -35,7 +35,7 @@ async def publish_announcement(ann: SystemAnnouncement, current_user: str = Depe
|
|
| 35 |
"type": "system",
|
| 36 |
"from_user": current_user, # 使用真实的管理员账号
|
| 37 |
"from_name": "官方团队",
|
| 38 |
-
"from_avatar": "
|
| 39 |
"content": ann.content,
|
| 40 |
"created_at": int(time.time())
|
| 41 |
}
|
|
@@ -163,7 +163,7 @@ async def get_chat_list(account: str):
|
|
| 163 |
chat_list.append({
|
| 164 |
"account": other_account,
|
| 165 |
"name": other_user.get("name", other_account),
|
| 166 |
-
"avatar": other_user.get("avatarDataUrl", "
|
| 167 |
"last_message": last_msg["content"],
|
| 168 |
"last_time": last_msg["created_at"],
|
| 169 |
"unread_count": unread_count
|
|
|
|
| 35 |
"type": "system",
|
| 36 |
"from_user": current_user, # 使用真实的管理员账号
|
| 37 |
"from_name": "官方团队",
|
| 38 |
+
"from_avatar": ""
|
| 39 |
"content": ann.content,
|
| 40 |
"created_at": int(time.time())
|
| 41 |
}
|
|
|
|
| 163 |
chat_list.append({
|
| 164 |
"account": other_account,
|
| 165 |
"name": other_user.get("name", other_account),
|
| 166 |
+
"avatar": other_user.get("avatarDataUrl", "")
|
| 167 |
"last_message": last_msg["content"],
|
| 168 |
"last_time": last_msg["created_at"],
|
| 169 |
"unread_count": unread_count
|
router_users_auth.py
CHANGED
|
@@ -17,7 +17,7 @@ import random
|
|
| 17 |
import json
|
| 18 |
import 数据库连接 as db
|
| 19 |
from models import UserRegister, UserLogin, SendCodeRequest
|
| 20 |
-
from verify_code_engine import VERIFY_CODES, send_email_code, send_sms_code, cleanup_expired_codes
|
| 21 |
|
| 22 |
# 🔒 P0安全增强:导入密码哈希和 JWT 工具
|
| 23 |
from 安全认证 import hash_password, verify_password, create_token, require_password_match
|
|
@@ -65,6 +65,9 @@ async def send_verify_code(request: Request, req: SendCodeRequest, bg_tasks: Bac
|
|
| 65 |
if req.contact_type == "phone" and user.get("phone") != req.contact:
|
| 66 |
raise HTTPException(status_code=400, detail="填写的手机号与该账号绑定的手机号不一致")
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
# 生成6位随机验证码
|
| 69 |
code = str(random.randint(100000, 999999))
|
| 70 |
|
|
@@ -74,10 +77,10 @@ async def send_verify_code(request: Request, req: SendCodeRequest, bg_tasks: Bac
|
|
| 74 |
# 构建缓存键(联系方式_动作类型)
|
| 75 |
cache_key = f"{req.contact}_{req.action_type}"
|
| 76 |
|
| 77 |
-
# 将验证码存入内存缓存,有效期
|
| 78 |
VERIFY_CODES[cache_key] = {
|
| 79 |
"code": code,
|
| 80 |
-
"expires_at": int(time.time()) +
|
| 81 |
}
|
| 82 |
|
| 83 |
# 根据联系方式类型,添加后台发送任务
|
|
@@ -99,14 +102,17 @@ async def send_verify_code(request: Request, req: SendCodeRequest, bg_tasks: Bac
|
|
| 99 |
@router.post("/api/users/send_code")
|
| 100 |
async def send_code_api(req: SendCodeRequest):
|
| 101 |
"""发送验证码接口(同步版本,直接等待发送结果)"""
|
|
|
|
|
|
|
|
|
|
| 102 |
# 生成6位随机验证码
|
| 103 |
code = str(random.randint(100000, 999999))
|
| 104 |
key = f"{req.contact}_{req.action_type}"
|
| 105 |
|
| 106 |
-
# 存入缓存
|
| 107 |
VERIFY_CODES[key] = {
|
| 108 |
"code": code,
|
| 109 |
-
"expires_at": time.time() +
|
| 110 |
}
|
| 111 |
|
| 112 |
# 同步发送(会阻塞等待结果)
|
|
@@ -161,10 +167,11 @@ async def register_user(request: Request, user: UserRegister):
|
|
| 161 |
if user.phone and existing_user.get("phone") == user.phone:
|
| 162 |
raise HTTPException(status_code=400, detail="该手机号已被绑定")
|
| 163 |
|
| 164 |
-
# ========== 第二步:验证码校验 ==========
|
| 165 |
# 根据注册方式构建缓存键
|
| 166 |
cache_key = f"{user.email}_register" if user.email else f"{user.phone}_register"
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
# 兼容新老缓存格式(expires_at 或 expires)
|
| 170 |
expire_time = cached.get("expires_at", cached.get("expires", 0)) if cached else 0
|
|
@@ -186,9 +193,6 @@ async def register_user(request: Request, user: UserRegister):
|
|
| 186 |
raise HTTPException(status_code=400, detail="个人介绍不能超过100个字符")
|
| 187 |
|
| 188 |
# ========== 第四步:保存新用户 ==========
|
| 189 |
-
# 验证通过后,清除已使用的验证码
|
| 190 |
-
VERIFY_CODES.pop(cache_key, None)
|
| 191 |
-
|
| 192 |
# 构建用户数据对象
|
| 193 |
new_user = user.dict()
|
| 194 |
new_user.pop("code", None) # 移除验证码字段,不存入数据库
|
|
@@ -241,15 +245,23 @@ async def login_user(request: Request, user: UserLogin):
|
|
| 241 |
if not require_password_match(stored_password, user.password):
|
| 242 |
raise HTTPException(status_code=401, detail="密码错误")
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
# 🔒 P0安全增强:生成 JWT Token(替代 mock_token)
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
|
| 247 |
return {
|
| 248 |
"status": "success",
|
| 249 |
"token": token, # 🔒 JWT Token
|
| 250 |
"account": user.account,
|
| 251 |
"name": user_data["name"],
|
| 252 |
-
"avatar": user_data.get("avatarDataUrl", "
|
| 253 |
}
|
| 254 |
|
| 255 |
|
|
@@ -317,9 +329,10 @@ async def reset_password(request: Request):
|
|
| 317 |
if verify_type == "phone" and user.get("phone") != verify_contact:
|
| 318 |
raise HTTPException(status_code=400, detail="填写的手机号与该账号绑定的手机号不匹配")
|
| 319 |
|
| 320 |
-
# 校验验证码
|
| 321 |
cache_key = f"{verify_contact}_reset"
|
| 322 |
-
|
|
|
|
| 323 |
expire_time = cached.get("expires_at", cached.get("expires", 0)) if cached else 0
|
| 324 |
|
| 325 |
if not cached or cached["code"] != code or time.time() > expire_time:
|
|
@@ -332,11 +345,16 @@ async def reset_password(request: Request):
|
|
| 332 |
raise HTTPException(status_code=400, detail="新密码包含不支持的特殊字符")
|
| 333 |
|
| 334 |
# ========== 第四步:更新密码并保存 ==========
|
| 335 |
-
VERIFY_CODES.pop(cache_key, None) # 清除已使用的验证码
|
| 336 |
-
|
| 337 |
# 🔒 P0安全增强:新密码哈希化存储
|
| 338 |
user["password"] = hash_password(new_password)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
db.save_data("users.json", users_db)
|
| 341 |
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import json
|
| 18 |
import 数据库连接 as db
|
| 19 |
from models import UserRegister, UserLogin, SendCodeRequest
|
| 20 |
+
from verify_code_engine import VERIFY_CODES, send_email_code, send_sms_code, cleanup_expired_codes, check_send_cooldown
|
| 21 |
|
| 22 |
# 🔒 P0安全增强:导入密码哈希和 JWT 工具
|
| 23 |
from 安全认证 import hash_password, verify_password, create_token, require_password_match
|
|
|
|
| 65 |
if req.contact_type == "phone" and user.get("phone") != req.contact:
|
| 66 |
raise HTTPException(status_code=400, detail="填写的手机号与该账号绑定的手机号不一致")
|
| 67 |
|
| 68 |
+
# 🔒 检查发送频率限制(同一联系方式60秒内只能发送1条)
|
| 69 |
+
check_send_cooldown(req.contact)
|
| 70 |
+
|
| 71 |
# 生成6位随机验证码
|
| 72 |
code = str(random.randint(100000, 999999))
|
| 73 |
|
|
|
|
| 77 |
# 构建缓存键(联系方式_动作类型)
|
| 78 |
cache_key = f"{req.contact}_{req.action_type}"
|
| 79 |
|
| 80 |
+
# 将验证码存入内存缓存,有效期5分钟
|
| 81 |
VERIFY_CODES[cache_key] = {
|
| 82 |
"code": code,
|
| 83 |
+
"expires_at": int(time.time()) + 300 # 当前时间 + 300秒(5分钟)
|
| 84 |
}
|
| 85 |
|
| 86 |
# 根据联系方式类型,添加后台发送任务
|
|
|
|
| 102 |
@router.post("/api/users/send_code")
|
| 103 |
async def send_code_api(req: SendCodeRequest):
|
| 104 |
"""发送验证码接口(同步版本,直接等待发送结果)"""
|
| 105 |
+
# 🔒 检查发送频率限制(同一联系方式60秒内只能发送1条)
|
| 106 |
+
check_send_cooldown(req.contact)
|
| 107 |
+
|
| 108 |
# 生成6位随机验证码
|
| 109 |
code = str(random.randint(100000, 999999))
|
| 110 |
key = f"{req.contact}_{req.action_type}"
|
| 111 |
|
| 112 |
+
# 存入缓存,有效期5分钟
|
| 113 |
VERIFY_CODES[key] = {
|
| 114 |
"code": code,
|
| 115 |
+
"expires_at": time.time() + 300 # 300秒(5分钟)
|
| 116 |
}
|
| 117 |
|
| 118 |
# 同步发送(会阻塞等待结果)
|
|
|
|
| 167 |
if user.phone and existing_user.get("phone") == user.phone:
|
| 168 |
raise HTTPException(status_code=400, detail="该手机号已被绑定")
|
| 169 |
|
| 170 |
+
# ========== 第二步:验证码校验(原子性获取+删除) ==========
|
| 171 |
# 根据注册方式构建缓存键
|
| 172 |
cache_key = f"{user.email}_register" if user.email else f"{user.phone}_register"
|
| 173 |
+
# 🔒 P0安全修复:验证码一次性使用,原子性pop防止并发重用
|
| 174 |
+
cached = VERIFY_CODES.pop(cache_key, None)
|
| 175 |
|
| 176 |
# 兼容新老缓存格式(expires_at 或 expires)
|
| 177 |
expire_time = cached.get("expires_at", cached.get("expires", 0)) if cached else 0
|
|
|
|
| 193 |
raise HTTPException(status_code=400, detail="个人介绍不能超过100个字符")
|
| 194 |
|
| 195 |
# ========== 第四步:保存新用户 ==========
|
|
|
|
|
|
|
|
|
|
| 196 |
# 构建用户数据对象
|
| 197 |
new_user = user.dict()
|
| 198 |
new_user.pop("code", None) # 移除验证码字段,不存入数据库
|
|
|
|
| 245 |
if not require_password_match(stored_password, user.password):
|
| 246 |
raise HTTPException(status_code=401, detail="密码错误")
|
| 247 |
|
| 248 |
+
# 🔒 P0安全增强:登录成功后,检查是否需要迁移旧密码为bcrypt
|
| 249 |
+
if not user_data["password"].startswith('$2b$') and not user_data["password"].startswith('$2a$'):
|
| 250 |
+
# 旧版SHA256密码,自动迁移为bcrypt
|
| 251 |
+
user_data["password"] = hash_password(user.password)
|
| 252 |
+
db.save_data("users.json", users_db)
|
| 253 |
+
|
| 254 |
# 🔒 P0安全增强:生成 JWT Token(替代 mock_token)
|
| 255 |
+
# 获取password_version用于Token生成(如不存在则默认为0)
|
| 256 |
+
password_version = user_data.get("password_version", 0)
|
| 257 |
+
token = create_token(user.account, extra_data={"pwd_ver": password_version})
|
| 258 |
|
| 259 |
return {
|
| 260 |
"status": "success",
|
| 261 |
"token": token, # 🔒 JWT Token
|
| 262 |
"account": user.account,
|
| 263 |
"name": user_data["name"],
|
| 264 |
+
"avatar": user_data.get("avatarDataUrl", "")
|
| 265 |
}
|
| 266 |
|
| 267 |
|
|
|
|
| 329 |
if verify_type == "phone" and user.get("phone") != verify_contact:
|
| 330 |
raise HTTPException(status_code=400, detail="填写的手机号与该账号绑定的手机号不匹配")
|
| 331 |
|
| 332 |
+
# 校验验证码(原子性获取+删除)
|
| 333 |
cache_key = f"{verify_contact}_reset"
|
| 334 |
+
# 🔒 P0安全修复:验证码一次性使用,原子性pop防止并发重用
|
| 335 |
+
cached = VERIFY_CODES.pop(cache_key, None)
|
| 336 |
expire_time = cached.get("expires_at", cached.get("expires", 0)) if cached else 0
|
| 337 |
|
| 338 |
if not cached or cached["code"] != code or time.time() > expire_time:
|
|
|
|
| 345 |
raise HTTPException(status_code=400, detail="新密码包含不支持的特殊字符")
|
| 346 |
|
| 347 |
# ========== 第四步:更新密码并保存 ==========
|
|
|
|
|
|
|
| 348 |
# 🔒 P0安全增强:新密码哈希化存储
|
| 349 |
user["password"] = hash_password(new_password)
|
| 350 |
|
| 351 |
+
# 🔒 P0安全增强:更新password_version使旧Token失效
|
| 352 |
+
import time as time_module
|
| 353 |
+
user["password_version"] = int(time_module.time())
|
| 354 |
+
|
| 355 |
db.save_data("users.json", users_db)
|
| 356 |
|
| 357 |
+
# 🔒 P0安全增强:生成新Token返回给前端替换旧Token
|
| 358 |
+
new_token = create_token(account, extra_data={"pwd_ver": user["password_version"]})
|
| 359 |
+
|
| 360 |
+
return {"status": "success", "message": "密码修改成功", "token": new_token}
|
数据库连接.py
CHANGED
|
@@ -25,8 +25,7 @@ import shutil
|
|
| 25 |
import tempfile
|
| 26 |
import logging
|
| 27 |
from typing import Any, Dict, List, Optional, Union
|
| 28 |
-
from
|
| 29 |
-
from huggingface_hub import HfApi, hf_hub_download
|
| 30 |
|
| 31 |
# 📝 日志配置
|
| 32 |
logger = logging.getLogger("ComfyUI-Ranking.DB")
|
|
@@ -51,8 +50,11 @@ BACKUP_DIR = os.path.join(LOCAL_DB_DIR, "_backups")
|
|
| 51 |
# HuggingFace API 客户端
|
| 52 |
api = HfApi() if HF_TOKEN else None
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# 确保目录存在
|
| 58 |
os.makedirs(LOCAL_DB_DIR, exist_ok=True)
|
|
@@ -403,63 +405,89 @@ def save_data(file_name: str, data: Union[Dict, List]) -> bool:
|
|
| 403 |
# ========== 🚀 P1优化:更新内存缓存 ==========
|
| 404 |
_set_to_cache(file_name, data, local_path)
|
| 405 |
|
| 406 |
-
# ========== 第五步:
|
| 407 |
-
# 🔧 P3优化:使用线程池替代直接创建线程
|
| 408 |
if HF_TOKEN:
|
| 409 |
-
|
| 410 |
-
_upload_executor.submit(_background_upload_to_hf, local_path, file_name)
|
| 411 |
-
except Exception as e:
|
| 412 |
-
logger.warning(f"提交上传任务失败: {e}")
|
| 413 |
|
| 414 |
|
| 415 |
# ==========================================
|
| 416 |
-
# ☁️
|
| 417 |
# ==========================================
|
| 418 |
-
# 特点:
|
| 419 |
-
# - 后台线程执行,不阻塞主流程
|
| 420 |
-
# - 失败自动重试(最多3次)
|
| 421 |
-
# - 指数退避策略
|
| 422 |
|
| 423 |
-
def
|
| 424 |
-
"""
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
repo_id=DATASET_REPO_ID,
|
| 438 |
repo_type="dataset",
|
| 439 |
-
|
| 440 |
-
commit_message=f"
|
| 441 |
)
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
|
| 454 |
|
| 455 |
-
def
|
| 456 |
-
"""
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
|
| 465 |
# ==========================================
|
|
@@ -658,11 +686,8 @@ def atomic_update(file_name: str, updater, default_data=None):
|
|
| 658 |
# ========== 第四步:更新内存缓存 ==========
|
| 659 |
_set_to_cache(file_name, data, local_path)
|
| 660 |
|
| 661 |
-
# ========== 第五步:
|
| 662 |
if HF_TOKEN:
|
| 663 |
-
|
| 664 |
-
_upload_executor.submit(_background_upload_to_hf, local_path, file_name)
|
| 665 |
-
except Exception as e:
|
| 666 |
-
logger.warning(f"提交上传任务失败: {e}")
|
| 667 |
|
| 668 |
return result
|
|
|
|
| 25 |
import tempfile
|
| 26 |
import logging
|
| 27 |
from typing import Any, Dict, List, Optional, Union
|
| 28 |
+
from huggingface_hub import HfApi, CommitOperationAdd, hf_hub_download
|
|
|
|
| 29 |
|
| 30 |
# 📝 日志配置
|
| 31 |
logger = logging.getLogger("ComfyUI-Ranking.DB")
|
|
|
|
| 50 |
# HuggingFace API 客户端
|
| 51 |
api = HfApi() if HF_TOKEN else None
|
| 52 |
|
| 53 |
+
# ===== HF 批量同步机制 =====
|
| 54 |
+
_dirty_files = set() # 记录自上次同步以来有变更的文件名
|
| 55 |
+
_dirty_files_lock = threading.Lock()
|
| 56 |
+
_BATCH_SYNC_INTERVAL = 300 # 批量同步间隔(秒),默认5分钟
|
| 57 |
+
_batch_sync_timer = None # 定时器引用
|
| 58 |
|
| 59 |
# 确保目录存在
|
| 60 |
os.makedirs(LOCAL_DB_DIR, exist_ok=True)
|
|
|
|
| 405 |
# ========== 🚀 P1优化:更新内存缓存 ==========
|
| 406 |
_set_to_cache(file_name, data, local_path)
|
| 407 |
|
| 408 |
+
# ========== 第五步:标记文件脏,等待批量同步 ==========
|
|
|
|
| 409 |
if HF_TOKEN:
|
| 410 |
+
_mark_dirty(file_name)
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
|
| 413 |
# ==========================================
|
| 414 |
+
# ☁️ HF 批量同步机制
|
| 415 |
# ==========================================
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
def _mark_dirty(file_name: str):
|
| 418 |
+
"""标记文件为脏,等待下次批量同步"""
|
| 419 |
+
with _dirty_files_lock:
|
| 420 |
+
_dirty_files.add(file_name)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _batch_sync_to_hf(schedule_next=True):
|
| 424 |
+
"""批量同步所有脏文件到 HF Dataset(单次 commit)"""
|
| 425 |
+
# 取出脏文件列表并清空
|
| 426 |
+
with _dirty_files_lock:
|
| 427 |
+
files_to_sync = list(_dirty_files)
|
| 428 |
+
_dirty_files.clear()
|
| 429 |
+
|
| 430 |
+
if not files_to_sync:
|
| 431 |
+
if schedule_next:
|
| 432 |
+
_schedule_next_sync()
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
try:
|
| 436 |
+
operations = []
|
| 437 |
+
for file_name in files_to_sync:
|
| 438 |
+
local_path = os.path.join(LOCAL_DB_DIR, file_name)
|
| 439 |
+
if os.path.exists(local_path):
|
| 440 |
+
operations.append(
|
| 441 |
+
CommitOperationAdd(
|
| 442 |
+
path_in_repo=file_name,
|
| 443 |
+
path_or_fileobj=local_path
|
| 444 |
+
)
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if operations:
|
| 448 |
+
hf_api = HfApi(token=HF_TOKEN)
|
| 449 |
+
hf_api.create_commit(
|
| 450 |
repo_id=DATASET_REPO_ID,
|
| 451 |
repo_type="dataset",
|
| 452 |
+
operations=operations,
|
| 453 |
+
commit_message=f"batch sync: {', '.join(files_to_sync)}"
|
| 454 |
)
|
| 455 |
+
logger.info(f"✅ 批量同步成功: {len(operations)} 个文件 ({', '.join(files_to_sync)})")
|
| 456 |
+
else:
|
| 457 |
+
logger.info("⏭️ 批量同步: 脏文件均不存在,跳过")
|
| 458 |
+
except Exception as e:
|
| 459 |
+
logger.error(f"🚨 批量同步失败: {e}")
|
| 460 |
+
# 失败的文件重新标记为脏,下次重试
|
| 461 |
+
with _dirty_files_lock:
|
| 462 |
+
_dirty_files.update(files_to_sync)
|
| 463 |
+
finally:
|
| 464 |
+
if schedule_next:
|
| 465 |
+
_schedule_next_sync()
|
| 466 |
|
| 467 |
|
| 468 |
+
def _schedule_next_sync():
|
| 469 |
+
"""调度下一轮批量同步"""
|
| 470 |
+
global _batch_sync_timer
|
| 471 |
+
_batch_sync_timer = threading.Timer(_BATCH_SYNC_INTERVAL, _batch_sync_to_hf)
|
| 472 |
+
_batch_sync_timer.daemon = True
|
| 473 |
+
_batch_sync_timer.start()
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def start_batch_sync():
|
| 477 |
+
"""启动批量同步定时器(在 app 启动时调用)"""
|
| 478 |
+
if HF_TOKEN:
|
| 479 |
+
_schedule_next_sync()
|
| 480 |
+
logger.info(f"📡 HF 批量同步已启动,间隔 {_BATCH_SYNC_INTERVAL} 秒")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def flush_sync():
|
| 484 |
+
"""立即同步所有脏文件(用于服务关闭前)"""
|
| 485 |
+
global _batch_sync_timer
|
| 486 |
+
if _batch_sync_timer:
|
| 487 |
+
_batch_sync_timer.cancel()
|
| 488 |
+
_batch_sync_timer = None
|
| 489 |
+
logger.info("🔄 正在执行关闭前同步...")
|
| 490 |
+
_batch_sync_to_hf(schedule_next=False)
|
| 491 |
|
| 492 |
|
| 493 |
# ==========================================
|
|
|
|
| 686 |
# ========== 第四步:更新内存缓存 ==========
|
| 687 |
_set_to_cache(file_name, data, local_path)
|
| 688 |
|
| 689 |
+
# ========== 第五步:标记文件脏,等待批量同步 ==========
|
| 690 |
if HF_TOKEN:
|
| 691 |
+
_mark_dirty(file_name)
|
|
|
|
|
|
|
|
|
|
| 692 |
|
| 693 |
return result
|