# 云端Space代码/db_utils.py # ========================================== # 🔧 P2代码质量优化:数据库工具函数 # ========================================== # 作用:封装 JSON 数据库常用操作,减少重复代码 # 关联文件: # - 数据库连接.py (基础读写) # - router_tasks.py (任务操作) # - router_posts.py (帖子操作) # - router_items.py (商品操作) # ========================================== from typing import Any, Dict, List, Optional, Callable, Union import 数据库连接 as db # ========================================== # 📖 查询工具函数 # ========================================== def get_by_id(file_name: str, item_id: str, id_field: str = "id") -> Optional[Dict]: """ 根据 ID 获取单个记录 参数: file_name: JSON 文件名(如 tasks.json) item_id: 要查找的 ID id_field: ID 字段名(默认 "id") 返回: 找到的记录,或 None 示例: task = get_by_id("tasks.json", "task_123") user = get_by_id("users.json", "user@example.com", id_field="account") """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): return data.get(item_id) return next((item for item in data if item.get(id_field) == item_id), None) def get_by_field(file_name: str, field: str, value: Any) -> Optional[Dict]: """ 根据指定字段获取单个记录 参数: file_name: JSON 文件名 field: 字段名 value: 字段值 返回: 找到的记录,或 None """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): for item in data.values(): if isinstance(item, dict) and item.get(field) == value: return item return None return next((item for item in data if item.get(field) == value), None) def filter_by(file_name: str, **conditions) -> List[Dict]: """ 根据条件筛选记录 参数: file_name: JSON 文件名 **conditions: 筛选条件(键值对) 返回: 符合条件的记录列表 示例: open_tasks = filter_by("tasks.json", status="open") user_posts = filter_by("posts.json", author="user123", deleted=False) """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): data = list(data.values()) result = [] for item in data: if all(item.get(key) == value for key, value in conditions.items()): result.append(item) return result def count_by(file_name: str, **conditions) -> int: """ 统计符合条件的记录数量 参数: file_name: JSON 文件名 **conditions: 筛选条件 返回: 符合条件的记录数量 """ return len(filter_by(file_name, **conditions)) # ========================================== # ✏️ 更新工具函数 # ========================================== def update_by_id(file_name: str, item_id: str, updates: Dict, id_field: str = "id") -> bool: """ 根据 ID 更新记录 参数: file_name: JSON 文件名 item_id: 要更新的 ID updates: 要更新的字段(键值对) id_field: ID 字段名 返回: True 更新成功 / False 记录不存在 示例: update_by_id("tasks.json", "task_123", {"status": "completed"}) """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): if item_id in data: data[item_id].update(updates) db.save_data(file_name, data) return True return False for item in data: if item.get(id_field) == item_id: item.update(updates) db.save_data(file_name, data) return True return False def update_with_fn(file_name: str, item_id: str, update_fn: Callable[[Dict], None], id_field: str = "id") -> bool: """ 使用函数更新记录(支持复杂更新逻辑) 参数: file_name: JSON 文件名 item_id: 要更新的 ID update_fn: 更新函数,接收记录 dict,直接修改 id_field: ID 字段名 返回: True 更新成功 / False 记录不存在 示例: def increment_views(item): item["views"] = item.get("views", 0) + 1 update_with_fn("items.json", "item_123", increment_views) """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): if item_id in data: update_fn(data[item_id]) db.save_data(file_name, data) return True return False for item in data: if item.get(id_field) == item_id: update_fn(item) db.save_data(file_name, data) return True return False # ========================================== # ➕ 添加工具函数 # ========================================== def insert(file_name: str, item: Dict, prepend: bool = True) -> bool: """ 插入新记录 参数: file_name: JSON 文件名 item: 要插入的记录 prepend: True 插入到开头 / False 插入到末尾 返回: True 插入成功 示例: insert("tasks.json", {"id": "task_123", "title": "新任务"}) """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): item_id = item.get("id") or item.get("account") if item_id: data[item_id] = item db.save_data(file_name, data) return True return False if prepend: data.insert(0, item) else: data.append(item) db.save_data(file_name, data) return True def insert_if_not_exists(file_name: str, item: Dict, id_field: str = "id") -> bool: """ 如果不存在则插入 参数: file_name: JSON 文件名 item: 要插入的记录 id_field: ID 字段名 返回: True 插入成功 / False 已存在 """ item_id = item.get(id_field) if not item_id: return False existing = get_by_id(file_name, item_id, id_field) if existing: return False return insert(file_name, item) # ========================================== # ❌ 删除工具函数 # ========================================== def delete_by_id(file_name: str, item_id: str, id_field: str = "id") -> bool: """ 根据 ID 删除记录 参数: file_name: JSON 文件名 item_id: 要删除的 ID id_field: ID 字段名 返回: True 删除成功 / False 记录不存在 """ data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): if item_id in data: del data[item_id] db.save_data(file_name, data) return True return False original_len = len(data) data = [item for item in data if item.get(id_field) != item_id] if len(data) < original_len: db.save_data(file_name, data) return True return False def soft_delete_by_id(file_name: str, item_id: str, id_field: str = "id") -> bool: """ 软删除(标记为已删除,不物理删除) 参数: file_name: JSON 文件名 item_id: 要删除的 ID id_field: ID 字段名 返回: True 成功 / False 记录不存在 """ import time return update_by_id(file_name, item_id, { "deleted": True, "deleted_at": int(time.time()) }, id_field) # ========================================== # 🔍 分页工具函数 # ========================================== def paginate( file_name: str, page: int = 1, limit: int = 20, sort_by: str = None, sort_desc: bool = True, **filters ) -> Dict: """ 分页查询 参数: file_name: JSON 文件名 page: 页码(从 1 开始) limit: 每页数量 sort_by: 排序字段 sort_desc: 是否降序 **filters: 筛选条件 返回: { "data": [...], # 当前页数据 "total": 100, # 总数 "page": 1, # 当前页 "limit": 20, # 每页数量 "pages": 5 # 总页数 } """ # 获取并筛选数据 if filters: data = filter_by(file_name, **filters) else: data = db.load_data(file_name, default_data=[]) if isinstance(data, dict): data = list(data.values()) # 排序 if sort_by: try: data = sorted(data, key=lambda x: x.get(sort_by, 0), reverse=sort_desc) except TypeError: pass # 排序失败时忽略 # 计算分页 total = len(data) pages = (total + limit - 1) // limit # 向上取整 start = (page - 1) * limit end = start + limit return { "data": data[start:end], "total": total, "page": page, "limit": limit, "pages": pages } # ========================================== # 👁️ 访问量记录工具函数 # ========================================== def record_view(data_file: str, item_id: str, user_account: str) -> Optional[Dict]: """ 记录访问量(原子操作,并发安全) 参数: data_file: JSON 文件名(如 Items.json) item_id: 要记录访问的 item/task ID user_account: 访问用户账号 返回: {"views": N, "daily_views": N} 成功 None 记录不存在 逻辑: - 初始化字段: views=0, viewed_by=[], daily_views=0, daily_views_date="" - daily_views 每次调用都增加 - views 只在用户首次访问时增加(user_account 不在 viewed_by 中) - 如果 daily_views_date 不是今天,重置 daily_views=0 并更新日期 并发安全: - 使用 atomic_update 确保读-改-写在同一把锁内完成 - 高并发下不会丢失访问量 """ from datetime import datetime, timedelta # 用于在闭包中存储结果 result_container = [None] def updater(data): # 查找目标记录 target_item = None if isinstance(data, dict): if item_id in data: target_item = data[item_id] else: for item in data: if item.get("id") == item_id: target_item = item break if target_item is None: result_container[0] = None return # 初始化字段(如果不存在) if "views" not in target_item: target_item["views"] = 0 if "viewed_by" not in target_item: target_item["viewed_by"] = [] if "daily_views" not in target_item: target_item["daily_views"] = 0 if "daily_views_date" not in target_item: target_item["daily_views_date"] = "" # 获取今天的日期字符串(使用 UTC+8 时区) now_utc8 = datetime.utcnow() + timedelta(hours=8) today_str = now_utc8.date().isoformat() # "2026-04-02" # 检查是否需要重置日访问量 if target_item["daily_views_date"] != today_str: target_item["daily_views"] = 0 target_item["daily_views_date"] = today_str # 增加日访问量(每次调用都增加) target_item["daily_views"] += 1 # 检查用户是否已访问过 if user_account not in target_item["viewed_by"]: target_item["viewed_by"].append(user_account) target_item["views"] += 1 # 保存结果到闭包容器 result_container[0] = { "views": target_item["views"], "daily_views": target_item["daily_views"] } # 使用原子更新,整个读-改-写过程在同一把锁内完成 db.atomic_update(data_file, updater, default_data=[]) return result_container[0] # ========================================== # 🗂️ 排序结果缓存工具类 # ========================================== import time from typing import Callable, List, Dict, Any class SortCache: """ 排序结果缓存类 - 用于缓存列表排序结果,减少重复排序开销 设计原则: 1. 缓存排序后的 ID 顺序,而非完整数据,避免内存浪费 2. 使用 TTL (默认5分钟) 自动过期,确保数据新鲜度 3. 写操作时清除相关缓存,确保数据一致性 4. 在 load_data 返回的最新数据之上应用缓存的顺序 使用示例: # 在列表接口中 def get_items(sort="time"): items = db.load_data("items.json", default_data=[]) cache_key = f"items:{sort}" def sort_fn(data): if sort == "likes": data.sort(key=lambda x: x.get("likes", 0), reverse=True) else: data.sort(key=lambda x: x.get("created_at", 0), reverse=True) return sort_cache.get_sorted(cache_key, items, sort_fn) # 在写操作后 def create_item(...): ... sort_cache.invalidate("items") """ def __init__(self, ttl: int = 300): """ 初始化排序缓存 参数: ttl: 缓存过期时间(秒),默认 300 秒(5分钟) """ self._cache: Dict[str, tuple] = {} # {cache_key: (sorted_ids, timestamp)} self._ttl = ttl def get_sorted(self, cache_key: str, items: List[Dict], sort_fn: Callable[[List[Dict]], None]) -> List[Dict]: """ 获取排序后的数据(带缓存) 参数: cache_key: 缓存键,应包含数据文件和排序参数 items: 原始数据列表(来自 load_data 的最新数据) sort_fn: 排序函数,接收 items 列表并原地排序 返回: 排序后的 items 列表 """ now = time.time() # 检查缓存是否有效 if cache_key in self._cache: sorted_ids, cached_time = self._cache[cache_key] if now - cached_time < self._ttl: # 缓存有效,用缓存的顺序重排当前数据 id_order = {id_: idx for idx, id_ in enumerate(sorted_ids)} # 按缓存的顺序排序,新数据(不在缓存中的)放在最后 return sorted(items, key=lambda x: id_order.get(x.get("id"), float('inf'))) # 缓存无效或不存在,执行排序 sort_fn(items) # 缓存排序后的 ID 列表 sorted_ids = [item.get("id") for item in items] self._cache[cache_key] = (sorted_ids, now) return items def invalidate(self, prefix: str = ""): """ 清除缓存 参数: prefix: 缓存键前缀,如果指定则只清除匹配的缓存,否则清除所有缓存 """ if prefix: keys_to_remove = [k for k in self._cache if k.startswith(prefix)] for k in keys_to_remove: del self._cache[k] else: self._cache.clear() def get_stats(self) -> Dict[str, Any]: """ 获取缓存统计信息(用于调试) """ now = time.time() valid_count = sum(1 for _, ts in self._cache.values() if now - ts < self._ttl) return { "total_cached": len(self._cache), "valid": valid_count, "expired": len(self._cache) - valid_count, "ttl": self._ttl } # 全局排序缓存实例(TTL 5分钟) sort_cache = SortCache(ttl=300) # ========================================== # 🔄 批量操作工具函数 # ========================================== def batch_update(file_name: str, item_ids: List[str], updates: Dict, id_field: str = "id") -> int: """ 批量更新记录 参数: file_name: JSON 文件名 item_ids: 要更新的 ID 列表 updates: 要更新的字段 id_field: ID 字段名 返回: 更新的记录数量 """ data = db.load_data(file_name, default_data=[]) updated_count = 0 if isinstance(data, dict): for item_id in item_ids: if item_id in data: data[item_id].update(updates) updated_count += 1 else: id_set = set(item_ids) for item in data: if item.get(id_field) in id_set: item.update(updates) updated_count += 1 if updated_count > 0: db.save_data(file_name, data) return updated_count def batch_delete(file_name: str, item_ids: List[str], id_field: str = "id") -> int: """ 批量删除记录 参数: file_name: JSON 文件名 item_ids: 要删除的 ID 列表 id_field: ID 字段名 返回: 删除的记录数量 """ data = db.load_data(file_name, default_data=[]) original_count = len(data) if isinstance(data, list) else len(data) if isinstance(data, dict): for item_id in item_ids: data.pop(item_id, None) else: id_set = set(item_ids) data = [item for item in data if item.get(id_field) not in id_set] new_count = len(data) if isinstance(data, list) else len(data) deleted_count = original_count - new_count if deleted_count > 0: db.save_data(file_name, data) return deleted_count