ZHIWEI666 commited on
Commit
0bf23ac
·
verified ·
1 Parent(s): 3b1406b
Files changed (5) hide show
  1. app.py +52 -10
  2. db_utils.py +428 -0
  3. router_tasks.py +164 -92
  4. 安全认证.py +179 -0
  5. 数据库连接.py +114 -3
app.py CHANGED
@@ -73,38 +73,80 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
73
  # 🛡️ 稳定性优化:全局异常处理器
74
  # ==========================================
75
  # 作用:捕获所有未处理异常,防止单个请求崩溃整个服务
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @app.exception_handler(StarletteHTTPException)
78
  async def http_exception_handler(request, exc):
79
- """HTTP 异常处理,返回友好错误信息"""
80
  logger.warning(f"HTTP {exc.status_code} | {request.url.path} | {exc.detail}")
 
 
 
 
 
 
 
 
 
 
 
 
81
  return JSONResponse(
82
  status_code=exc.status_code,
83
- content={"status": "error", "detail": str(exc.detail)}
84
  )
85
 
86
  @app.exception_handler(RequestValidationError)
87
  async def validation_exception_handler(request, exc):
88
- """请求参数校验失败处理"""
89
  logger.warning(f"Validation Error | {request.url.path} | {exc.errors()}")
 
 
 
 
 
 
 
90
  return JSONResponse(
91
  status_code=422,
92
- content={"status": "error", "detail": "请求参数格式错误", "errors": exc.errors()}
 
 
 
 
 
93
  )
94
 
95
  @app.exception_handler(Exception)
96
  async def global_exception_handler(request, exc):
97
- """全局异常处理,捕获所有未预期异常"""
98
  error_id = f"ERR_{int(time.time())}_{id(exc) % 10000}"
99
  logger.error(f"Unhandled Exception [{error_id}] | {request.url.path} | {type(exc).__name__}: {exc}")
100
  logger.error(traceback.format_exc())
101
  return JSONResponse(
102
  status_code=500,
103
- content={
104
- "status": "error",
105
- "detail": "服务器内部错误,请稍后重试",
106
- "error_id": error_id # 方便排查
107
- }
 
108
  )
109
 
110
 
 
73
  # 🛡️ 稳定性优化:全局异常处理器
74
  # ==========================================
75
  # 作用:捕获所有未处理异常,防止单个请求崩溃整个服务
76
+ # 🚀 P3优化:统一错误响应格式
77
+
78
+ def create_error_response(status_code: int, detail: str, error_type: str = None, error_id: str = None, extra: dict = None):
79
+ """🚀 P3优化:统一错误响应构造器"""
80
+ response = {
81
+ "success": False,
82
+ "status": "error",
83
+ "code": status_code,
84
+ "detail": detail,
85
+ }
86
+ if error_type:
87
+ response["type"] = error_type
88
+ if error_id:
89
+ response["error_id"] = error_id
90
+ if extra:
91
+ response.update(extra)
92
+ return response
93
 
94
  @app.exception_handler(StarletteHTTPException)
95
  async def http_exception_handler(request, exc):
96
+ """🚀 P3优化:HTTP 异常处理,返回统一格式"""
97
  logger.warning(f"HTTP {exc.status_code} | {request.url.path} | {exc.detail}")
98
+
99
+ # 根据状态码确定错误类型
100
+ error_type_map = {
101
+ 400: "bad_request",
102
+ 401: "unauthorized",
103
+ 403: "forbidden",
104
+ 404: "not_found",
105
+ 409: "conflict",
106
+ 429: "rate_limited",
107
+ }
108
+ error_type = error_type_map.get(exc.status_code, "http_error")
109
+
110
  return JSONResponse(
111
  status_code=exc.status_code,
112
+ content=create_error_response(exc.status_code, str(exc.detail), error_type)
113
  )
114
 
115
  @app.exception_handler(RequestValidationError)
116
  async def validation_exception_handler(request, exc):
117
+ """🚀 P3优化:请求参数校验失败处理"""
118
  logger.warning(f"Validation Error | {request.url.path} | {exc.errors()}")
119
+
120
+ # 提取简洁的错误信息
121
+ errors = []
122
+ for err in exc.errors():
123
+ field = ".".join(str(loc) for loc in err["loc"][1:]) # 跳过 'body'
124
+ errors.append({"field": field, "message": err["msg"], "type": err["type"]})
125
+
126
  return JSONResponse(
127
  status_code=422,
128
+ content=create_error_response(
129
+ 422,
130
+ "请求参数格式错误",
131
+ "validation_error",
132
+ extra={"errors": errors}
133
+ )
134
  )
135
 
136
  @app.exception_handler(Exception)
137
  async def global_exception_handler(request, exc):
138
+ """🚀 P3优化:全局异常处理,捕获所有未预期异常"""
139
  error_id = f"ERR_{int(time.time())}_{id(exc) % 10000}"
140
  logger.error(f"Unhandled Exception [{error_id}] | {request.url.path} | {type(exc).__name__}: {exc}")
141
  logger.error(traceback.format_exc())
142
  return JSONResponse(
143
  status_code=500,
144
+ content=create_error_response(
145
+ 500,
146
+ "服务器内部错误,请稍后重试",
147
+ "internal_error",
148
+ error_id=error_id
149
+ )
150
  )
151
 
152
 
