Spaces:
Running
Running
| # router_messages.py | |
| from fastapi import APIRouter, HTTPException, Depends | |
| from pydantic import BaseModel | |
| import time | |
| import uuid | |
| import subprocess | |
| import os | |
| import 数据库连接 as db | |
| from notifications import add_notification | |
| from models import PrivateMessage | |
| from 安全认证 import require_auth, is_admin | |
| router = APIRouter() | |
| # ========================================== | |
| # 新增:系统公告请求体模型 | |
| # ========================================== | |
| class SystemAnnouncement(BaseModel): | |
| admin_account: str | |
| content: str | |
| # ========================================== | |
| # 新增:发布系统公告接口 (仅限管理员,使用JWT验证) | |
| # ========================================== | |
| async def publish_announcement(ann: SystemAnnouncement, current_user: str = Depends(require_auth)): | |
| # 🔒 P0安全修复:使用环境变量配置的管理员列表 | |
| if not is_admin(current_user): | |
| raise HTTPException(status_code=403, detail="无权发布系统公告,仅管理员可操作") | |
| # 查询管理员信息 | |
| users_db = db.load_data("users.json", default_data={}) | |
| admin_info = users_db.get(current_user, {}) | |
| announcements_db = db.load_data("announcements.json", default_data=[]) | |
| new_ann = { | |
| "id": f"sys_{int(time.time())}_{uuid.uuid4().hex[:6]}", | |
| "type": "system", | |
| "from_user": current_user, # 使用真实的管理员账号 | |
| "from_name": admin_info.get("name", current_user), # 使用真实昵称,fallback 为账号 | |
| "from_avatar": admin_info.get("avatarDataUrl", ""), # 使用真实头像 | |
| "content": ann.content, | |
| "created_at": int(time.time()) | |
| } | |
| announcements_db.append(new_ann) | |
| db.save_data("announcements.json", announcements_db) | |
| return {"status": "success"} | |
| # ========================================== | |
| # 管理员调试:执行 Python 脚本 | |
| # ========================================== | |
| class AdminScriptRequest(BaseModel): | |
| admin_account: str | |
| script_name: str | |
| # 🔒 P0安全修复:脚本白名单(仅允许执行指定的脚本) | |
| # 警告:添加新脚本前请确保其安全性 | |
| ALLOWED_SCRIPTS = { | |
| "密码迁移.py", # 用户密码哈希化迁移 | |
| "测试脚本.py", # 接口测试工具 | |
| "迁移_余额合并.py", # 一次性余额合并迁移(执行后可移除) | |
| } | |
| async def run_admin_script(req: AdminScriptRequest, current_user: str = Depends(require_auth)): | |
| """ | |
| 管理员专属:执行指定的 Python 脚本 | |
| 🔒 P0安全修复:白名单 + 路径穿越防护 | |
| """ | |
| # 🔒 P0安全修复:使用环境变量配置的管理员列表 | |
| if not is_admin(current_user): | |
| raise HTTPException(status_code=403, detail="无权执行此操作,仅管理员可操作") | |
| script_name = req.script_name.strip() | |
| if not script_name: | |
| raise HTTPException(status_code=400, detail="脚本名称不能为空") | |
| # 🔒 P0安全修复:路径穿越攻击防护 | |
| if ".." in script_name or "/" in script_name or "\\" in script_name: | |
| raise HTTPException(status_code=400, detail="🚨 安全拦截:脚本名称包含非法字符") | |
| # 🔒 P0安全修复:白名单检查 | |
| if script_name not in ALLOWED_SCRIPTS: | |
| raise HTTPException( | |
| status_code=403, | |
| detail=f"🚨 安全拦截:脚本 [{script_name}] 不在白名单中。允许的脚本: {list(ALLOWED_SCRIPTS)}" | |
| ) | |
| # 获取当前工作目录 | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| script_path = os.path.join(current_dir, script_name) | |
| # 检查文件是否存在 | |
| if not os.path.exists(script_path): | |
| return { | |
| "status": "error", | |
| "output": f"❌ 脚本文件不存在: {script_name}\n\n白名单脚本: {list(ALLOWED_SCRIPTS)}" | |
| } | |
| try: | |
| # 执行脚本,设置超时 60 秒 | |
| result = subprocess.run( | |
| ["python", script_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=60, | |
| cwd=current_dir, | |
| encoding="utf-8" | |
| ) | |
| output = "" | |
| if result.stdout: | |
| output += f"📝 标准输出:\n{result.stdout}\n" | |
| if result.stderr: | |
| output += f"\n⚠️ 错误输出:\n{result.stderr}" | |
| if not output: | |
| output = "✅ 脚本执行完成,无输出" | |
| return { | |
| "status": "success" if result.returncode == 0 else "error", | |
| "return_code": result.returncode, | |
| "output": output | |
| } | |
| except subprocess.TimeoutExpired: | |
| return { | |
| "status": "error", | |
| "output": "❌ 脚本执行超时 (60秒)" | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "output": f"❌ 执行异常: {str(e)}" | |
| } | |
| # ========================================== | |
| # 原有功能:私信与聊天 | |
| # ========================================== | |
| async def send_private_message(msg: PrivateMessage): | |
| chats_db = db.load_data("chats.json", default_data={}) | |
| conv_id = f"{min(msg.sender, msg.receiver)}_{max(msg.sender, msg.receiver)}" | |
| if conv_id not in chats_db: chats_db[conv_id] = [] | |
| chat_msg = {"id": f"chat_{int(time.time())}_{uuid.uuid4().hex[:6]}", "sender": msg.sender, "receiver": msg.receiver, "content": msg.content, "created_at": int(time.time()), "is_read": False} | |
| chats_db[conv_id].append(chat_msg) | |
| db.save_data("chats.json", chats_db) | |
| add_notification(msg.receiver, {"type": "private", "from_user": msg.sender, "content": msg.content}) | |
| return {"status": "success"} | |
| async def get_chat_list(account: str): | |
| chats_db = db.load_data("chats.json", default_data={}) | |
| users_db = db.load_data("users.json", default_data={}) | |
| chat_list = [] | |
| for conv_id, msgs in chats_db.items(): | |
| if account in conv_id: | |
| other_account = conv_id.replace(account, "").replace("_", "") | |
| if not msgs: continue | |
| last_msg = msgs[-1] | |
| unread_count = sum(1 for m in msgs if m["receiver"] == account and not m.get("is_read")) | |
| other_user = users_db.get(other_account, {}) | |
| chat_list.append({ | |
| "account": other_account, | |
| "name": other_user.get("name", other_account), | |
| "avatar": other_user.get("avatarDataUrl", ""), | |
| "last_message": last_msg["content"], | |
| "last_time": last_msg["created_at"], | |
| "unread_count": unread_count | |
| }) | |
| chat_list.sort(key=lambda x: x["last_time"], reverse=True) | |
| return {"status": "success", "data": chat_list} | |
| async def get_chat_history(account: str, target_account: str): | |
| chats_db = db.load_data("chats.json", default_data={}) | |
| conv_id = f"{min(account, target_account)}_{max(account, target_account)}" | |
| msgs = chats_db.get(conv_id, []) | |
| now = int(time.time()) | |
| seven_days = 7 * 24 * 3600 | |
| valid_msgs = [] | |
| modified = False | |
| for m in msgs: | |
| if not m.get("is_read") or (now - m.get("created_at", 0) < seven_days): | |
| valid_msgs.append(m) | |
| else: | |
| modified = True | |
| # 本次访问即为已读 | |
| if m["receiver"] == account and not m.get("is_read"): | |
| m["is_read"] = True | |
| modified = True | |
| if modified or len(valid_msgs) != len(msgs): | |
| chats_db[conv_id] = valid_msgs | |
| db.save_data("chats.json", chats_db) | |
| return {"status": "success", "data": valid_msgs} | |
| # ========================================== | |
| # 改造:获取通知列表 (加入系统公告懒加载注入) | |
| # 使用 atomic_update 避免并发覆盖问题 | |
| # 🔥 性能优化:先用只读方式检查是否有实际变更,避免无意义的写入和HF上传 | |
| # ========================================== | |
| async def get_messages(account: str, count_only: bool = False, current_user: str = Depends(require_auth)): | |
| # 🔥 count_only 模式:轻量级轮询,只返回未读数,不标记已读 | |
| if count_only: | |
| messages_db = db.load_data("messages.json", default_data={}) | |
| user_msgs = messages_db.get(account, []) | |
| now = int(time.time()) | |
| seven_days = 7 * 24 * 3600 | |
| # 只统计未过期的未读消息 | |
| unread = sum(1 for m in user_msgs | |
| if not m.get("is_read") | |
| and (now - m.get("created_at", 0) < seven_days)) | |
| return {"status": "success", "unread_count": unread} | |
| # 公告是只读的,先加载 | |
| announcements_db = db.load_data("announcements.json", default_data=[]) | |
| # 🔥 性能优化:先用只读方式检查是否有实际变更需要写入 | |
| messages_db = db.load_data("messages.json", default_data={}) | |
| user_msgs = messages_db.get(account, []) | |
| now = int(time.time()) | |
| seven_days = 7 * 24 * 3600 | |
| # 检查三个条件判断是否需要写入 | |
| # 1. 是否有新公告需要注入 | |
| user_msg_ids = {m.get("id") for m in user_msgs} | |
| has_new_announcements = any( | |
| ann.get("id") not in user_msg_ids | |
| for ann in announcements_db | |
| ) | |
| # 2. 是否有未读消息需要标记已读 | |
| has_unread = any(not m.get("is_read") for m in user_msgs) | |
| # 3. 是否有已读超过7天的消息需要清理 | |
| has_expired = any( | |
| m.get("is_read") and (now - m.get("created_at", 0) >= seven_days) | |
| for m in user_msgs | |
| ) | |
| needs_update = has_new_announcements or has_unread or has_expired | |
| # 🔥 修复:在标记已读之前先计算真实的未读数 | |
| unread_before_mark = sum(1 for m in user_msgs | |
| if not m.get("is_read") | |
| and (now - m.get("created_at", 0) < seven_days)) | |
| if not needs_update: | |
| # 无变更,直接返回只读数据,不触发写入和HF上传 | |
| return { | |
| "status": "success", | |
| "data": user_msgs, | |
| "unread_count": unread_before_mark # 🔥 修复:返回真实的未读数 | |
| } | |
| # 有变更需要写入,使用 atomic_update 保证并发安全 | |
| result_container = [None] | |
| def updater(data): | |
| user_msgs = data.get(account, []) | |
| # --- 核心改造区:瞬间比对并注入全局公告 --- | |
| user_msg_ids = {m.get("id") for m in user_msgs} | |
| injected = False | |
| for ann in announcements_db: | |
| if ann.get("id") not in user_msg_ids: | |
| new_sys_msg = dict(ann) | |
| new_sys_msg["is_read"] = False | |
| new_sys_msg["receiver"] = account | |
| user_msgs.append(new_sys_msg) | |
| injected = True | |
| if injected: | |
| # 重新按照时间倒序排列,让新公告置顶 | |
| user_msgs.sort(key=lambda x: x.get("created_at", 0), reverse=True) | |
| # ---------------------------------------- | |
| now = int(time.time()) | |
| seven_days = 7 * 24 * 3600 | |
| valid = [] | |
| # 如果注入了新公告,则判定为需要回写数据库保存 | |
| modified = injected | |
| for m in user_msgs: | |
| if not m.get("is_read") or (now - m.get("created_at", 0) < seven_days): | |
| valid.append(m) | |
| else: | |
| modified = True | |
| # 本次访问即为已读 - 将所有未读消息标记为已读 | |
| if not m.get("is_read"): | |
| m["is_read"] = True | |
| modified = True | |
| # 原地修改 data,atomic_update 会自动保存 | |
| data[account] = valid | |
| # 通过闭包返回结果 | |
| result_container[0] = { | |
| "status": "success", | |
| "data": valid, | |
| "unread_count": unread_before_mark # 🔥 修复:返回标记已读前的真实未读数 | |
| } | |
| db.atomic_update("messages.json", updater, default_data={}) | |
| return result_container[0] | |
| async def mark_messages_read(account: str): | |
| """ | |
| 标记消息为已读(原子操作,并发安全) | |
| """ | |
| def updater(data): | |
| user_msgs = data.get(account, []) | |
| modified = False | |
| for m in user_msgs: | |
| if not m.get("is_read"): | |
| m["is_read"] = True | |
| modified = True | |
| # 原地修改 data,atomic_update 会自动保存 | |
| if modified: | |
| data[account] = user_msgs | |
| db.atomic_update("messages.json", updater, default_data={}) | |
| return {"status": "success"} |