Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| """ | |
| processor.py | |
| Used to clean raw trace data into standard storage structure for reinforcement learning training. | |
| """ | |
| import json | |
| import os | |
| import datetime | |
| from typing import Any | |
| import threading | |
| from aworld.utils import import_package | |
| from aworld.replay_buffer.base import DataRow, Experience, ExpMeta | |
| from aworld.logs.util import logger | |
| from aworld.utils.common import get_local_ip | |
| class ReplayBufferExporter: | |
| def __init__(self): | |
| """Initialize ReplayBufferExporter instance""" | |
| self._file_locks = {} | |
| self._lock_dict_lock = threading.Lock() | |
| self._task_output_paths = {} | |
| def _get_file_lock(self, file_path): | |
| """Get the lock for the specified file""" | |
| with self._lock_dict_lock: | |
| if file_path not in self._file_locks: | |
| self._file_locks[file_path] = threading.Lock() | |
| return self._file_locks[file_path] | |
| def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str): | |
| """ | |
| Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files | |
| Args: | |
| spans: span data list | |
| output_dir: output directory path | |
| """ | |
| # Ensure output directory exists | |
| import_package("oss2") | |
| import oss2 | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Get OSS credentials from environment variables | |
| enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true" | |
| access_key_id = os.getenv('OSS_ACCESS_KEY_ID') | |
| access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET') | |
| endpoint = os.getenv('OSS_ENDPOINT') | |
| bucket_name = os.getenv('OSS_BUCKET_NAME') | |
| bucket = None | |
| if not all([access_key_id, access_key_secret, endpoint, bucket_name]): | |
| enable_oss_export = False | |
| logger.warn("Missing required OSS environment variables") | |
| else: | |
| try: | |
| # Initialize OSS client | |
| auth = oss2.Auth(access_key_id, access_key_secret) | |
| bucket = oss2.Bucket(auth, endpoint, bucket_name) | |
| except Exception as e: | |
| enable_oss_export = False | |
| logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}") | |
| # Group by task_id | |
| task_groups = {} | |
| for span_data in spans: | |
| # Only process spans with 'step_execution_' prefix | |
| if not span_data['name'].startswith('step_execution_'): | |
| continue | |
| attr = span_data.get('attributes', {}) | |
| exp_id = attr.get('exp_id') | |
| task_id = attr.get('task_id', '') | |
| if not exp_id or not task_id: | |
| continue | |
| if task_id not in task_groups: | |
| task_groups[task_id] = {} | |
| if exp_id not in task_groups[task_id]: | |
| task_groups[task_id][exp_id] = { | |
| 'exp_meta': None, | |
| 'exp_data': None | |
| } | |
| # Process step_execution span | |
| task_name = attr.get('task_name', '') | |
| agent_id = attr.get('agent_id', '') | |
| step = attr.get('step', 0) | |
| execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', '')) | |
| observation = {} | |
| action = [] | |
| messages = [] | |
| pre_agent = None | |
| if 'observation' in attr: | |
| try: | |
| observation = json.loads(attr['observation']) | |
| except: | |
| observation = attr['observation'] | |
| if 'actions' in attr: | |
| try: | |
| action = json.loads(attr['actions']) | |
| except: | |
| action = attr['actions'] | |
| if 'messages' in attr: | |
| try: | |
| messages = json.loads(attr['messages']) | |
| except: | |
| messages = attr['messages'] | |
| pre_agent = attr.get('pre_agent', '') | |
| reward = attr.get('reward', 0.0) | |
| adv = attr.get('adv_t', 0.0) | |
| v = attr.get('v_t', 0.0) | |
| exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent) | |
| exp_data = Experience(observation, action, reward, adv, v, messages) | |
| task_groups[task_id][exp_id]['exp_meta'] = exp_meta | |
| task_groups[task_id][exp_id]['exp_data'] = exp_data | |
| # Process data for each task_id | |
| for task_id, exp_groups in task_groups.items(): | |
| # Merge data and generate final Experience object | |
| data_rows = [] | |
| # Read existing data (if any) | |
| output_path = self._task_output_paths.get(task_id) | |
| if not output_path: | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d") | |
| replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays") | |
| replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir) | |
| export_dir = os.path.abspath(replay_dataset_path) | |
| os.makedirs(export_dir, exist_ok=True) | |
| output_path = os.path.join(export_dir, f"task_replay_{task_id}.json") | |
| self._task_output_paths[task_id] = output_path | |
| # Use thread lock to protect read and write operations | |
| file_lock = self._get_file_lock(output_path) | |
| with file_lock: | |
| if os.path.exists(output_path): | |
| try: | |
| with open(output_path, 'r', encoding='utf-8') as f: | |
| existing_data = json.load(f) | |
| data_rows.extend([DataRow( | |
| ExpMeta(**row['exp_meta']), | |
| Experience(**row['exp_data']), | |
| row['id'] | |
| ) for row in existing_data]) | |
| except Exception as e: | |
| print(f"Failed to read existing file {output_path}: {str(e)}") | |
| # Add new data | |
| for exp_id, group in exp_groups.items(): | |
| if group['exp_meta'] and group['exp_data']: | |
| row = DataRow(group['exp_meta'], group['exp_data'], exp_id) | |
| data_rows.append(row) | |
| # Sort by execute_time | |
| data_rows.sort(key=lambda x: x.exp_meta.execute_time) | |
| # Export to json | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2) | |
| logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}") | |
| if enable_oss_export: | |
| # Upload to OSS | |
| try: | |
| # Get the relative path | |
| abs_path = os.path.abspath(output_path) | |
| path_parts = abs_path.split(os.sep) | |
| if len(path_parts) >= 4: | |
| # Get the last 4 parts of the path | |
| relative_path = os.sep.join(path_parts[-4:]) | |
| oss_key = relative_path | |
| else: | |
| oss_key = f"replay_buffer/{os.path.basename(output_path)}" | |
| bucket.put_object_from_file(oss_key, output_path) | |
| logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}") | |
| except Exception as e: | |
| logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}") | |