db_utils.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 云端Space代码/db_utils.py
2
+ # ==========================================
3
+ # 🔧 P2代码质量优化:数据库工具函数
4
+ # ==========================================
5
+ # 作用:封装 JSON 数据库常用操作,减少重复代码
6
+ # 关联文件:
7
+ # - 数据库连接.py (基础读写)
8
+ # - router_tasks.py (任务操作)
9
+ # - router_posts.py (帖子操作)
10
+ # - router_items.py (商品操作)
11
+ # ==========================================
12
+
13
+ from typing import Any, Dict, List, Optional, Callable, Union
14
+ import 数据库连接 as db
15
+
16
+
17
+ # ==========================================
18
+ # 📖 查询工具函数
19
+ # ==========================================
20
+
21
+ def get_by_id(file_name: str, item_id: str, id_field: str = "id") -> Optional[Dict]:
22
+ """
23
+ 根据 ID 获取单个记录
24
+
25
+ 参数:
26
+ file_name: JSON 文件名(如 tasks.json)
27
+ item_id: 要查找的 ID
28
+ id_field: ID 字段名(默认 "id")
29
+
30
+ 返回:
31
+ 找到的记录,或 None
32
+
33
+ 示例:
34
+ task = get_by_id("tasks.json", "task_123")
35
+ user = get_by_id("users.json", "user@example.com", id_field="account")
36
+ """
37
+ data = db.load_data(file_name, default_data=[])
38
+
39
+ if isinstance(data, dict):
40
+ return data.get(item_id)
41
+
42
+ return next((item for item in data if item.get(id_field) == item_id), None)
43
+
44
+
45
+ def get_by_field(file_name: str, field: str, value: Any) -> Optional[Dict]:
46
+ """
47
+ 根据指定字段获取单个记录
48
+
49
+ 参数:
50
+ file_name: JSON 文件名
51
+ field: 字段名
52
+ value: 字段值
53
+
54
+ 返回:
55
+ 找到的记录,或 None
56
+ """
57
+ data = db.load_data(file_name, default_data=[])
58
+
59
+ if isinstance(data, dict):
60
+ for item in data.values():
61
+ if isinstance(item, dict) and item.get(field) == value:
62
+ return item
63
+ return None
64
+
65
+ return next((item for item in data if item.get(field) == value), None)
66
+
67
+
68
+ def filter_by(file_name: str, **conditions) -> List[Dict]:
69
+ """
70
+ 根据条件筛选记录
71
+
72
+ 参数:
73
+ file_name: JSON 文件名
74
+ **conditions: 筛选条件(键值对)
75
+
76
+ 返回:
77
+ 符合条件的记录列表
78
+
79
+ 示例:
80
+ open_tasks = filter_by("tasks.json", status="open")
81
+ user_posts = filter_by("posts.json", author="user123", deleted=False)
82
+ """
83
+ data = db.load_data(file_name, default_data=[])
84
+
85
+ if isinstance(data, dict):
86
+ data = list(data.values())
87
+
88
+ result = []
89
+ for item in data:
90
+ if all(item.get(key) == value for key, value in conditions.items()):
91
+ result.append(item)
92
+
93
+ return result
94
+
95
+
96
+ def count_by(file_name: str, **conditions) -> int:
97
+ """
98
+ 统计符合条件的记录数量
99
+
100
+ 参数:
101
+ file_name: JSON 文件名
102
+ **conditions: 筛选条件
103
+
104
+ 返回:
105
+ 符合条件的记录数量
106
+ """
107
+ return len(filter_by(file_name, **conditions))
108
+
109
+
110
+ # ==========================================
111
+ # ✏️ 更新工具函数
112
+ # ==========================================
113
+
114
+ def update_by_id(file_name: str, item_id: str, updates: Dict, id_field: str = "id") -> bool:
115
+ """
116
+ 根据 ID 更新记录
117
+
118
+ 参数:
119
+ file_name: JSON 文件名
120
+ item_id: 要更新的 ID
121
+ updates: 要更新的字段(键值对)
122
+ id_field: ID 字段名
123
+
124
+ 返回:
125
+ True 更新成功 / False 记录不存在
126
+
127
+ 示例:
128
+ update_by_id("tasks.json", "task_123", {"status": "completed"})
129
+ """
130
+ data = db.load_data(file_name, default_data=[])
131
+
132
+ if isinstance(data, dict):
133
+ if item_id in data:
134
+ data[item_id].update(updates)
135
+ db.save_data(file_name, data)
136
+ return True
137
+ return False
138
+
139
+ for item in data:
140
+ if item.get(id_field) == item_id:
141
+ item.update(updates)
142
+ db.save_data(file_name, data)
143
+ return True
144
+
145
+ return False
146
+
147
+
148
+ def update_with_fn(file_name: str, item_id: str, update_fn: Callable[[Dict], None], id_field: str = "id") -> bool:
149
+ """
150
+ 使用函数更新记录(支持复杂更新逻辑)
151
+
152
+ 参数:
153
+ file_name: JSON 文件名
154
+ item_id: 要更新的 ID
155
+ update_fn: 更新函数,接收记录 dict,直接修改
156
+ id_field: ID 字段名
157
+
158
+ 返回:
159
+ True 更新成功 / False 记录不存在
160
+
161
+ 示例:
162
+ def increment_views(item):
163
+ item["views"] = item.get("views", 0) + 1
164
+
165
+ update_with_fn("items.json", "item_123", increment_views)
166
+ """
167
+ data = db.load_data(file_name, default_data=[])
168
+
169
+ if isinstance(data, dict):
170
+ if item_id in data:
171
+ update_fn(data[item_id])
172
+ db.save_data(file_name, data)
173
+ return True
174
+ return False
175
+
176
+ for item in data:
177
+ if item.get(id_field) == item_id:
178
+ update_fn(item)
179
+ db.save_data(file_name, data)
180
+ return True
181
+
182
+ return False
183
+
184
+
185
+ # ==========================================
186
+ # ➕ 添加工具函数
187
+ # ==========================================
188
+
189
+ def insert(file_name: str, item: Dict, prepend: bool = True) -> bool:
190
+ """
191
+ 插入新记录
192
+
193
+ 参数:
194
+ file_name: JSON 文件名
195
+ item: 要插入的记录
196
+ prepend: True 插入到开头 / False 插入到末尾
197
+
198
+ 返回:
199
+ True 插入成功
200
+
201
+ 示例:
202
+ insert("tasks.json", {"id": "task_123", "title": "新任务"})
203
+ """
204
+ data = db.load_data(file_name, default_data=[])
205
+
206
+ if isinstance(data, dict):
207
+ item_id = item.get("id") or item.get("account")
208
+ if item_id:
209
+ data[item_id] = item
210
+ db.save_data(file_name, data)
211
+ return True
212
+ return False
213
+
214
+ if prepend:
215
+ data.insert(0, item)
216
+ else:
217
+ data.append(item)
218
+
219
+ db.save_data(file_name, data)
220
+ return True
221
+
222
+
223
+ def insert_if_not_exists(file_name: str, item: Dict, id_field: str = "id") -> bool:
224
+ """
225
+ 如果不存在则插入
226
+
227
+ 参数:
228
+ file_name: JSON 文件名
229
+ item: 要插入的记录
230
+ id_field: ID 字段名
231
+
232
+ 返回:
233
+ True 插入成功 / False 已存在
234
+ """
235
+ item_id = item.get(id_field)
236
+ if not item_id:
237
+ return False
238
+
239
+ existing = get_by_id(file_name, item_id, id_field)
240
+ if existing:
241
+ return False
242
+
243
+ return insert(file_name, item)
244
+
245
+
246
+ # ==========================================
247
+ # ❌ 删除工具函数
248
+ # ==========================================
249
+
250
+ def delete_by_id(file_name: str, item_id: str, id_field: str = "id") -> bool:
251
+ """
252
+ 根据 ID 删除记录
253
+
254
+ 参数:
255
+ file_name: JSON 文件名
256
+ item_id: 要删除的 ID
257
+ id_field: ID 字段名
258
+
259
+ 返回:
260
+ True 删除成功 / False 记录不存在
261
+ """
262
+ data = db.load_data(file_name, default_data=[])
263
+
264
+ if isinstance(data, dict):
265
+ if item_id in data:
266
+ del data[item_id]
267
+ db.save_data(file_name, data)
268
+ return True
269
+ return False
270
+
271
+ original_len = len(data)
272
+ data = [item for item in data if item.get(id_field) != item_id]
273
+
274
+ if len(data) < original_len:
275
+ db.save_data(file_name, data)
276
+ return True
277
+
278
+ return False
279
+
280
+
281
+ def soft_delete_by_id(file_name: str, item_id: str, id_field: str = "id") -> bool:
282
+ """
283
+ 软删除(标记为已删除,不物理删除)
284
+
285
+ 参数:
286
+ file_name: JSON 文件名
287
+ item_id: 要删除的 ID
288
+ id_field: ID 字段名
289
+
290
+ 返回:
291
+ True 成功 / False 记录不存在
292
+ """
293
+ import time
294
+ return update_by_id(file_name, item_id, {
295
+ "deleted": True,
296
+ "deleted_at": int(time.time())
297
+ }, id_field)
298
+
299
+
300
+ # ==========================================
301
+ # 🔍 分页工具函数
302
+ # ==========================================
303
+
304
+ def paginate(
305
+ file_name: str,
306
+ page: int = 1,
307
+ limit: int = 20,
308
+ sort_by: str = None,
309
+ sort_desc: bool = True,
310
+ **filters
311
+ ) -> Dict:
312
+ """
313
+ 分页查询
314
+
315
+ 参数:
316
+ file_name: JSON 文件名
317
+ page: 页码(从 1 开始)
318
+ limit: 每页数量
319
+ sort_by: 排序字段
320
+ sort_desc: 是否降序
321
+ **filters: 筛选条件
322
+
323
+ 返回:
324
+ {
325
+ "data": [...], # 当前页数据
326
+ "total": 100, # 总数
327
+ "page": 1, # 当前页
328
+ "limit": 20, # 每页数量
329
+ "pages": 5 # 总页数
330
+ }
331
+ """
332
+ # 获取并筛选数据
333
+ if filters:
334
+ data = filter_by(file_name, **filters)
335
+ else:
336
+ data = db.load_data(file_name, default_data=[])
337
+ if isinstance(data, dict):
338
+ data = list(data.values())
339
+
340
+ # 排序
341
+ if sort_by:
342
+ try:
343
+ data = sorted(data, key=lambda x: x.get(sort_by, 0), reverse=sort_desc)
344
+ except TypeError:
345
+ pass # 排序失败时忽略
346
+
347
+ # 计算分页
348
+ total = len(data)
349
+ pages = (total + limit - 1) // limit # 向上取整
350
+ start = (page - 1) * limit
351
+ end = start + limit
352
+
353
+ return {
354
+ "data": data[start:end],
355
+ "total": total,
356
+ "page": page,
357
+ "limit": limit,
358
+ "pages": pages
359
+ }
360
+
361
+
362
+ # ==========================================
363
+ # 🔄 批量操作工具函数
364
+ # ==========================================
365
+
366
+ def batch_update(file_name: str, item_ids: List[str], updates: Dict, id_field: str = "id") -> int:
367
+ """
368
+ 批量更新记录
369
+
370
+ 参数:
371
+ file_name: JSON 文件名
372
+ item_ids: 要更新的 ID 列表
373
+ updates: 要更新的字段
374
+ id_field: ID 字段名
375
+
376
+ 返回:
377
+ 更新的记录数量
378
+ """
379
+ data = db.load_data(file_name, default_data=[])
380
+ updated_count = 0
381
+
382
+ if isinstance(data, dict):
383
+ for item_id in item_ids:
384
+ if item_id in data:
385
+ data[item_id].update(updates)
386
+ updated_count += 1
387
+ else:
388
+ id_set = set(item_ids)
389
+ for item in data:
390
+ if item.get(id_field) in id_set:
391
+ item.update(updates)
392
+ updated_count += 1
393
+
394
+ if updated_count > 0:
395
+ db.save_data(file_name, data)
396
+
397
+ return updated_count
398
+
399
+
400
+ def batch_delete(file_name: str, item_ids: List[str], id_field: str = "id") -> int:
401
+ """
402
+ 批量删除记录
403
+
404
+ 参数:
405
+ file_name: JSON 文件名
406
+ item_ids: 要删除的 ID 列表
407
+ id_field: ID 字段名
408
+
409
+ 返回:
410
+ 删除的记录数量
411
+ """
412
+ data = db.load_data(file_name, default_data=[])
413
+ original_count = len(data) if isinstance(data, list) else len(data)
414
+
415
+ if isinstance(data, dict):
416
+ for item_id in item_ids:
417
+ data.pop(item_id, None)
418
+ else:
419
+ id_set = set(item_ids)
420
+ data = [item for item in data if item.get(id_field) not in id_set]
421
+
422
+ new_count = len(data) if isinstance(data, list) else len(data)
423
+ deleted_count = original_count - new_count
424
+
425
+ if deleted_count > 0:
426
+ db.save_data(file_name, data)
427
+
428
+ return deleted_count
router_tasks.py CHANGED
@@ -90,12 +90,15 @@ def check_and_update_expired_tasks(tasks_db, db_session=None):
90
  - open 状态且超过截止日期:自动取消,退还冻结金额
