""" OASIS Reddit模拟预设脚本 此脚本读取配置文件中的参数来执行模拟,实现全程自动化 功能特性: - 完成模拟后不立即关闭环境,进入等待命令模式 - 支持通过IPC接收Interview命令 - 支持单个Agent采访和批量采访 - 支持远程关闭环境命令 使用方式: python run_reddit_simulation.py --config /path/to/simulation_config.json python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 """ import argparse import asyncio import json import logging import os import random import signal import sys import sqlite3 from datetime import datetime from typing import Dict, Any, List, Optional # 全局变量:用于信号处理 _shutdown_event = None _cleanup_done = False # 添加项目路径 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) # 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) else: _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) import re class UnicodeFormatter(logging.Formatter): """自定义格式化器,将 Unicode 转义序列转换为可读字符""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') def format(self, record): result = super().format(record) def replace_unicode(match): try: return chr(int(match.group(1), 16)) except (ValueError, OverflowError): return match.group(0) return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) class MaxTokensWarningFilter(logging.Filter): """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" def filter(self, record): # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True # 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): """配置 OASIS 的日志,使用固定名称的日志文件""" os.makedirs(log_dir, exist_ok=True) # 清理旧的日志文件 for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): try: os.remove(old_log) except OSError: pass formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") loggers_config = { "social.agent": os.path.join(log_dir, "social.agent.log"), "social.twitter": os.path.join(log_dir, "social.twitter.log"), "social.rec": os.path.join(log_dir, "social.rec.log"), "oasis.env": os.path.join(log_dir, "oasis.env.log"), "table": os.path.join(log_dir, "table.log"), } for logger_name, log_file in loggers_config.items(): logger = logging.getLogger(logger_name) logger.setLevel(logging.DEBUG) logger.handlers.clear() file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.propagate = False try: from camel.models import ModelFactory from camel.types import ModelPlatformType import oasis from oasis import ( ActionType, LLMAction, ManualAction, generate_reddit_agent_graph ) except ImportError as e: print(f"错误: 缺少依赖 {e}") print("请先安装: pip install oasis-ai camel-ai") sys.exit(1) # IPC相关常量 IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: """命令类型常量""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: """IPC命令处理器""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir self.env = env self.agent_graph = agent_graph self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True # 确保目录存在 os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): """更新环境状态""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: """轮询获取待处理命令""" if not os.path.exists(self.commands_dir): return None # 获取命令文件(按时间排序) command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) command_files.sort(key=lambda x: x[1]) for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, OSError): continue return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): """发送响应""" response = { "command_id": command_id, "status": status, "result": result, "error": error, "timestamp": datetime.now().isoformat() } response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) # 删除命令文件 command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) except OSError: pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: """ 处理单个Agent采访命令 Returns: True 表示成功,False 表示失败 """ try: # 获取Agent agent = self.agent_graph.get_agent(agent_id) # 创建Interview动作 interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) # 执行Interview actions = {agent: interview_action} await self.env.step(actions) # 从数据库获取结果 result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}") return True except Exception as e: error_msg = str(e) print(f" Interview失败: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: """ 处理批量采访命令 Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ try: # 构建动作字典 actions = {} agent_prompts = {} # 记录每个agent的prompt for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") try: agent = self.agent_graph.get_agent(agent_id) actions[agent] = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) agent_prompts[agent_id] = prompt except Exception as e: print(f" 警告: 无法获取Agent {agent_id}: {e}") if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False # 执行批量Interview await self.env.step(actions) # 获取所有结果 results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) results[agent_id] = result self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) print(f" 批量Interview完成: {len(results)} 个Agent") return True except Exception as e: error_msg = str(e) print(f" 批量Interview失败: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """从数据库获取最新的Interview结果""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") result = { "agent_id": agent_id, "response": None, "timestamp": None } if not os.path.exists(db_path): return result try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # 查询最新的Interview记录 cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = ? AND user_id = ? ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) row = cursor.fetchone() if row: user_id, info_json, created_at = row try: info = json.loads(info_json) if info_json else {} result["response"] = info.get("response", info) result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json conn.close() except Exception as e: print(f" 读取Interview结果失败: {e}") return result async def process_commands(self) -> bool: """ 处理所有待处理命令 Returns: True 表示继续运行,False 表示应该退出 """ command = self.poll_command() if not command: return True command_id = command.get("command_id") command_type = command.get("command_type") args = command.get("args", {}) print(f"\n收到IPC命令: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( command_id, args.get("agent_id", 0), args.get("prompt", "") ) return True elif command_type == CommandType.BATCH_INTERVIEW: await self.handle_batch_interview( command_id, args.get("interviews", []) ) return True elif command_type == CommandType.CLOSE_ENV: print("收到关闭环境命令") self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) return False else: self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") return True class RedditSimulationRunner: """Reddit模拟运行器""" # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, ActionType.CREATE_POST, ActionType.CREATE_COMMENT, ActionType.LIKE_COMMENT, ActionType.DISLIKE_COMMENT, ActionType.SEARCH_POSTS, ActionType.SEARCH_USER, ActionType.TREND, ActionType.REFRESH, ActionType.DO_NOTHING, ActionType.FOLLOW, ActionType.MUTE, ] def __init__(self, config_path: str, wait_for_commands: bool = True): """ 初始化模拟运行器 Args: config_path: 配置文件路径 (simulation_config.json) wait_for_commands: 模拟完成后是否等待命令(默认True) """ self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) self.wait_for_commands = wait_for_commands self.env = None self.agent_graph = None self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: """加载配置文件""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) def _get_profile_path(self) -> str: """获取Profile文件路径""" return os.path.join(self.simulation_dir, "reddit_profiles.json") def _get_db_path(self) -> str: """获取数据库路径""" return os.path.join(self.simulation_dir, "reddit_simulation.db") def _create_model(self): """ 创建LLM模型 统一使用项目根目录 .env 文件中的配置(优先级最高): - LLM_API_KEY: API密钥 - LLM_BASE_URL: API基础URL - LLM_MODEL_NAME: 模型名称 """ # 优先从 .env 读取配置 llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") # 如果 .env 中没有,则使用 config 作为备用 if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") # 设置 camel-ai 所需的环境变量 if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, ) def _get_active_agents_for_round( self, env, current_hour: int, round_num: int ) -> List: """ 根据时间和配置决定本轮激活哪些Agent """ time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) if current_hour in peak_hours: multiplier = time_config.get("peak_activity_multiplier", 1.5) elif current_hour in off_peak_hours: multiplier = time_config.get("off_peak_activity_multiplier", 0.3) else: multiplier = 1.0 target_count = int(random.uniform(base_min, base_max) * multiplier) candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) if current_hour not in active_hours: continue if random.random() < activity_level: candidates.append(agent_id) selected_ids = random.sample( candidates, min(target_count, len(candidates)) ) if candidates else [] active_agents = [] for agent_id in selected_ids: try: agent = env.agent_graph.get_agent(agent_id) active_agents.append((agent_id, agent)) except Exception: pass return active_agents async def run(self, max_rounds: int = None): """运行Reddit模拟 Args: max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) """ print("=" * 60) print("OASIS Reddit模拟") print(f"配置文件: {self.config_path}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round # 如果指定了最大轮数,则截断 if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") print(f"\n模拟参数:") print(f" - 总模拟时长: {total_hours}小时") print(f" - 每轮时间: {minutes_per_round}分钟") print(f" - 总轮数: {total_rounds}") if max_rounds: print(f" - 最大轮数限制: {max_rounds}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") print("\n初始化LLM模型...") model = self._create_model() print("加载Agent Profile...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): print(f"错误: Profile文件不存在: {profile_path}") return self.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, ) db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) print(f"已删除旧数据库: {db_path}") print("创建OASIS环境...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 ) await self.env.reset() print("环境初始化完成\n") # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") # 执行初始事件 event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) if initial_posts: print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: agent = self.env.agent_graph.get_agent(agent_id) if agent in initial_actions: if not isinstance(initial_actions[agent], list): initial_actions[agent] = [initial_actions[agent]] initial_actions[agent].append(ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} )) else: initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} ) except Exception as e: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") if initial_actions: await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") # 主模拟循环 print("\n开始模拟循环...") start_time = datetime.now() for round_num in range(total_rounds): simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) if not active_agents: continue actions = { agent: LLMAction() for _, agent in active_agents } await self.env.step(actions) if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 print(f" [Day {simulated_day}, {simulated_hour:02d}:00] " f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") total_elapsed = (datetime.now() - start_time).total_seconds() print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") # 是否进入等待命令模式 if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") print("支持的命令: interview, batch_interview, close_env") print("=" * 60) self.ipc_handler.update_status("alive") # 等待命令循环(使用全局 _shutdown_event) try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() if not should_continue: break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) break # 收到退出信号 except asyncio.TimeoutError: pass except KeyboardInterrupt: print("\n收到中断信号") except asyncio.CancelledError: print("\n任务被取消") except Exception as e: print(f"\n命令处理出错: {e}") print("\n关闭环境...") # 关闭环境 self.ipc_handler.update_status("stopped") await self.env.close() print("环境已关闭") print("=" * 60) async def main(): parser = argparse.ArgumentParser(description='OASIS Reddit模拟') parser.add_argument( '--config', type=str, required=True, help='配置文件路径 (simulation_config.json)' ) parser.add_argument( '--max-rounds', type=int, default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) parser.add_argument( '--no-wait', action='store_true', default=False, help='模拟完成后立即关闭环境,不进入等待命令模式' ) args = parser.parse_args() # 在 main 函数开始时创建 shutdown 事件 global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) # 初始化日志配置(使用固定文件名,清理旧日志) simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) runner = RedditSimulationRunner( config_path=args.config, wait_for_commands=not args.no_wait ) await runner.run(max_rounds=args.max_rounds) def setup_signal_handlers(): """ 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 让程序有机会正常清理资源(关闭数据库、环境等) """ def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" print(f"\n收到 {sig_name} 信号,正在退出...") if not _cleanup_done: _cleanup_done = True if _shutdown_event: _shutdown_event.set() else: # 重复收到信号才强制退出 print("强制退出...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": setup_signal_handlers() try: asyncio.run(main()) except KeyboardInterrupt: print("\n程序被中断") except SystemExit: pass finally: print("模拟进程已退出")