ComfyUI-Ranking-API / db_utils.py
ZHIWEI666's picture
修复时间错乱
a2e28fb verified
# 云端Space代码/db_utils.py
# ==========================================
# 🔧 P2代码质量优化:数据库工具函数
# ==========================================
# 作用:封装 JSON 数据库常用操作,减少重复代码
# 关联文件:
# - 数据库连接.py (基础读写)
# - router_tasks.py (任务操作)
# - router_posts.py (帖子操作)
# - router_items.py (商品操作)
# ==========================================
from typing import Any, Dict, List, Optional, Callable, Union
import 数据库连接 as db
# ==========================================
# 📖 查询工具函数
# ==========================================
def get_by_id(file_name: str, item_id: str, id_field: str = "id") -> Optional[Dict]:
"""
根据 ID 获取单个记录
参数:
file_name: JSON 文件名(如 tasks.json)
item_id: 要查找的 ID
id_field: ID 字段名(默认 "id")
返回:
找到的记录,或 None
示例:
task = get_by_id("tasks.json", "task_123")
user = get_by_id("users.json", "user@example.com", id_field="account")
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
return data.get(item_id)
return next((item for item in data if item.get(id_field) == item_id), None)
def get_by_field(file_name: str, field: str, value: Any) -> Optional[Dict]:
"""
根据指定字段获取单个记录
参数:
file_name: JSON 文件名
field: 字段名
value: 字段值
返回:
找到的记录,或 None
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
for item in data.values():
if isinstance(item, dict) and item.get(field) == value:
return item
return None
return next((item for item in data if item.get(field) == value), None)
def filter_by(file_name: str, **conditions) -> List[Dict]:
"""
根据条件筛选记录
参数:
file_name: JSON 文件名
**conditions: 筛选条件(键值对)
返回:
符合条件的记录列表
示例:
open_tasks = filter_by("tasks.json", status="open")
user_posts = filter_by("posts.json", author="user123", deleted=False)
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
data = list(data.values())
result = []
for item in data:
if all(item.get(key) == value for key, value in conditions.items()):
result.append(item)
return result
def count_by(file_name: str, **conditions) -> int:
"""
统计符合条件的记录数量
参数:
file_name: JSON 文件名
**conditions: 筛选条件
返回:
符合条件的记录数量
"""
return len(filter_by(file_name, **conditions))
# ==========================================
# ✏️ 更新工具函数
# ==========================================
def update_by_id(file_name: str, item_id: str, updates: Dict, id_field: str = "id") -> bool:
"""
根据 ID 更新记录
参数:
file_name: JSON 文件名
item_id: 要更新的 ID
updates: 要更新的字段(键值对)
id_field: ID 字段名
返回:
True 更新成功 / False 记录不存在
示例:
update_by_id("tasks.json", "task_123", {"status": "completed"})
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
if item_id in data:
data[item_id].update(updates)
db.save_data(file_name, data)
return True
return False
for item in data:
if item.get(id_field) == item_id:
item.update(updates)
db.save_data(file_name, data)
return True
return False
def update_with_fn(file_name: str, item_id: str, update_fn: Callable[[Dict], None], id_field: str = "id") -> bool:
"""
使用函数更新记录(支持复杂更新逻辑)
参数:
file_name: JSON 文件名
item_id: 要更新的 ID
update_fn: 更新函数,接收记录 dict,直接修改
id_field: ID 字段名
返回:
True 更新成功 / False 记录不存在
示例:
def increment_views(item):
item["views"] = item.get("views", 0) + 1
update_with_fn("items.json", "item_123", increment_views)
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
if item_id in data:
update_fn(data[item_id])
db.save_data(file_name, data)
return True
return False
for item in data:
if item.get(id_field) == item_id:
update_fn(item)
db.save_data(file_name, data)
return True
return False
# ==========================================
# ➕ 添加工具函数
# ==========================================
def insert(file_name: str, item: Dict, prepend: bool = True) -> bool:
"""
插入新记录
参数:
file_name: JSON 文件名
item: 要插入的记录
prepend: True 插入到开头 / False 插入到末尾
返回:
True 插入成功
示例:
insert("tasks.json", {"id": "task_123", "title": "新任务"})
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
item_id = item.get("id") or item.get("account")
if item_id:
data[item_id] = item
db.save_data(file_name, data)
return True
return False
if prepend:
data.insert(0, item)
else:
data.append(item)
db.save_data(file_name, data)
return True
def insert_if_not_exists(file_name: str, item: Dict, id_field: str = "id") -> bool:
"""
如果不存在则插入
参数:
file_name: JSON 文件名
item: 要插入的记录
id_field: ID 字段名
返回:
True 插入成功 / False 已存在
"""
item_id = item.get(id_field)
if not item_id:
return False
existing = get_by_id(file_name, item_id, id_field)
if existing:
return False
return insert(file_name, item)
# ==========================================
# ❌ 删除工具函数
# ==========================================
def delete_by_id(file_name: str, item_id: str, id_field: str = "id") -> bool:
"""
根据 ID 删除记录
参数:
file_name: JSON 文件名
item_id: 要删除的 ID
id_field: ID 字段名
返回:
True 删除成功 / False 记录不存在
"""
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
if item_id in data:
del data[item_id]
db.save_data(file_name, data)
return True
return False
original_len = len(data)
data = [item for item in data if item.get(id_field) != item_id]
if len(data) < original_len:
db.save_data(file_name, data)
return True
return False
def soft_delete_by_id(file_name: str, item_id: str, id_field: str = "id") -> bool:
"""
软删除(标记为已删除,不物理删除)
参数:
file_name: JSON 文件名
item_id: 要删除的 ID
id_field: ID 字段名
返回:
True 成功 / False 记录不存在
"""
import time
return update_by_id(file_name, item_id, {
"deleted": True,
"deleted_at": int(time.time())
}, id_field)
# ==========================================
# 🔍 分页工具函数
# ==========================================
def paginate(
file_name: str,
page: int = 1,
limit: int = 20,
sort_by: str = None,
sort_desc: bool = True,
**filters
) -> Dict:
"""
分页查询
参数:
file_name: JSON 文件名
page: 页码(从 1 开始)
limit: 每页数量
sort_by: 排序字段
sort_desc: 是否降序
**filters: 筛选条件
返回:
{
"data": [...], # 当前页数据
"total": 100, # 总数
"page": 1, # 当前页
"limit": 20, # 每页数量
"pages": 5 # 总页数
}
"""
# 获取并筛选数据
if filters:
data = filter_by(file_name, **filters)
else:
data = db.load_data(file_name, default_data=[])
if isinstance(data, dict):
data = list(data.values())
# 排序
if sort_by:
try:
data = sorted(data, key=lambda x: x.get(sort_by, 0), reverse=sort_desc)
except TypeError:
pass # 排序失败时忽略
# 计算分页
total = len(data)
pages = (total + limit - 1) // limit # 向上取整
start = (page - 1) * limit
end = start + limit
return {
"data": data[start:end],
"total": total,
"page": page,
"limit": limit,
"pages": pages
}
# ==========================================
# 👁️ 访问量记录工具函数
# ==========================================
def record_view(data_file: str, item_id: str, user_account: str) -> Optional[Dict]:
"""
记录访问量(原子操作,并发安全)
参数:
data_file: JSON 文件名(如 Items.json)
item_id: 要记录访问的 item/task ID
user_account: 访问用户账号
返回:
{"views": N, "daily_views": N} 成功
None 记录不存在
逻辑:
- 初始化字段: views=0, viewed_by=[], daily_views=0, daily_views_date=""
- daily_views 每次调用都增加
- views 只在用户首次访问时增加(user_account 不在 viewed_by 中)
- 如果 daily_views_date 不是今天,重置 daily_views=0 并更新日期
并发安全:
- 使用 atomic_update 确保读-改-写在同一把锁内完成
- 高并发下不会丢失访问量
"""
from datetime import datetime, timedelta
# 用于在闭包中存储结果
result_container = [None]
def updater(data):
# 查找目标记录
target_item = None
if isinstance(data, dict):
if item_id in data:
target_item = data[item_id]
else:
for item in data:
if item.get("id") == item_id:
target_item = item
break
if target_item is None:
result_container[0] = None
return
# 初始化字段(如果不存在)
if "views" not in target_item:
target_item["views"] = 0
if "viewed_by" not in target_item:
target_item["viewed_by"] = []
if "daily_views" not in target_item:
target_item["daily_views"] = 0
if "daily_views_date" not in target_item:
target_item["daily_views_date"] = ""
# 获取今天的日期字符串(使用 UTC+8 时区)
now_utc8 = datetime.utcnow() + timedelta(hours=8)
today_str = now_utc8.date().isoformat() # "2026-04-02"
# 检查是否需要重置日访问量
if target_item["daily_views_date"] != today_str:
target_item["daily_views"] = 0
target_item["daily_views_date"] = today_str
# 增加日访问量(每次调用都增加)
target_item["daily_views"] += 1
# 检查用户是否已访问过
if user_account not in target_item["viewed_by"]:
target_item["viewed_by"].append(user_account)
target_item["views"] += 1
# 保存结果到闭包容器
result_container[0] = {
"views": target_item["views"],
"daily_views": target_item["daily_views"]
}
# 使用原子更新,整个读-改-写过程在同一把锁内完成
db.atomic_update(data_file, updater, default_data=[])
return result_container[0]
# ==========================================
# 🗂️ 排序结果缓存工具类
# ==========================================
import time
from typing import Callable, List, Dict, Any
class SortCache:
"""
排序结果缓存类 - 用于缓存列表排序结果,减少重复排序开销
设计原则:
1. 缓存排序后的 ID 顺序,而非完整数据,避免内存浪费
2. 使用 TTL (默认5分钟) 自动过期,确保数据新鲜度
3. 写操作时清除相关缓存,确保数据一致性
4. 在 load_data 返回的最新数据之上应用缓存的顺序
使用示例:
# 在列表接口中
def get_items(sort="time"):
items = db.load_data("items.json", default_data=[])
cache_key = f"items:{sort}"
def sort_fn(data):
if sort == "likes":
data.sort(key=lambda x: x.get("likes", 0), reverse=True)
else:
data.sort(key=lambda x: x.get("created_at", 0), reverse=True)
return sort_cache.get_sorted(cache_key, items, sort_fn)
# 在写操作后
def create_item(...):
...
sort_cache.invalidate("items")
"""
def __init__(self, ttl: int = 300):
"""
初始化排序缓存
参数:
ttl: 缓存过期时间(秒),默认 300 秒(5分钟)
"""
self._cache: Dict[str, tuple] = {} # {cache_key: (sorted_ids, timestamp)}
self._ttl = ttl
def get_sorted(self, cache_key: str, items: List[Dict], sort_fn: Callable[[List[Dict]], None]) -> List[Dict]:
"""
获取排序后的数据(带缓存)
参数:
cache_key: 缓存键,应包含数据文件和排序参数
items: 原始数据列表(来自 load_data 的最新数据)
sort_fn: 排序函数,接收 items 列表并原地排序
返回:
排序后的 items 列表
"""
now = time.time()
# 检查缓存是否有效
if cache_key in self._cache:
sorted_ids, cached_time = self._cache[cache_key]
if now - cached_time < self._ttl:
# 缓存有效,用缓存的顺序重排当前数据
id_order = {id_: idx for idx, id_ in enumerate(sorted_ids)}
# 按缓存的顺序排序,新数据(不在缓存中的)放在最后
return sorted(items, key=lambda x: id_order.get(x.get("id"), float('inf')))
# 缓存无效或不存在,执行排序
sort_fn(items)
# 缓存排序后的 ID 列表
sorted_ids = [item.get("id") for item in items]
self._cache[cache_key] = (sorted_ids, now)
return items
def invalidate(self, prefix: str = ""):
"""
清除缓存
参数:
prefix: 缓存键前缀,如果指定则只清除匹配的缓存,否则清除所有缓存
"""
if prefix:
keys_to_remove = [k for k in self._cache if k.startswith(prefix)]
for k in keys_to_remove:
del self._cache[k]
else:
self._cache.clear()
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息(用于调试)
"""
now = time.time()
valid_count = sum(1 for _, ts in self._cache.values() if now - ts < self._ttl)
return {
"total_cached": len(self._cache),
"valid": valid_count,
"expired": len(self._cache) - valid_count,
"ttl": self._ttl
}
# 全局排序缓存实例(TTL 5分钟)
sort_cache = SortCache(ttl=300)
# ==========================================
# 🔄 批量操作工具函数
# ==========================================
def batch_update(file_name: str, item_ids: List[str], updates: Dict, id_field: str = "id") -> int:
"""
批量更新记录
参数:
file_name: JSON 文件名
item_ids: 要更新的 ID 列表
updates: 要更新的字段
id_field: ID 字段名
返回:
更新的记录数量
"""
data = db.load_data(file_name, default_data=[])
updated_count = 0
if isinstance(data, dict):
for item_id in item_ids:
if item_id in data:
data[item_id].update(updates)
updated_count += 1
else:
id_set = set(item_ids)
for item in data:
if item.get(id_field) in id_set:
item.update(updates)
updated_count += 1
if updated_count > 0:
db.save_data(file_name, data)
return updated_count
def batch_delete(file_name: str, item_ids: List[str], id_field: str = "id") -> int:
"""
批量删除记录
参数:
file_name: JSON 文件名
item_ids: 要删除的 ID 列表
id_field: ID 字段名
返回:
删除的记录数量
"""
data = db.load_data(file_name, default_data=[])
original_count = len(data) if isinstance(data, list) else len(data)
if isinstance(data, dict):
for item_id in item_ids:
data.pop(item_id, None)
else:
id_set = set(item_ids)
data = [item for item in data if item.get(id_field) not in id_set]
new_count = len(data) if isinstance(data, list) else len(data)
deleted_count = original_count - new_count
if deleted_count > 0:
db.save_data(file_name, data)
return deleted_count