91
  - in_progress 状态且超过截止日期:标记为过期(不自动取消,需双方处理)
92
  💳 P6支付增强:过期时自动退款
 
93
  """
94
  today = datetime.date.today().isoformat() # "2026-03-30"
95
  updated = False
96
  refund_tasks = [] # 需要退款的任务
 
97
 
98
- for task in tasks_db:
 
99
  deadline = task.get("deadline", "")
100
  status = task.get("status", "")
101
 
@@ -105,57 +108,81 @@ def check_and_update_expired_tasks(tasks_db, db_session=None):
105
  # 检查是否过期
106
  if deadline < today:
107
  if status == "open":
108
- # 开放接单状态且过期:自动取消,退还冻结金额
109
- task["status"] = "expired"
110
- task["expired_at"] = int(time.time())
111
- updated = True
112
-
113
- # 💳 记录需要退款的任务
114
  frozen_amount = task.get("frozen_amount", task.get("total_price", 0))
115
- if frozen_amount > 0:
116
- refund_tasks.append({
117
- "task_id": task["id"],
118
- "publisher": task.get("publisher"),
119
- "amount": frozen_amount,
120
- "title": task.get("title", "")
121
- })
122
- task["refunded"] = True
123
- task["refund_amount"] = frozen_amount
124
 
125
  elif status == "in_progress":
126
- # 进行中且过期:标记过期但取消
127
  task["is_overdue"] = True
128
  updated = True
129
 
130
- # 💳 执行退款操作
131
- if refund_tasks and db_session:
132
- for refund in refund_tasks:
133
- try:
134
- wallet = db_session.query(Wallet).filter(Wallet.account == refund["publisher"]).with_for_update().first()
135
- if wallet:
136
- wallet.frozen_balance = max(0, wallet.frozen_balance - refund["amount"]) # 减少冻结
137
- wallet.balance += refund["amount"] # 返还余额
138
-
139
- # 记录退款交易
140
- create_task_transaction(
141
- db_session, refund["publisher"], "TASK_REFUND",
142
- refund["amount"], task_id=refund["task_id"]
143
- )
144
-
145
- # 发送退款通知
146
- add_notification(refund["publisher"], {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  "type": "task_refund",
148
  "from_user": "system",
149
- "target_item_id": refund["task_id"],
150
- "target_item_title": refund["title"],
151
- "content": f"💰 任务《{refund['title']}》已过期自动取消,{refund['amount']}积分已退还"
152
  })
 
 
 
153
 
154
- logger.info(f"TASK_REFUND | publisher={refund['publisher']} | task={refund['task_id']} | amount={refund['amount']}")
155
- except Exception as e:
156
- logger.error(f"TASK_REFUND_ERROR | task={refund['task_id']} | error={str(e)}")
157
-
158
- db_session.commit()
 
 
 
 
 
 
 
 
159
 
160
  return updated
161
 
@@ -410,9 +437,10 @@ async def update_task(task_id: str, update_data: TaskUpdate, current_user: str =
410
  raise HTTPException(status_code=404, detail="任务不存在")
411
 
412
  @router.delete("/api/tasks/{task_id}")
413
- async def cancel_task(task_id: str, current_user: str = Depends(require_auth)):
414
  """
