# router_messages.py from fastapi import APIRouter, HTTPException, Depends from pydantic import BaseModel import time import uuid import subprocess import os import 数据库连接 as db from notifications import add_notification from models import PrivateMessage from 安全认证 import require_auth, is_admin router = APIRouter() # ========================================== # 新增:系统公告请求体模型 # ========================================== class SystemAnnouncement(BaseModel): admin_account: str content: str # ========================================== # 新增:发布系统公告接口 (仅限管理员,使用JWT验证) # ========================================== @router.post("/api/system/announcement") async def publish_announcement(ann: SystemAnnouncement, current_user: str = Depends(require_auth)): # 🔒 P0安全修复:使用环境变量配置的管理员列表 if not is_admin(current_user): raise HTTPException(status_code=403, detail="无权发布系统公告,仅管理员可操作") # 查询管理员信息 users_db = db.load_data("users.json", default_data={}) admin_info = users_db.get(current_user, {}) announcements_db = db.load_data("announcements.json", default_data=[]) new_ann = { "id": f"sys_{int(time.time())}_{uuid.uuid4().hex[:6]}", "type": "system", "from_user": current_user, # 使用真实的管理员账号 "from_name": admin_info.get("name", current_user), # 使用真实昵称,fallback 为账号 "from_avatar": admin_info.get("avatarDataUrl", ""), # 使用真实头像 "content": ann.content, "created_at": int(time.time()) } announcements_db.append(new_ann) db.save_data("announcements.json", announcements_db) return {"status": "success"} # ========================================== # 管理员调试:执行 Python 脚本 # ========================================== class AdminScriptRequest(BaseModel): admin_account: str script_name: str # 🔒 P0安全修复:脚本白名单(仅允许执行指定的脚本) # 警告:添加新脚本前请确保其安全性 ALLOWED_SCRIPTS = { "密码迁移.py", # 用户密码哈希化迁移 "测试脚本.py", # 接口测试工具 "迁移_余额合并.py", # 一次性余额合并迁移(执行后可移除) } @router.post("/api/admin/run-script") async def run_admin_script(req: AdminScriptRequest, current_user: str = Depends(require_auth)): """ 管理员专属:执行指定的 Python 脚本 🔒 P0安全修复:白名单 + 路径穿越防护 """ # 🔒 P0安全修复:使用环境变量配置的管理员列表 if not is_admin(current_user): raise HTTPException(status_code=403, detail="无权执行此操作,仅管理员可操作") script_name = req.script_name.strip() if not script_name: raise HTTPException(status_code=400, detail="脚本名称不能为空") # 🔒 P0安全修复:路径穿越攻击防护 if ".." in script_name or "/" in script_name or "\\" in script_name: raise HTTPException(status_code=400, detail="🚨 安全拦截:脚本名称包含非法字符") # 🔒 P0安全修复:白名单检查 if script_name not in ALLOWED_SCRIPTS: raise HTTPException( status_code=403, detail=f"🚨 安全拦截:脚本 [{script_name}] 不在白名单中。允许的脚本: {list(ALLOWED_SCRIPTS)}" ) # 获取当前工作目录 current_dir = os.path.dirname(os.path.abspath(__file__)) script_path = os.path.join(current_dir, script_name) # 检查文件是否存在 if not os.path.exists(script_path): return { "status": "error", "output": f"❌ 脚本文件不存在: {script_name}\n\n白名单脚本: {list(ALLOWED_SCRIPTS)}" } try: # 执行脚本,设置超时 60 秒 result = subprocess.run( ["python", script_path], capture_output=True, text=True, timeout=60, cwd=current_dir, encoding="utf-8" ) output = "" if result.stdout: output += f"📝 标准输出:\n{result.stdout}\n" if result.stderr: output += f"\n⚠️ 错误输出:\n{result.stderr}" if not output: output = "✅ 脚本执行完成,无输出" return { "status": "success" if result.returncode == 0 else "error", "return_code": result.returncode, "output": output } except subprocess.TimeoutExpired: return { "status": "error", "output": "❌ 脚本执行超时 (60秒)" } except Exception as e: return { "status": "error", "output": f"❌ 执行异常: {str(e)}" } # ========================================== # 原有功能:私信与聊天 # ========================================== @router.post("/api/messages/private") async def send_private_message(msg: PrivateMessage): chats_db = db.load_data("chats.json", default_data={}) conv_id = f"{min(msg.sender, msg.receiver)}_{max(msg.sender, msg.receiver)}" if conv_id not in chats_db: chats_db[conv_id] = [] chat_msg = {"id": f"chat_{int(time.time())}_{uuid.uuid4().hex[:6]}", "sender": msg.sender, "receiver": msg.receiver, "content": msg.content, "created_at": int(time.time()), "is_read": False} chats_db[conv_id].append(chat_msg) db.save_data("chats.json", chats_db) add_notification(msg.receiver, {"type": "private", "from_user": msg.sender, "content": msg.content}) return {"status": "success"} @router.get("/api/chats/{account}") async def get_chat_list(account: str): chats_db = db.load_data("chats.json", default_data={}) users_db = db.load_data("users.json", default_data={}) chat_list = [] for conv_id, msgs in chats_db.items(): if account in conv_id: other_account = conv_id.replace(account, "").replace("_", "") if not msgs: continue last_msg = msgs[-1] unread_count = sum(1 for m in msgs if m["receiver"] == account and not m.get("is_read")) other_user = users_db.get(other_account, {}) chat_list.append({ "account": other_account, "name": other_user.get("name", other_account), "avatar": other_user.get("avatarDataUrl", ""), "last_message": last_msg["content"], "last_time": last_msg["created_at"], "unread_count": unread_count }) chat_list.sort(key=lambda x: x["last_time"], reverse=True) return {"status": "success", "data": chat_list} @router.get("/api/chats/{account}/{target_account}") async def get_chat_history(account: str, target_account: str): chats_db = db.load_data("chats.json", default_data={}) conv_id = f"{min(account, target_account)}_{max(account, target_account)}" msgs = chats_db.get(conv_id, []) now = int(time.time()) seven_days = 7 * 24 * 3600 valid_msgs = [] modified = False for m in msgs: if not m.get("is_read") or (now - m.get("created_at", 0) < seven_days): valid_msgs.append(m) else: modified = True # 本次访问即为已读 if m["receiver"] == account and not m.get("is_read"): m["is_read"] = True modified = True if modified or len(valid_msgs) != len(msgs): chats_db[conv_id] = valid_msgs db.save_data("chats.json", chats_db) return {"status": "success", "data": valid_msgs} # ========================================== # 改造:获取通知列表 (加入系统公告懒加载注入) # 使用 atomic_update 避免并发覆盖问题 # 🔥 性能优化:先用只读方式检查是否有实际变更,避免无意义的写入和HF上传 # ========================================== @router.get("/api/messages/{account}") async def get_messages(account: str, count_only: bool = False, current_user: str = Depends(require_auth)): # 🔥 count_only 模式:轻量级轮询,只返回未读数,不标记已读 if count_only: messages_db = db.load_data("messages.json", default_data={}) user_msgs = messages_db.get(account, []) now = int(time.time()) seven_days = 7 * 24 * 3600 # 只统计未过期的未读消息 unread = sum(1 for m in user_msgs if not m.get("is_read") and (now - m.get("created_at", 0) < seven_days)) return {"status": "success", "unread_count": unread} # 公告是只读的,先加载 announcements_db = db.load_data("announcements.json", default_data=[]) # 🔥 性能优化:先用只读方式检查是否有实际变更需要写入 messages_db = db.load_data("messages.json", default_data={}) user_msgs = messages_db.get(account, []) now = int(time.time()) seven_days = 7 * 24 * 3600 # 检查三个条件判断是否需要写入 # 1. 是否有新公告需要注入 user_msg_ids = {m.get("id") for m in user_msgs} has_new_announcements = any( ann.get("id") not in user_msg_ids for ann in announcements_db ) # 2. 是否有未读消息需要标记已读 has_unread = any(not m.get("is_read") for m in user_msgs) # 3. 是否有已读超过7天的消息需要清理 has_expired = any( m.get("is_read") and (now - m.get("created_at", 0) >= seven_days) for m in user_msgs ) needs_update = has_new_announcements or has_unread or has_expired # 🔥 修复:在标记已读之前先计算真实的未读数 unread_before_mark = sum(1 for m in user_msgs if not m.get("is_read") and (now - m.get("created_at", 0) < seven_days)) if not needs_update: # 无变更,直接返回只读数据,不触发写入和HF上传 return { "status": "success", "data": user_msgs, "unread_count": unread_before_mark # 🔥 修复:返回真实的未读数 } # 有变更需要写入,使用 atomic_update 保证并发安全 result_container = [None] def updater(data): user_msgs = data.get(account, []) # --- 核心改造区:瞬间比对并注入全局公告 --- user_msg_ids = {m.get("id") for m in user_msgs} injected = False for ann in announcements_db: if ann.get("id") not in user_msg_ids: new_sys_msg = dict(ann) new_sys_msg["is_read"] = False new_sys_msg["receiver"] = account user_msgs.append(new_sys_msg) injected = True if injected: # 重新按照时间倒序排列,让新公告置顶 user_msgs.sort(key=lambda x: x.get("created_at", 0), reverse=True) # ---------------------------------------- now = int(time.time()) seven_days = 7 * 24 * 3600 valid = [] # 如果注入了新公告,则判定为需要回写数据库保存 modified = injected for m in user_msgs: if not m.get("is_read") or (now - m.get("created_at", 0) < seven_days): valid.append(m) else: modified = True # 本次访问即为已读 - 将所有未读消息标记为已读 if not m.get("is_read"): m["is_read"] = True modified = True # 原地修改 data,atomic_update 会自动保存 data[account] = valid # 通过闭包返回结果 result_container[0] = { "status": "success", "data": valid, "unread_count": unread_before_mark # 🔥 修复:返回标记已读前的真实未读数 } db.atomic_update("messages.json", updater, default_data={}) return result_container[0] @router.post("/api/messages/{account}/read") async def mark_messages_read(account: str): """ 标记消息为已读(原子操作,并发安全) """ def updater(data): user_msgs = data.get(account, []) modified = False for m in user_msgs: if not m.get("is_read"): m["is_read"] = True modified = True # 原地修改 data,atomic_update 会自动保存 if modified: data[account] = user_msgs db.atomic_update("messages.json", updater, default_data={}) return {"status": "success"}