415
  取消任务(仅发布者可操作,且仅在 open 状态时可取消)
 
416
  """
417
  tasks_db = db.load_data("tasks.json", default_data=[])
418
 
@@ -424,9 +452,41 @@ async def cancel_task(task_id: str, current_user: str = Depends(require_auth)):
424
  if task.get("status") != "open":
425
  raise HTTPException(status_code=400, detail="只能取消开放状态的任务")
426
 
427
- task["status"] = "cancelled"
428
- db.save_data("tasks.json", tasks_db)
429
- return {"status": "success"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  raise HTTPException(status_code=404, detail="任务不存在")
432
 
@@ -619,6 +679,7 @@ async def accept_task(task_id: str, is_accepted: bool, feedback: str = None, cur
619
  - is_accepted=True: 验收通过,支付尾款给接��者
620
  - is_accepted=False: 验收不通过,可以要求修改或发起申诉
621
  💳 P6支付增强:使用SQL钱包支付,记录交易流水,发送支付通知
 
622
  """
623
  tasks_db = db.load_data("tasks.json", default_data=[])
624
 
@@ -638,50 +699,61 @@ async def accept_task(task_id: str, is_accepted: bool, feedback: str = None, cur
638
 
639
  if is_accepted:
640
  # 💳 验收通过:支付尾款给接单者
641
- publisher_wallet = db_session.query(Wallet).filter(Wallet.account == current_user).with_for_update().first()
642
- if not publisher_wallet or publisher_wallet.frozen_balance < remaining:
643
- raise HTTPException(status_code=400, detail="冻结余额不足支付尾款")
644
-
645
- # 扣除发布者尾款(从冻结余额
646
- publisher_wallet.frozen_balance -= remaining
647
-
648
- # 给接单者全款(订金+尾款)
649
- assignee_wallet = db_session.query(Wallet).filter(Wallet.account == assignee_account).with_for_update().first()
650
- if not assignee_wallet:
651
- assignee_wallet = Wallet(account=assignee_account, balance=0)
652
- db_session.add(assignee_wallet)
653
- db_session.flush()
654
-
655
- assignee_wallet.balance += total_price # 全款进入可用余额
656
-
657
- # 💳 记录交易流水
658
- # 1. 发布者支付尾款
659
- create_task_transaction(
660
- db_session, current_user, "TASK_PAYMENT",
661
- -remaining, related_account=assignee_account, task_id=task_id
662
- )
663
- # 2. 接单者收入
664
- create_task_transaction(
665
- db_session, assignee_account, "TASK_INCOME",
666
- total_price, related_account=current_user, task_id=task_id
667
- )
668
-
669
- task["status"] = "completed"
670
- task["completed_at"] = int(time.time())
671
-
672
- db_session.commit()
673
-
674
- logger.info(f"TASK_COMPLETE | publisher={current_user} | assignee={assignee_account} | task={task_id} | total={total_price}")
675
- message = f"验收通过,已支付 {total_price} 积分给接单者"
676
-
677
- # 🔔 支付通知:接单者收到款项
678
- add_notification(assignee_account, {
679
- "type": "task_payment",
680
- "from_user": current_user,
681
- "target_item_id": task_id,
682
- "target_item_title": task.get("title", ""),
683
- "content": f"💰 任务《{task.get('title', '')}》验收通过,{total_price}积分已到账"
684
- })
 
 
 
 
 
 
 
 
 
 
 
685
  else:
686
  # 验收不通过:回到进行中状态,允许修改后重新提交
687
  task["status"] = "in_progress"
@@ -689,6 +761,8 @@ async def accept_task(task_id: str, is_accepted: bool, feedback: str = None, cur
689
  task["submitted_at"] = None
690
  message = "验收不通过,接单者可以修改后重新提交"
691
 
 
 
692
  # 🔔 通知接单者:验收未通过
693
  add_notification(assignee_account, {
694
  "type": "task_rejected",
@@ -698,8 +772,6 @@ async def accept_task(task_id: str, is_accepted: bool, feedback: str = None, cur
698
  "content": f"任务《{task.get('title', '')}》验收未通过,请修改后重新提交"
699
  })
700
 
701
- db.save_data("tasks.json", tasks_db)
702
-
703
  return {"status": "success", "message": message}
704
 
705
  raise HTTPException(status_code=404, detail="任务不存在")
 
90
  - open 状态且超过截止日期:自动取消,退还冻结金额
91
  - in_progress 状态且超过截止日期:标记为过期(不自动取消,需双方处理)
92
  💳 P6支付增强:过期时自动退款
93
+ 💳 P0修复:事务完整性保护,失败则回滚
94
  """
95
  today = datetime.date.today().isoformat() # "2026-03-30"
96
  updated = False
97
  refund_tasks = [] # 需要退款的任务
98
+ expired_task_indices = [] # 记录过期任务的索引
99
 
100
+ # 第一阶段:扫描过期任务(不修改数据)
101
+ for idx, task in enumerate(tasks_db):
102
  deadline = task.get("deadline", "")
103
  status = task.get("status", "")
104
 
 
108
  # 检查是否过期
109
  if deadline < today:
110
  if status == "open":
111
+ # 开放接单状态且过期:记录需要处理
 
 
 
 
 
112
  frozen_amount = task.get("frozen_amount", task.get("total_price", 0))
113
+ expired_task_indices.append({
114
+ "index": idx,
115
+ "task_id": task["id"],
116
+ "publisher": task.get("publisher"),
117
+ "amount": frozen_amount,
118
+ "title": task.get("title", "")
119
+ })
120
+ updated = True
 
121
 
122
  elif status == "in_progress":
123
+ # 进行中且过期:直接标记涉及资金)
124
  task["is_overdue"] = True
125
  updated = True
126
 
127
+ # 第二阶段:执行退款操作(事务保护)
128
+ if expired_task_indices and db_session:
129
+ try:
130
+ refund_results = [] # 记录退款结果
131
+
132
+ for item in expired_task_indices:
133
+ if item["amount"] > 0:
134
+ wallet = db_session.query(Wallet).filter(Wallet.account == item["publisher"]).with_for_update().first()
135
+ if wallet:
136
+ wallet.frozen_balance = max(0, wallet.frozen_balance - item["amount"])
137
+ wallet.balance += item["amount"]
138
+
139
+ # 记录退款交易
140
+ create_task_transaction(
141
+ db_session, item["publisher"], "TASK_REFUND",
142
+ item["amount"], task_id=item["task_id"]
143
+ )
144
+
145
+ refund_results.append(item)
146
+
147
+ # 💳 P0修复:先 commit 数据库,再修改 JSON
148
+ db_session.commit()
149
+
150
+ # commit 成功后,修改 JSON 数据
151
+ for item in expired_task_indices:
152
+ task = tasks_db[item["index"]]
153
+ task["status"] = "expired"
154
+ task["expired_at"] = int(time.time())
155
+ if item["amount"] > 0:
156
+ task["refunded"] = True
157
+ task["refund_amount"] = item["amount"]
158
+
159
+ # 发送退款通知(在事务外执行,失败不影响主流程)
160
+ for item in refund_results:
161
+ try:
162
+ add_notification(item["publisher"], {
163
  "type": "task_refund",
164
  "from_user": "system",
165
+ "target_item_id": item["task_id"],
166
+ "target_item_title": item["title"],
167
+ "content": f"💰 任务《{item['title']}》已过期自动取消,{item['amount']}积分已退还"
168
  })
169
+ logger.info(f"TASK_REFUND | publisher={item['publisher']} | task={item['task_id']} | amount={item['amount']}")
170
+ except Exception as e:
171
+ logger.warning(f"TASK_REFUND_NOTIFY_ERROR | task={item['task_id']} | error={str(e)}")
172
 
173
+ except Exception as e:
174
+ # 💳 P0修复:事务失败,回滚所有操作
175
+ db_session.rollback()
176
+ logger.error(f"TASK_EXPIRED_REFUND_ROLLBACK | error={str(e)}")
177
+ # 不修改 JSON,下次再试
178
+ return False
179
+ elif expired_task_indices:
180
+ # 没有 db_session 但有过期任务:只更新状态,不处理退款
181
+ for item in expired_task_indices:
182
+ task = tasks_db[item["index"]]
183
+ task["status"] = "expired"
184
+ task["expired_at"] = int(time.time())
185
+ # 不标记退款,等待下次带 db_session 的调用
186
 
187
  return updated
188
 
 
437
  raise HTTPException(status_code=404, detail="任务不存在")
438
 
439
  @router.delete("/api/tasks/{task_id}")
440
+ async def cancel_task(task_id: str, current_user: str = Depends(require_auth), db_session: Session = Depends(get_db)):
441
  """
442
  取消任务(仅发布者可操作,且仅在 open 状态时可取消)
443
+ 💳 P0修复:取消时退还冻结资金
444
  """
445
  tasks_db = db.load_data("tasks.json", default_data=[])
446
 
 
452
  if task.get("status") != "open":
453
  raise HTTPException(status_code=400, detail="只能取消开放状态的任务")
454
 
455
+ # 💳 P0修复:退还冻结资金
456
+ frozen_amount = task.get("frozen_amount", task.get("total_price", 0))
457
+ refund_success = False
458
+
459
+ if frozen_amount > 0:
460
+ try:
461
+ wallet = db_session.query(Wallet).filter(Wallet.account == current_user).with_for_update().first()
462
+ if wallet:
463
+ wallet.frozen_balance = max(0, wallet.frozen_balance - frozen_amount)
464
+ wallet.balance += frozen_amount
465
+
466
+ # 记录退款交易
467
+ create_task_transaction(
468
+ db_session, current_user, "TASK_CANCEL_REFUND",
469
+ frozen_amount, task_id=task_id
470
+ )
471
+ db_session.commit()
472
+ refund_success = True
473
+ logger.info(f"TASK_CANCEL_REFUND | user={current_user} | task={task_id} | amount={frozen_amount}")
474
+ except Exception as e:
475
+ db_session.rollback()
476
+ logger.error(f"TASK_CANCEL_REFUND_ERROR | task={task_id} | error={str(e)}")
477
+ raise HTTPException(status_code=500, detail="退款失败,请稍后重试")
478
+ else:
479
+ refund_success = True # 无需退款
480
+
481
+ # 只有退款成功才更新状态
482
+ if refund_success:
483
+ task["status"] = "cancelled"
484
+ task["cancelled_at"] = int(time.time())
485
+ task["refunded"] = frozen_amount > 0
486
+ task["refund_amount"] = frozen_amount
487
+ db.save_data("tasks.json", tasks_db)
488
+
489
+ return {"status": "success", "refunded_amount": frozen_amount}
490
 
491
  raise HTTPException(status_code=404, detail="任务不存在")
492
 
 
679
  - is_accepted=True: 验收通过,支付尾款给接��者
680
  - is_accepted=False: 验收不通过,可以要求修改或发起申诉
681
  💳 P6支付增强:使用SQL钱包支付,记录交易流水,发送支付通知
682
+ 💳 P0修复:事务完整性保护,失败时回滚
683
  """
684
  tasks_db = db.load_data("tasks.json", default_data=[])
685
 
 
699
 
700
  if is_accepted:
701
  # 💳 验收通过:支付尾款给接单者
702
+ # 💳 P0修复:使用事务保护,失败则回滚
703
+ try:
704
+ publisher_wallet = db_session.query(Wallet).filter(Wallet.account == current_user).with_for_update().first()
705
+ if not publisher_wallet or publisher_wallet.frozen_balance < remaining:
706
+ raise HTTPException(status_code=400, detail="冻结余额不足支付尾款")
707
+
708
+ # 扣除发布者尾款(从冻结余额)
709
+ publisher_wallet.frozen_balance -= remaining
710
+
711
+ # 给接单者全款(订金+尾款)
712
+ assignee_wallet = db_session.query(Wallet).filter(Wallet.account == assignee_account).with_for_update().first()
713
+ if not assignee_wallet:
714
+ assignee_wallet = Wallet(account=assignee_account, balance=0)
715
+ db_session.add(assignee_wallet)
716
+ db_session.flush()
717
+
718
+ assignee_wallet.balance += total_price # 全款进入可用余额
719
+
720
+ # 💳 记录交易流水
721
+ # 1. 发布者支付尾款
722
+ create_task_transaction(
723
+ db_session, current_user, "TASK_PAYMENT",
724
+ -remaining, related_account=assignee_account, task_id=task_id
725
+ )
726
+ # 2. 接单者收入
727
+ create_task_transaction(
728
+ db_session, assignee_account, "TASK_INCOME",
729
+ total_price, related_account=current_user, task_id=task_id
730
+ )
731
+
732
+ task["status"] = "completed"
733
+ task["completed_at"] = int(time.time())
734
+
735
+ # 💳 P0修复:先保存JSON,再 commit,确保原子性
736
+ db.save_data("tasks.json", tasks_db)
737
+ db_session.commit()
738
+
739
+ logger.info(f"TASK_COMPLETE | publisher={current_user} | assignee={assignee_account} | task={task_id} | total={total_price}")
740
+ message = f"验收通过,已支付 {total_price} 积分给接单者"
741
+
742
+ # 🔔 支付通知:接单者收到款项
743
+ add_notification(assignee_account, {
744
+ "type": "task_payment",
745
+ "from_user": current_user,
746
+ "target_item_id": task_id,
747
+ "target_item_title": task.get("title", ""),
748
+ "content": f"💰 任务《{task.get('title', '')}》验收通过,{total_price}积分已到账"
749
+ })
750
+
751
+ except HTTPException:
752
+ raise # 重新抛出 HTTP 异常
753
+ except Exception as e:
754
+ db_session.rollback()
755
+ logger.error(f"TASK_ACCEPT_ERROR | task={task_id} | error={str(e)}")
756
+ raise HTTPException(status_code=500, detail="验收处理失败,请稍后重试")
757
  else:
758
  # 验收不通过:回到进行中状态,允许修改后重新提交
759
  task["status"] = "in_progress"
 
761
  task["submitted_at"] = None
762
  message = "验收不通过,接单者可以修改后重新提交"
763
 
764
+ db.save_data("tasks.json", tasks_db)
765
+
766
  # 🔔 通知接单者:验收未通过
767
  add_notification(assignee_account, {
768
  "type": "task_rejected",
 
772
  "content": f"任务《{task.get('title', '')}》验收未通过,请修改后重新提交"
773
  })
774
 
 
 
775
  return {"status": "success", "message": message}
776
 
777
  raise HTTPException(status_code=404, detail="任务不存在")
安全认证.py CHANGED
@@ -313,3 +313,182 @@ def verify_token_with_fallback(token: str) -> Tuple[bool, Optional[str], str]:
313
  return True, account, ""
314
 
315
  return False, None, error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  return True, account, ""
314
 
315
  return False, None, error_msg
316
+
317
+
318
+ # ==========================================
319
+ # 🛡️ P2优化:统一权限检查装饰器
320
+ # ==========================================
321
+ # 作用:提供所有者、管理员等权限检查功能
322
+ # 用法:减少路由中的重复权限检查代码
323
+
324
+ import os
325
+ from functools import wraps
326
+
327
+ # 管理员账号列表
328
+ ADMIN_ACCOUNTS = [a.strip() for a in os.getenv("ADMIN_ACCOUNTS", "admin").split(",")]
329
+
330
+
331
+ def is_admin(account: str) -> bool:
332
+ """
333
+ 检查用户是否为管理员
334
+
335
+ 参数:
336
+ account: 用户账号
337
+
338
+ 返回:
339
+ True 是管理员 / False 不是
340
+ """
341
+ return account in ADMIN_ACCOUNTS
342
+
343
+
344
+ async def require_admin(authorization: str = Header(None, alias="Authorization")) -> str:
345
+ """
346
+ FastAPI 依赖:要求管理员权限
347
+
348
+ 使用方法:
349
+ @router.delete("/api/admin/users/{user_id}")
350
+ async def delete_user(user_id: str, admin: str = Depends(require_admin)):
351
+ ...
352
+
353
+ 异常:
354
+ 401: 未登录
355
+ 403: 非管理员
356
+ """
357
+ account = await require_auth(authorization)
358
+
359
+ if not is_admin(account):
360
+ raise HTTPException(status_code=403, detail="需要管理员权限")
361
+
362
+ return account
363
+
364
+
365
+ def check_ownership(
366
+ item: dict,
367
+ current_user: str,
368
+ owner_field: str = "author",
369
+ allow_admin: bool = True
370
+ ) -> bool:
371
+ """
372
+ 检查用户是否为资源所有者
373
+
374
+ 参数:
375
+ item: 资源对象
376
+ current_user: 当前用户账号
377
+ owner_field: 所有者字段名(默认 "author")
378
+ allow_admin: 是否允许管理员操作(默认 True)
379
+
380
+ 返回:
381
+ True 有权限 / False 无权限
382
+
383
+ 示例:
384
+ if not check_ownership(task, current_user, "publisher"):
385
+ raise HTTPException(403, "无权操作")
386
+ """
387
+ # 检查是否为所有者
388
+ if item.get(owner_field) == current_user:
389
+ return True
390
+
391
+ # 检查是否为管理员
392
+ if allow_admin and is_admin(current_user):
393
+ return True
394
+
395
+ return False
396
+
397
+
398
+ def require_ownership(
399
+ item: dict,
400
+ current_user: str,
401
+ owner_field: str = "author",
402
+ allow_admin: bool = True,
403
+ error_msg: str = "无权操作此资源"
404
+ ):
405
+ """
406
+ 要求所有者权限(失败时抛出异常)
407
+
408
+ 参数:
409
+ item: 资源对象
410
+ current_user: 当前用户账号
411
+ owner_field: 所有者字段名
412
+ allow_admin: 是否允许管理员操作
413
+ error_msg: 错误提示
414
+
415
+ 异常:
416
+ 403 Forbidden: 无权操作
417
+
418
+ 示例:
419
+ task = get_by_id("tasks.json", task_id)
420
+ require_ownership(task, current_user, "publisher")
421
+ # 继续执行更新操作...
422
+ """
423
+ if not check_ownership(item, current_user, owner_field, allow_admin):
424
+ raise HTTPException(status_code=403, detail=error_msg)
425
+
426
+
427
+ def require_not_self(
428
+ target_user: str,
429
+ current_user: str,
430
+ error_msg: str = "不能对自己执行此操作"
431
+ ):
432
+ """
433
+ 要求操作对象不是自己
434
+
435
+ 使用场景:
436
+ - 不能关注自己
437
+ - 不能给自己发私信
438
+ - 不能接自己发布的任务
439
+
440
+ 异常:
441
+ 400 Bad Request: 不能对自己执行此操作
442
+ """
443
+ if target_user == current_user:
444
+ raise HTTPException(status_code=400, detail=error_msg)
445
+
446
+
447
+ def require_item_exists(
448
+ item: Optional[dict],
449
+ item_type: str = "资源"
450
+ ):
451
+ """
452
+ 要求资源存在
453
+
454
+ 参数:
455
+ item: 资源对象(可能为 None)
456
+ item_type: 资源类型名称(用于错误提示)
457
+
458
+ 异常:
459
+ 404 Not Found: 资源不存在
460
+
461
+ 示例:
462
+ task = get_by_id("tasks.json", task_id)
463
+ require_item_exists(task, "任务")
464
+ """
465
+ if not item:
466
+ raise HTTPException(status_code=404, detail=f"{item_type}不存在")
467
+
468
+
469
+ def require_status(
470
+ item: dict,
471
+ allowed_statuses: list,
472
+ status_field: str = "status",
473
+ error_msg: str = None
474
+ ):
475
+ """
476
+ 要求资源处于指定状态
477
+
478
+ 参数:
479
+ item: 资源对象
480
+ allowed_statuses: 允许的状态列表
481
+ status_field: 状态字段名
482
+ error_msg: 自定义错误信息
483
+
484
+ 异常:
485
+ 400 Bad Request: 状态不允许
486
+
487
+ 示例:
488
+ require_status(task, ["open"], error_msg="只能取消开放状态的任务")
489
+ """
490
+ current_status = item.get(status_field)
491
+ if current_status not in allowed_statuses:
492
+ if not error_msg:
493
+ error_msg = f"当前状态 ({current_status}) 不允许此操作"
494
+ raise HTTPException(status_code=400, detail=error_msg)
数据库连接.py CHANGED
@@ -59,6 +59,97 @@ os.makedirs(LOCAL_DB_DIR, exist_ok=True)
59
  os.makedirs(BACKUP_DIR, exist_ok=True)
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # ==========================================
63
  # 🔐 并发控制:线程锁 + 文件锁
64
  # ==========================================
@@ -126,17 +217,19 @@ else:
126
  # 📖 数据读取函数
127
  # ==========================================
128
  # 特点:
 
129
  # - 线程安全(threading.Lock)
130
  # - 文件不存在时从 HF 下载
131
  # - JSON 解析失败时返回默认值
132
 
133
- def load_data(file_name: str, default_data: Optional[Union[Dict, List]] = None) -> Union[Dict, List]:
134
  """
135
  从 JSON 文件加载数据
136
 
137
  参数:
138
  file_name: 文件名(如 users.json)
139
  default_data: 数据不存在时的默认值
 
140
 
141
  返回:
142
  解析后的 JSON 数据(dict 或 list)
@@ -146,6 +239,13 @@ def load_data(file_name: str, default_data: Optional[Union[Dict, List]] = None)
146
  default_data = {} if file_name == "users.json" else []
147
 
148
  local_path = os.path.join(LOCAL_DB_DIR, file_name)
 
 
 
 
 
 
 
149
  file_lock = _get_file_lock(file_name)
150
 
151
  with file_lock:
@@ -167,6 +267,8 @@ def load_data(file_name: str, default_data: Optional[Union[Dict, List]] = None)
167
  with open(local_path, "w", encoding="utf-8") as f:
168
  json.dump(data, f, ensure_ascii=False, indent=2)
169
  print(f"✅ 从 HF Dataset 下载 {file_name} 成功")
 
 
170
  return data
171
  except Exception as e:
172
  print(f"⚠️ 从 HF 下载 {file_name} 失败: {e}")
@@ -179,9 +281,13 @@ def load_data(file_name: str, default_data: Optional[Union[Dict, List]] = None)
179
  # 获取共享锁(读锁)
180
  _lock_file(f, exclusive=False)
181
  try:
182
- return json.load(f)
183
  finally:
184
  _unlock_file(f)
 
 
 
 
185
  except json.JSONDecodeError as e:
186
  print(f"🚨 JSON 解析错误 {file_name}: {e}")
187
  # 尝试从备份恢复
@@ -218,7 +324,8 @@ def _recover_from_backup(file_name: str, default_data: Union[Dict, List]) -> Uni
218
  # 1. 先写临时文件,验证成功后原子重命名
219
  # 2. 写入前备份上一版本
220
  # 3. 写入后验证数据完整性
221
- # 4. 异步同步到 HuggingFace
 
222
 
223
  def save_data(file_name: str, data: Union[Dict, List]) -> bool:
224
  """
@@ -235,6 +342,7 @@ def save_data(file_name: str, data: Union[Dict, List]) -> bool:
235
  - 原子写入:先写临时文件再重命名
236
  - 自动备份:保留上一版本
237
  - 完整性校验:验证 JSON 可解析
 
238
  - 异步同步:后台上传到 HF Dataset
239
  """
240
  local_path = os.path.join(LOCAL_DB_DIR, file_name)
@@ -285,6 +393,9 @@ def save_data(file_name: str, data: Union[Dict, List]) -> bool:
285
  print(f"🚨 保存 {file_name} 失败: {e}")
286
  raise
287
 
 
 
 
288
  # ========== 第五步:异步同步到云端 ==========
289
  # 🔧 P3优化:使用线程池替代直接创建线程
290
  if HF_TOKEN:
 
59
  os.makedirs(BACKUP_DIR, exist_ok=True)
60
 
61
 
62
+ # ==========================================
63
+ # 🚀 P1性能优化:内存缓存层
64
+ # ==========================================
65
+ # 缓存配置(秒)
66
+ CACHE_TTL = {
67
+ "items.json": 120, # 列表数据 2 分钟
68
+ "tasks.json": 60, # 任务数据 1 分钟
69
+ "users.json": 300, # 用户数据 5 分钟
70
+ "posts.json": 120, # 帖子数据 2 分钟
71
+ "comments.json": 60, # 评论数据 1 分钟
72
+ "disputes.json": 60, # 申诉数据 1 分钟
73
+ "_default": 60 # 默认 1 分钟
74
+ }
75
+
76
+ # 内存缓存存储
77
+ _memory_cache = {} # {file_name: {"data": ..., "time": ..., "mtime": ...}}
78
+ _cache_lock = threading.Lock()
79
+
80
+
81
+ def _get_cache_ttl(file_name: str) -> int:
82
+ """获取文件的缓存 TTL"""
83
+ return CACHE_TTL.get(file_name, CACHE_TTL["_default"])
84
+
85
+
86
+ def _get_from_cache(file_name: str, local_path: str) -> Optional[Union[Dict, List]]:
87
+ """
88
+ 从内存缓存获取数据
89
+
90
+ 返回:
91
+ 缓存数据(如果有效)或 None
92
+ """
93
+ with _cache_lock:
94
+ if file_name not in _memory_cache:
95
+ return None
96
+
97
+ cache_entry = _memory_cache[file_name]
98
+ now = time.time()
99
+
100
+ # 检查 TTL 是否过期
101
+ if now - cache_entry["time"] > _get_cache_ttl(file_name):
102
+ del _memory_cache[file_name]
103
+ return None
104
+
105
+ # 检查文件是否被修改(mtime 变化)
106
+ try:
107
+ current_mtime = os.path.getmtime(local_path)
108
+ if current_mtime != cache_entry.get("mtime"):
109
+ del _memory_cache[file_name]
110
+ return None
111
+ except OSError:
112
+ return None
113
+
114
+ logger.debug(f"✨ 内存缓存命中: {file_name}")
115
+ # 返回深拷贝,防止外部修改影响缓存
116
+ import copy
117
+ return copy.deepcopy(cache_entry["data"])
118
+
119
+
120
+ def _set_to_cache(file_name: str, data: Union[Dict, List], local_path: str):
121
+ """将数据存入内存缓存"""
122
+ with _cache_lock:
123
+ try:
124
+ mtime = os.path.getmtime(local_path) if os.path.exists(local_path) else 0
125
+ except OSError:
126
+ mtime = 0
127
+
128
+ import copy
129
+ _memory_cache[file_name] = {
130
+ "data": copy.deepcopy(data),
131
+ "time": time.time(),
132
+ "mtime": mtime
133
+ }
134
+ logger.debug(f"💾 内存缓存更新: {file_name}")
135
+
136
+
137
+ def invalidate_cache(file_name: str = None):
138
+ """
139
+ 使缓存失效
140
+
141
+ 参数:
142
+ file_name: 指定文件名,为 None 时清空所有缓存
143
+ """
144
+ with _cache_lock:
145
+ if file_name:
146
+ _memory_cache.pop(file_name, None)
147
+ logger.debug(f"🗑️ 缓存失效: {file_name}")
148
+ else:
149
+ _memory_cache.clear()
150
+ logger.debug("🗑️ 所有缓存已清空")
151
+
152
+
153
  # ==========================================
154
  # 🔐 并发控制:线程锁 + 文件锁
155
  # ==========================================
 
217
  # 📖 数据读取函数
218
  # ==========================================
219
  # 特点:
220
+ # - 🚀 P1优化:内存缓存优先
221
  # - 线程安全(threading.Lock)
222
  # - 文件不存在时从 HF 下载
223
  # - JSON 解析失败时返回默认值
224
 
225
+ def load_data(file_name: str, default_data: Optional[Union[Dict, List]] = None, skip_cache: bool = False) -> Union[Dict, List]:
226
  """
227
  从 JSON 文件加载数据
228
 
229
  参数:
230
  file_name: 文件名(如 users.json)
231
  default_data: 数据不存在时的默认值
232
+ skip_cache: 是否跳过内存缓存(默认 False)
233
 
234
  返回:
235
  解析后的 JSON 数据(dict 或 list)
 
239
  default_data = {} if file_name == "users.json" else []
240
 
241
  local_path = os.path.join(LOCAL_DB_DIR, file_name)
242
+
243
+ # 🚀 P1优化:先检查内存缓存
244
+ if not skip_cache:
245
+ cached_data = _get_from_cache(file_name, local_path)
246
+ if cached_data is not None:
247
+ return cached_data
248
+
249
  file_lock = _get_file_lock(file_name)
250
 
251
  with file_lock:
 
267
  with open(local_path, "w", encoding="utf-8") as f:
268
  json.dump(data, f, ensure_ascii=False, indent=2)
269
  print(f"✅ 从 HF Dataset 下载 {file_name} 成功")
270
+ # 🚀 P1优化:存入内存缓存
271
+ _set_to_cache(file_name, data, local_path)
272
  return data
273
  except Exception as e:
274
  print(f"⚠️ 从 HF 下载 {file_name} 失败: {e}")
 
281
  # 获取共享锁(读锁)
282
  _lock_file(f, exclusive=False)
283
  try:
284
+ data = json.load(f)
285
  finally:
286
  _unlock_file(f)
287
+
288
+ # 🚀 P1优化:存入内存缓存
289
+ _set_to_cache(file_name, data, local_path)
290
+ return data
291
  except json.JSONDecodeError as e:
292
  print(f"🚨 JSON 解析错误 {file_name}: {e}")
293
  # 尝试从备份恢复
 
324
  # 1. 先写临时文件,验证成功后原子重命名
325
  # 2. 写入前备份上一版本
326
  # 3. 写入后验证数据完整性
327
+ # 4. 🚀 P1优化:保存后更新内存缓存
328
+ # 5. 异步同步到 HuggingFace
329
 
330
  def save_data(file_name: str, data: Union[Dict, List]) -> bool:
331
  """
 
342
  - 原子写入:先写临时文件再重命名
343
  - 自动备份:保留上一版本
344
  - 完整性校验:验证 JSON 可解析
345
+ - 🚀 P1优化:保存后更新内存缓存
346
  - 异步同步:后台上传到 HF Dataset
347
  """
348
  local_path = os.path.join(LOCAL_DB_DIR, file_name)
 
393
  print(f"🚨 保存 {file_name} 失败: {e}")
394
  raise
395
 
396
+ # ========== 🚀 P1优化:更新内存缓存 ==========
397
+ _set_to_cache(file_name, data, local_path)
398
+
399
  # ========== 第五步:异步同步到云端 ==========
400
  # 🔧 P3优化:使用线程池替代直接创建线程
401
  if HF_TOKEN: