Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- aworld/replay_buffer/README.md +111 -0
- aworld/replay_buffer/__init__.py +2 -0
- aworld/replay_buffer/base.py +409 -0
- aworld/replay_buffer/processor.py +190 -0
- aworld/replay_buffer/query_filter.py +228 -0
aworld/replay_buffer/README.md
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Replay Buffer
|
| 2 |
+
|
| 3 |
+
A multi-process capable replay buffer system for storing and sampling experience data.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Multi-process Support**: Safe concurrent access using shared memory and locks
|
| 8 |
+
- **Flexible Querying**: Powerful query builder for filtering stored data
|
| 9 |
+
- **Task-based Organization**: Data organized by task_id and agent_id
|
| 10 |
+
- **Capacity Management**: FIFO eviction when reaching max capacity
|
| 11 |
+
- **Custom Sampling**: Implement custom sampling logic through Sampler interface
|
| 12 |
+
- **Data Conversion**: Custom data conversion through Converter interface
|
| 13 |
+
|
| 14 |
+
## Basic Usage
|
| 15 |
+
|
| 16 |
+
### Writing Data
|
| 17 |
+
|
| 18 |
+
```python
|
| 19 |
+
from aworld.replay_buffer import ReplayBuffer, DataRow, ExpMeta, Experience
|
| 20 |
+
from aworld.core.common import ActionModel, Observation
|
| 21 |
+
|
| 22 |
+
# Create a data row
|
| 23 |
+
data = DataRow(
|
| 24 |
+
exp_meta=ExpMeta(
|
| 25 |
+
task_id="task_1",
|
| 26 |
+
task_name="my_task",
|
| 27 |
+
agent_id="agent_1",
|
| 28 |
+
step=1,
|
| 29 |
+
execute_time=time.time()
|
| 30 |
+
),
|
| 31 |
+
exp_data=Experience(
|
| 32 |
+
state=Observation(),
|
| 33 |
+
action=ActionModel()
|
| 34 |
+
)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Store data
|
| 38 |
+
replay_buffer.store(data)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Reading Data
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from aworld.replay_buffer.query_filter import QueryBuilder
|
| 45 |
+
|
| 46 |
+
# Basic example
|
| 47 |
+
replay_buffer = ReplayBuffer()
|
| 48 |
+
query_condition = QueryBuilder().eq("exp_meta.task_name", "test_task").build()
|
| 49 |
+
data = replay_buffer.sample(sampler=RandomTaskSample(),
|
| 50 |
+
query_condition=query_condition,
|
| 51 |
+
converter=DefaultConverter(),
|
| 52 |
+
batch_size=1000)
|
| 53 |
+
|
| 54 |
+
# Query Task by task_id
|
| 55 |
+
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
|
| 56 |
+
data = replay_buffer.sample_task(query_condition=query, batch_size=10)
|
| 57 |
+
|
| 58 |
+
# Query Task by agent_id
|
| 59 |
+
query = QueryBuilder().eq("exp_meta.agent_id", "agent_1").build()
|
| 60 |
+
data = replay_buffer.sample_task(query_condition=query, batch_size=5)
|
| 61 |
+
```
|
| 62 |
+
## Multi-processing Example
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
import multiprocessing
|
| 66 |
+
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
|
| 67 |
+
|
| 68 |
+
manager = multiprocessing.Manager()
|
| 69 |
+
replay_buffer = ReplayBuffer(
|
| 70 |
+
storage=MultiProcMemoryStorage(
|
| 71 |
+
data_dict=manager.dict(),
|
| 72 |
+
fifo_queue=manager.list(),
|
| 73 |
+
lock=manager.Lock(),
|
| 74 |
+
max_capacity=10000
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Start writer processes
|
| 79 |
+
processes = [
|
| 80 |
+
multiprocessing.Process(target=write_processing, args=(replay_buffer, f"task_{i}"))
|
| 81 |
+
for i in range(4)
|
| 82 |
+
]
|
| 83 |
+
```
|
| 84 |
+
## Query Builder Examples
|
| 85 |
+
|
| 86 |
+
### Simple Equality
|
| 87 |
+
```python
|
| 88 |
+
QueryBuilder().eq("exp_meta.task_id", "123").build()
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Complex Conditions
|
| 92 |
+
```python
|
| 93 |
+
QueryBuilder()
|
| 94 |
+
.eq("exp_meta.task_id", "123")
|
| 95 |
+
.and_()
|
| 96 |
+
.eq("exp_meta.agent_id", "456")
|
| 97 |
+
.build()
|
| 98 |
+
```
|
| 99 |
+
### Nested Conditions
|
| 100 |
+
```python
|
| 101 |
+
QueryBuilder()
|
| 102 |
+
.eq("exp_meta.task_id", "123")
|
| 103 |
+
.and_()
|
| 104 |
+
.nested(
|
| 105 |
+
QueryBuilder()
|
| 106 |
+
.eq("exp_meta.agent_id", "111")
|
| 107 |
+
.or_()
|
| 108 |
+
.eq("exp_meta.agent_id", "222")
|
| 109 |
+
)
|
| 110 |
+
.build()
|
| 111 |
+
```
|
aworld/replay_buffer/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
aworld/replay_buffer/base.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import uuid
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, List, TypeVar
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from math import ceil
|
| 7 |
+
|
| 8 |
+
from aworld.core.common import ActionModel, Observation
|
| 9 |
+
from aworld.replay_buffer.query_filter import QueryCondition, QueryFilter
|
| 10 |
+
from aworld.logs.util import logger
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
T = TypeVar('T')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Experience:
|
| 18 |
+
'''
|
| 19 |
+
Experience of agent.
|
| 20 |
+
'''
|
| 21 |
+
state: Observation
|
| 22 |
+
actions: List[ActionModel]
|
| 23 |
+
reward_t: float = None
|
| 24 |
+
adv_t: float = None
|
| 25 |
+
v_t: float = None
|
| 26 |
+
messages: List[Dict] = None
|
| 27 |
+
|
| 28 |
+
def to_dict(self):
|
| 29 |
+
return {
|
| 30 |
+
"state": self.state,
|
| 31 |
+
"actions": self.actions,
|
| 32 |
+
"reward_t": self.reward_t,
|
| 33 |
+
"adv_t": self.adv_t,
|
| 34 |
+
"v_t": self.v_t,
|
| 35 |
+
"messages": self.messages
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class ExpMeta:
|
| 41 |
+
'''
|
| 42 |
+
Experience meta data.
|
| 43 |
+
'''
|
| 44 |
+
task_id: str
|
| 45 |
+
task_name: str
|
| 46 |
+
agent_id: str
|
| 47 |
+
step: int
|
| 48 |
+
execute_time: float
|
| 49 |
+
pre_agent: str
|
| 50 |
+
|
| 51 |
+
def to_dict(self):
|
| 52 |
+
return {
|
| 53 |
+
"task_id": self.task_id,
|
| 54 |
+
"task_name": self.task_name,
|
| 55 |
+
"agent_id": self.agent_id,
|
| 56 |
+
"step": self.step,
|
| 57 |
+
"execute_time": self.execute_time,
|
| 58 |
+
"pre_agent": self.pre_agent
|
| 59 |
+
}
|
| 60 |
+
@dataclass
|
| 61 |
+
class DataRow:
|
| 62 |
+
'''
|
| 63 |
+
Data row for storing data.
|
| 64 |
+
'''
|
| 65 |
+
exp_meta: ExpMeta
|
| 66 |
+
exp_data: Experience
|
| 67 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 68 |
+
|
| 69 |
+
def to_dict(self):
|
| 70 |
+
return {
|
| 71 |
+
"exp_meta": self.exp_meta.to_dict(),
|
| 72 |
+
"exp_data": self.exp_data.to_dict(),
|
| 73 |
+
"id": self.id
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Storage(ABC):
|
| 78 |
+
'''
|
| 79 |
+
Storage for storing and sampling data.
|
| 80 |
+
'''
|
| 81 |
+
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def add(self, data: DataRow):
|
| 84 |
+
'''
|
| 85 |
+
Add data to the storage.
|
| 86 |
+
Args:
|
| 87 |
+
data (DataRow): Data to add.
|
| 88 |
+
'''
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def add_batch(self, data_batch: List[DataRow]):
|
| 92 |
+
'''
|
| 93 |
+
Add batch of data to the storage.
|
| 94 |
+
Args:
|
| 95 |
+
data_batch (List[DataRow]): List of data to add.
|
| 96 |
+
'''
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
| 100 |
+
'''
|
| 101 |
+
Get the size of the storage.
|
| 102 |
+
Returns:
|
| 103 |
+
int: Size of the storage.
|
| 104 |
+
'''
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
| 108 |
+
'''
|
| 109 |
+
Get paginated data from the storage.
|
| 110 |
+
Args:
|
| 111 |
+
page (int): Page number.
|
| 112 |
+
page_size (int): Number of data per page.
|
| 113 |
+
Returns:
|
| 114 |
+
List[DataRow]: List of data.
|
| 115 |
+
'''
|
| 116 |
+
|
| 117 |
+
@abstractmethod
|
| 118 |
+
def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
| 119 |
+
'''
|
| 120 |
+
Get all data from the storage.
|
| 121 |
+
Returns:
|
| 122 |
+
List[DataRow]: List of data.
|
| 123 |
+
'''
|
| 124 |
+
|
| 125 |
+
@abstractmethod
|
| 126 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
| 127 |
+
'''
|
| 128 |
+
Get data by task_id from the storage.
|
| 129 |
+
Args:
|
| 130 |
+
task_id (str): Task id.
|
| 131 |
+
Returns:
|
| 132 |
+
List[DataRow]: List of data.
|
| 133 |
+
'''
|
| 134 |
+
|
| 135 |
+
@abstractmethod
|
| 136 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
| 137 |
+
'''
|
| 138 |
+
Get batch of data by task_ids from the storage.
|
| 139 |
+
Args:
|
| 140 |
+
task_ids (List[str]): List of task ids.
|
| 141 |
+
Returns:
|
| 142 |
+
Dict[str, List[DataRow]]: Dictionary of data.
|
| 143 |
+
The key is the task_id and the value is the list of data.
|
| 144 |
+
The list of data is sorted by step.
|
| 145 |
+
'''
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Sampler(ABC):
|
| 149 |
+
'''
|
| 150 |
+
Sample data from the storage.
|
| 151 |
+
'''
|
| 152 |
+
|
| 153 |
+
def sample(self,
|
| 154 |
+
storage: Storage,
|
| 155 |
+
batch_size: int,
|
| 156 |
+
query_condition: QueryCondition = None) -> List[DataRow]:
|
| 157 |
+
'''
|
| 158 |
+
Sample data from the storage.
|
| 159 |
+
Args:
|
| 160 |
+
storage (Storage): Storage to sample from.
|
| 161 |
+
batch_size (int): Number of data to sample.
|
| 162 |
+
query_condition (QueryCondition, optional): Query condition. Defaults to None.
|
| 163 |
+
Returns:
|
| 164 |
+
List[DataRow]
|
| 165 |
+
'''
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TaskSampler(Sampler):
|
| 169 |
+
'''
|
| 170 |
+
Sample task data from storage, returns Dict[str, List[DataRow]] where:
|
| 171 |
+
- key is task_id
|
| 172 |
+
- value is list of task all data rows
|
| 173 |
+
'''
|
| 174 |
+
|
| 175 |
+
def sorted_by_step(self, task_experience: List[DataRow]) -> List[DataRow]:
|
| 176 |
+
'''
|
| 177 |
+
Sort the task experience by step and execute_time.
|
| 178 |
+
Args:
|
| 179 |
+
task_experience (List[DataRow]): List of task experience.
|
| 180 |
+
Returns:
|
| 181 |
+
List[DataRow]: List of task experience sorted by step and execute_time.
|
| 182 |
+
'''
|
| 183 |
+
return sorted(task_experience, key=lambda x: (x.exp_meta.step, x.exp_meta.execute_time))
|
| 184 |
+
|
| 185 |
+
def sample(self,
|
| 186 |
+
storage: Storage,
|
| 187 |
+
batch_size: int,
|
| 188 |
+
query_condition: QueryCondition = None) -> List[DataRow]:
|
| 189 |
+
task_ids = self.sample_task_ids(storage, batch_size, query_condition)
|
| 190 |
+
return storage.get_bacth_by_task_ids(task_ids)
|
| 191 |
+
|
| 192 |
+
def sample_tasks(self,
|
| 193 |
+
storage: Storage,
|
| 194 |
+
batch_size: int,
|
| 195 |
+
query_condition: QueryCondition = None) -> Dict[str, List[DataRow]]:
|
| 196 |
+
'''
|
| 197 |
+
Sample data from the storage.
|
| 198 |
+
Args:
|
| 199 |
+
storage (Storage): Storage to sample from.
|
| 200 |
+
batch_size (int): Number of data to sample.
|
| 201 |
+
query_condition (QueryCondition, optional): Query condition. Defaults to None.
|
| 202 |
+
Returns:
|
| 203 |
+
Dict[str, List[DataRow]]: Dictionary of sampled data.
|
| 204 |
+
The key is the task_id and the value is the list of data.
|
| 205 |
+
The list of data is sorted by step.
|
| 206 |
+
'''
|
| 207 |
+
task_ids = self.sample_task_ids(storage, batch_size, query_condition)
|
| 208 |
+
raws = storage.get_bacth_by_task_ids(task_ids)
|
| 209 |
+
return {task_id: self.sorted_by_step(raws) for task_id, raws in raws.items()}
|
| 210 |
+
|
| 211 |
+
@abstractmethod
|
| 212 |
+
def sample_task_ids(self,
|
| 213 |
+
storage: Storage,
|
| 214 |
+
batch_size: int,
|
| 215 |
+
query_condition: QueryCondition = None) -> List[str]:
|
| 216 |
+
'''
|
| 217 |
+
Sample task_ids from the storage.
|
| 218 |
+
Args:
|
| 219 |
+
storage (Storage): Storage to sample from.
|
| 220 |
+
batch_size (int): Number of task_ids to sample.
|
| 221 |
+
query_condition (QueryCondition, optional): Query condition. Defaults to None.
|
| 222 |
+
Returns:
|
| 223 |
+
List[str]: List of task_ids.
|
| 224 |
+
'''
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class Converter(ABC):
|
| 228 |
+
'''
|
| 229 |
+
Convert data to dataset row.
|
| 230 |
+
'''
|
| 231 |
+
|
| 232 |
+
@abstractmethod
|
| 233 |
+
def to_dataset_row(self, task_experience: List[DataRow]) -> T:
|
| 234 |
+
'''
|
| 235 |
+
Convert task experience to dataset row.
|
| 236 |
+
Args:
|
| 237 |
+
task_experience (List[DataRow]): List of task experience.
|
| 238 |
+
Returns:
|
| 239 |
+
T: type of dataset row.
|
| 240 |
+
'''
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class InMemoryStorage(Storage):
|
| 244 |
+
'''
|
| 245 |
+
In-memory storage for storing and sampling data.
|
| 246 |
+
'''
|
| 247 |
+
|
| 248 |
+
def __init__(self, max_capacity: int = 10000):
|
| 249 |
+
self._data: Dict[str, List[DataRow]] = {}
|
| 250 |
+
self._max_capacity = max_capacity
|
| 251 |
+
self._fifo_queue = [] # (task_id)
|
| 252 |
+
|
| 253 |
+
def add(self, data: DataRow):
|
| 254 |
+
if not data:
|
| 255 |
+
raise ValueError("Data is required")
|
| 256 |
+
if not data.exp_meta:
|
| 257 |
+
raise ValueError("exp_meta is required")
|
| 258 |
+
|
| 259 |
+
while self.size() >= self._max_capacity and self._fifo_queue:
|
| 260 |
+
oldest_task_id = self._fifo_queue.pop(0)
|
| 261 |
+
if oldest_task_id in self._data:
|
| 262 |
+
del self._data[oldest_task_id]
|
| 263 |
+
|
| 264 |
+
if data.exp_meta.task_id not in self._data:
|
| 265 |
+
self._data[data.exp_meta.task_id] = []
|
| 266 |
+
self._data[data.exp_meta.task_id].append(data)
|
| 267 |
+
self._fifo_queue.append(data.exp_meta.task_id)
|
| 268 |
+
|
| 269 |
+
if data.exp_meta.task_id not in self._data:
|
| 270 |
+
self._data[data.exp_meta.task_id] = []
|
| 271 |
+
self._data[data.exp_meta.task_id].append(data)
|
| 272 |
+
|
| 273 |
+
def add_batch(self, data_batch: List[DataRow]):
|
| 274 |
+
for data in data_batch:
|
| 275 |
+
self.add(data)
|
| 276 |
+
|
| 277 |
+
def size(self, query_condition: QueryCondition = None) -> int:
|
| 278 |
+
return len(self.get_all(query_condition))
|
| 279 |
+
|
| 280 |
+
def get_paginated(self, page: int, page_size: int, query_condition: QueryCondition = None) -> List[DataRow]:
|
| 281 |
+
if page < 1:
|
| 282 |
+
raise ValueError("Page must be greater than 0")
|
| 283 |
+
if page_size < 1:
|
| 284 |
+
raise ValueError("Page size must be greater than 0")
|
| 285 |
+
all_data = self.get_all(query_condition)
|
| 286 |
+
start_index = (page - 1) * page_size
|
| 287 |
+
end_index = start_index + page_size
|
| 288 |
+
return all_data[start_index:end_index]
|
| 289 |
+
|
| 290 |
+
def get_all(self, query_condition: QueryCondition = None) -> List[DataRow]:
|
| 291 |
+
all_data = []
|
| 292 |
+
query_filter = None
|
| 293 |
+
if query_condition:
|
| 294 |
+
query_filter = QueryFilter(query_condition)
|
| 295 |
+
for data in self._data.values():
|
| 296 |
+
if query_filter:
|
| 297 |
+
all_data.extend(query_filter.filter(data))
|
| 298 |
+
else:
|
| 299 |
+
all_data.extend(data)
|
| 300 |
+
return all_data
|
| 301 |
+
|
| 302 |
+
def get_by_task_id(self, task_id: str) -> List[DataRow]:
|
| 303 |
+
return self._data.get(task_id, [])
|
| 304 |
+
|
| 305 |
+
def get_bacth_by_task_ids(self, task_ids: List[str]) -> Dict[str, List[DataRow]]:
|
| 306 |
+
return {task_id: self._data.get(task_id, []) for task_id in task_ids}
|
| 307 |
+
|
| 308 |
+
def clear(self):
|
| 309 |
+
self._data = {}
|
| 310 |
+
self._fifo_queue = []
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class RandomTaskSample(TaskSampler):
|
| 314 |
+
'''
|
| 315 |
+
Randomly sample data from the storage.
|
| 316 |
+
'''
|
| 317 |
+
|
| 318 |
+
def sample_task_ids(self,
|
| 319 |
+
storage: Storage,
|
| 320 |
+
batch_size: int,
|
| 321 |
+
query_condition: QueryCondition = None) -> List[str]:
|
| 322 |
+
total_size = storage.size(query_condition)
|
| 323 |
+
if total_size <= batch_size:
|
| 324 |
+
return storage.get_all(query_condition)
|
| 325 |
+
|
| 326 |
+
sampled_task_ids = set()
|
| 327 |
+
page_size = min(100, batch_size * 2)
|
| 328 |
+
total_pages = ceil(total_size/page_size)
|
| 329 |
+
visited_pages = set()
|
| 330 |
+
while len(sampled_task_ids) < batch_size and len(visited_pages) < total_pages:
|
| 331 |
+
page = random.choice(
|
| 332 |
+
[p for p in range(1, total_pages+1) if p not in visited_pages])
|
| 333 |
+
visited_pages.add(page)
|
| 334 |
+
|
| 335 |
+
current_page = storage.get_paginated(
|
| 336 |
+
page, page_size, query_condition)
|
| 337 |
+
if not current_page:
|
| 338 |
+
continue
|
| 339 |
+
current_page_task_ids = set(
|
| 340 |
+
[data.exp_meta.task_id for data in current_page if data.exp_meta.task_id not in sampled_task_ids])
|
| 341 |
+
sample_count = min(len(current_page_task_ids),
|
| 342 |
+
batch_size - len(sampled_task_ids))
|
| 343 |
+
sampled_task_ids.update(random.sample(
|
| 344 |
+
list(current_page_task_ids), sample_count))
|
| 345 |
+
|
| 346 |
+
return list(sampled_task_ids)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class DefaultConverter(Converter):
|
| 350 |
+
'''
|
| 351 |
+
Default converter do nothing.
|
| 352 |
+
'''
|
| 353 |
+
|
| 354 |
+
def to_dataset_row(self, task_experience: List[DataRow]) -> List[DataRow]:
|
| 355 |
+
return task_experience
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class ReplayBuffer:
|
| 359 |
+
'''
|
| 360 |
+
Replay buffer for storing and sampling data.
|
| 361 |
+
'''
|
| 362 |
+
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
storage: Storage = InMemoryStorage()
|
| 366 |
+
):
|
| 367 |
+
self._storage = storage
|
| 368 |
+
|
| 369 |
+
def store(self, data: DataRow):
|
| 370 |
+
'''
|
| 371 |
+
Store data in the replay buffer.
|
| 372 |
+
'''
|
| 373 |
+
if not data:
|
| 374 |
+
raise ValueError("Data is required")
|
| 375 |
+
self._storage.add(data)
|
| 376 |
+
|
| 377 |
+
def store_batch(self, data_batch: List[DataRow]):
|
| 378 |
+
'''
|
| 379 |
+
Store batch of data in the replay buffer.
|
| 380 |
+
'''
|
| 381 |
+
if not data_batch:
|
| 382 |
+
raise ValueError("Data batch is required")
|
| 383 |
+
self._storage.add_batch(data_batch)
|
| 384 |
+
|
| 385 |
+
def sample_task(self,
|
| 386 |
+
sampler: TaskSampler = RandomTaskSample(),
|
| 387 |
+
query_condition: QueryCondition = None,
|
| 388 |
+
converter: Converter = DefaultConverter(),
|
| 389 |
+
batch_size: int = 1000) -> List[T]:
|
| 390 |
+
'''
|
| 391 |
+
Sample Task from the replay buffer and convert to dataset row.
|
| 392 |
+
DefaultConverter return List[DataRow]
|
| 393 |
+
'''
|
| 394 |
+
sampled_task = sampler.sample_tasks(
|
| 395 |
+
self._storage, batch_size, query_condition)
|
| 396 |
+
return [converter.to_dataset_row(task_experiences) for task_experiences in sampled_task.values()]
|
| 397 |
+
|
| 398 |
+
def sample(self,
|
| 399 |
+
sampler: Sampler = RandomTaskSample(),
|
| 400 |
+
query_condition: QueryCondition = None,
|
| 401 |
+
converter: Converter = DefaultConverter(),
|
| 402 |
+
batch_size: int = 1000) -> List[T]:
|
| 403 |
+
'''
|
| 404 |
+
Sample data from the replay buffer and convert to dataset row.
|
| 405 |
+
DefaultConverter return List[DataRow]
|
| 406 |
+
'''
|
| 407 |
+
sampled_data = sampler.sample(
|
| 408 |
+
self._storage, batch_size, query_condition)
|
| 409 |
+
return converter.to_dataset_row(sampled_data)
|
aworld/replay_buffer/processor.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
"""
|
| 3 |
+
processor.py
|
| 4 |
+
Used to clean raw trace data into standard storage structure for reinforcement learning training.
|
| 5 |
+
"""
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import datetime
|
| 9 |
+
from typing import Any
|
| 10 |
+
import threading
|
| 11 |
+
|
| 12 |
+
from aworld.utils import import_package
|
| 13 |
+
from aworld.replay_buffer.base import DataRow, Experience, ExpMeta
|
| 14 |
+
from aworld.logs.util import logger
|
| 15 |
+
from aworld.utils.common import get_local_ip
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ReplayBufferExporter:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
"""Initialize ReplayBufferExporter instance"""
|
| 21 |
+
self._file_locks = {}
|
| 22 |
+
self._lock_dict_lock = threading.Lock()
|
| 23 |
+
self._task_output_paths = {}
|
| 24 |
+
|
| 25 |
+
def _get_file_lock(self, file_path):
|
| 26 |
+
"""Get the lock for the specified file"""
|
| 27 |
+
with self._lock_dict_lock:
|
| 28 |
+
if file_path not in self._file_locks:
|
| 29 |
+
self._file_locks[file_path] = threading.Lock()
|
| 30 |
+
return self._file_locks[file_path]
|
| 31 |
+
|
| 32 |
+
def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str):
|
| 33 |
+
"""
|
| 34 |
+
Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
spans: span data list
|
| 38 |
+
output_dir: output directory path
|
| 39 |
+
"""
|
| 40 |
+
# Ensure output directory exists
|
| 41 |
+
import_package("oss2")
|
| 42 |
+
import oss2
|
| 43 |
+
|
| 44 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
# Get OSS credentials from environment variables
|
| 47 |
+
enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true"
|
| 48 |
+
access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
|
| 49 |
+
access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
|
| 50 |
+
endpoint = os.getenv('OSS_ENDPOINT')
|
| 51 |
+
bucket_name = os.getenv('OSS_BUCKET_NAME')
|
| 52 |
+
bucket = None
|
| 53 |
+
|
| 54 |
+
if not all([access_key_id, access_key_secret, endpoint, bucket_name]):
|
| 55 |
+
enable_oss_export = False
|
| 56 |
+
logger.warn("Missing required OSS environment variables")
|
| 57 |
+
else:
|
| 58 |
+
try:
|
| 59 |
+
# Initialize OSS client
|
| 60 |
+
auth = oss2.Auth(access_key_id, access_key_secret)
|
| 61 |
+
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
enable_oss_export = False
|
| 64 |
+
logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}")
|
| 65 |
+
|
| 66 |
+
# Group by task_id
|
| 67 |
+
task_groups = {}
|
| 68 |
+
|
| 69 |
+
for span_data in spans:
|
| 70 |
+
# Only process spans with 'step_execution_' prefix
|
| 71 |
+
if not span_data['name'].startswith('step_execution_'):
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
attr = span_data.get('attributes', {})
|
| 75 |
+
exp_id = attr.get('exp_id')
|
| 76 |
+
task_id = attr.get('task_id', '')
|
| 77 |
+
|
| 78 |
+
if not exp_id or not task_id:
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
if task_id not in task_groups:
|
| 82 |
+
task_groups[task_id] = {}
|
| 83 |
+
|
| 84 |
+
if exp_id not in task_groups[task_id]:
|
| 85 |
+
task_groups[task_id][exp_id] = {
|
| 86 |
+
'exp_meta': None,
|
| 87 |
+
'exp_data': None
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
# Process step_execution span
|
| 91 |
+
task_name = attr.get('task_name', '')
|
| 92 |
+
agent_id = attr.get('agent_id', '')
|
| 93 |
+
step = attr.get('step', 0)
|
| 94 |
+
execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', ''))
|
| 95 |
+
|
| 96 |
+
observation = {}
|
| 97 |
+
action = []
|
| 98 |
+
messages = []
|
| 99 |
+
pre_agent = None
|
| 100 |
+
if 'observation' in attr:
|
| 101 |
+
try:
|
| 102 |
+
observation = json.loads(attr['observation'])
|
| 103 |
+
except:
|
| 104 |
+
observation = attr['observation']
|
| 105 |
+
|
| 106 |
+
if 'actions' in attr:
|
| 107 |
+
try:
|
| 108 |
+
action = json.loads(attr['actions'])
|
| 109 |
+
except:
|
| 110 |
+
action = attr['actions']
|
| 111 |
+
|
| 112 |
+
if 'messages' in attr:
|
| 113 |
+
try:
|
| 114 |
+
messages = json.loads(attr['messages'])
|
| 115 |
+
except:
|
| 116 |
+
messages = attr['messages']
|
| 117 |
+
|
| 118 |
+
pre_agent = attr.get('pre_agent', '')
|
| 119 |
+
reward = attr.get('reward', 0.0)
|
| 120 |
+
adv = attr.get('adv_t', 0.0)
|
| 121 |
+
v = attr.get('v_t', 0.0)
|
| 122 |
+
|
| 123 |
+
exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent)
|
| 124 |
+
exp_data = Experience(observation, action, reward, adv, v, messages)
|
| 125 |
+
|
| 126 |
+
task_groups[task_id][exp_id]['exp_meta'] = exp_meta
|
| 127 |
+
task_groups[task_id][exp_id]['exp_data'] = exp_data
|
| 128 |
+
|
| 129 |
+
# Process data for each task_id
|
| 130 |
+
for task_id, exp_groups in task_groups.items():
|
| 131 |
+
# Merge data and generate final Experience object
|
| 132 |
+
data_rows = []
|
| 133 |
+
|
| 134 |
+
# Read existing data (if any)
|
| 135 |
+
output_path = self._task_output_paths.get(task_id)
|
| 136 |
+
if not output_path:
|
| 137 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d")
|
| 138 |
+
replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays")
|
| 139 |
+
replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir)
|
| 140 |
+
export_dir = os.path.abspath(replay_dataset_path)
|
| 141 |
+
os.makedirs(export_dir, exist_ok=True)
|
| 142 |
+
output_path = os.path.join(export_dir, f"task_replay_{task_id}.json")
|
| 143 |
+
self._task_output_paths[task_id] = output_path
|
| 144 |
+
|
| 145 |
+
# Use thread lock to protect read and write operations
|
| 146 |
+
file_lock = self._get_file_lock(output_path)
|
| 147 |
+
with file_lock:
|
| 148 |
+
if os.path.exists(output_path):
|
| 149 |
+
try:
|
| 150 |
+
with open(output_path, 'r', encoding='utf-8') as f:
|
| 151 |
+
existing_data = json.load(f)
|
| 152 |
+
data_rows.extend([DataRow(
|
| 153 |
+
ExpMeta(**row['exp_meta']),
|
| 154 |
+
Experience(**row['exp_data']),
|
| 155 |
+
row['id']
|
| 156 |
+
) for row in existing_data])
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Failed to read existing file {output_path}: {str(e)}")
|
| 159 |
+
|
| 160 |
+
# Add new data
|
| 161 |
+
for exp_id, group in exp_groups.items():
|
| 162 |
+
if group['exp_meta'] and group['exp_data']:
|
| 163 |
+
row = DataRow(group['exp_meta'], group['exp_data'], exp_id)
|
| 164 |
+
data_rows.append(row)
|
| 165 |
+
|
| 166 |
+
# Sort by execute_time
|
| 167 |
+
data_rows.sort(key=lambda x: x.exp_meta.execute_time)
|
| 168 |
+
|
| 169 |
+
# Export to json
|
| 170 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 171 |
+
json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2)
|
| 172 |
+
logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}")
|
| 173 |
+
|
| 174 |
+
if enable_oss_export:
|
| 175 |
+
# Upload to OSS
|
| 176 |
+
try:
|
| 177 |
+
# Get the relative path
|
| 178 |
+
abs_path = os.path.abspath(output_path)
|
| 179 |
+
path_parts = abs_path.split(os.sep)
|
| 180 |
+
if len(path_parts) >= 4:
|
| 181 |
+
# Get the last 4 parts of the path
|
| 182 |
+
relative_path = os.sep.join(path_parts[-4:])
|
| 183 |
+
oss_key = relative_path
|
| 184 |
+
else:
|
| 185 |
+
oss_key = f"replay_buffer/{os.path.basename(output_path)}"
|
| 186 |
+
bucket.put_object_from_file(oss_key, output_path)
|
| 187 |
+
logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}")
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}")
|
| 190 |
+
|
aworld/replay_buffer/query_filter.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List, TypeVar, Union, Literal, TypedDict, Dict
|
| 2 |
+
|
| 3 |
+
DataRow = TypeVar('DataRow')
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseCondition(TypedDict):
|
| 7 |
+
field: str
|
| 8 |
+
value: Any
|
| 9 |
+
op: Literal[
|
| 10 |
+
'eq', 'ne', 'gt', 'gte', 'lt', 'lte',
|
| 11 |
+
'in', 'not_in', 'like', 'not_like',
|
| 12 |
+
'is_null', 'is_not_null'
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LogicalCondition(TypedDict):
|
| 17 |
+
and_: List['QueryCondition']
|
| 18 |
+
or_: List['QueryCondition']
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
QueryCondition = Union[BaseCondition, LogicalCondition]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class QueryBuilder:
|
| 25 |
+
'''
|
| 26 |
+
Query builder for replay buffer. result example:
|
| 27 |
+
{
|
| 28 |
+
"and": [
|
| 29 |
+
{"field": "field1", "value": "value1", "op": "eq"},
|
| 30 |
+
{"or": [{"field": "field2", "value": "value2", "op": "eq"}, {"field": "field3", "value": "value3", "op": "eq"}]}
|
| 31 |
+
]
|
| 32 |
+
}
|
| 33 |
+
'''
|
| 34 |
+
|
| 35 |
+
def __init__(self) -> None:
|
| 36 |
+
self.conditions: List[Dict[str, any]] = []
|
| 37 |
+
self.logical_ops: List[str] = []
|
| 38 |
+
|
| 39 |
+
def eq(self, field: str, value: any) -> 'QueryBuilder':
|
| 40 |
+
self.conditions.append({"field": field, "value": value, "op": "eq"})
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
def ne(self, field: str, value: any) -> 'QueryBuilder':
|
| 44 |
+
self.conditions.append({"field": field, "value": value, "op": "ne"})
|
| 45 |
+
return self
|
| 46 |
+
|
| 47 |
+
def gt(self, field: str, value: any) -> 'QueryBuilder':
|
| 48 |
+
self.conditions.append({"field": field, "value": value, "op": "gt"})
|
| 49 |
+
return self
|
| 50 |
+
|
| 51 |
+
def gte(self, field: str, value: any) -> 'QueryBuilder':
|
| 52 |
+
self.conditions.append({"field": field, "value": value, "op": "gte"})
|
| 53 |
+
return self
|
| 54 |
+
|
| 55 |
+
def lt(self, field: str, value: any) -> 'QueryBuilder':
|
| 56 |
+
self.conditions.append({"field": field, "value": value, "op": "lt"})
|
| 57 |
+
return self
|
| 58 |
+
|
| 59 |
+
def lte(self, field: str, value: any) -> 'QueryBuilder':
|
| 60 |
+
self.conditions.append({"field": field, "value": value, "op": "lte"})
|
| 61 |
+
return self
|
| 62 |
+
|
| 63 |
+
def in_(self, field: str, value: any) -> 'QueryBuilder':
|
| 64 |
+
self.conditions.append({"field": field, "value": value, "op": "in"})
|
| 65 |
+
return self
|
| 66 |
+
|
| 67 |
+
def not_in(self, field: str, value: any) -> 'QueryBuilder':
|
| 68 |
+
self.conditions.append(
|
| 69 |
+
{"field": field, "value": value, "op": "not_in"})
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
def like(self, field: str, value: any) -> 'QueryBuilder':
|
| 73 |
+
self.conditions.append({"field": field, "value": value, "op": "like"})
|
| 74 |
+
return self
|
| 75 |
+
|
| 76 |
+
def not_like(self, field: str, value: any) -> 'QueryBuilder':
|
| 77 |
+
self.conditions.append(
|
| 78 |
+
{"field": field, "value": value, "op": "not_like"})
|
| 79 |
+
return self
|
| 80 |
+
|
| 81 |
+
def is_null(self, field: str) -> 'QueryBuilder':
|
| 82 |
+
self.conditions.append({"field": field, "op": "is_null"})
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def is_not_null(self, field: str) -> 'QueryBuilder':
|
| 86 |
+
self.conditions.append({"field": field, "op": "is_not_null"})
|
| 87 |
+
return self
|
| 88 |
+
|
| 89 |
+
def and_(self) -> 'QueryBuilder':
|
| 90 |
+
self.logical_ops.append("and_")
|
| 91 |
+
return self
|
| 92 |
+
|
| 93 |
+
def or_(self) -> 'QueryBuilder':
|
| 94 |
+
self.logical_ops.append("or_")
|
| 95 |
+
return self
|
| 96 |
+
|
| 97 |
+
def nested(self, builder: 'QueryBuilder') -> 'QueryBuilder':
|
| 98 |
+
self.conditions.append({"nested": builder.build()})
|
| 99 |
+
return self
|
| 100 |
+
|
| 101 |
+
def build(self) -> QueryCondition:
|
| 102 |
+
conditions = self.conditions # all conditions(including nested)
|
| 103 |
+
operators = self.logical_ops
|
| 104 |
+
|
| 105 |
+
# Validate condition and operator counts (n conditions need n-1 operators)
|
| 106 |
+
if len(operators) != len(conditions) - 1:
|
| 107 |
+
raise ValueError("Mismatch between condition and operator counts")
|
| 108 |
+
|
| 109 |
+
# Use stack to handle operator precedence (simplified version supporting and/or)
|
| 110 |
+
stack: List[Union[Dict[str, any], str]] = []
|
| 111 |
+
|
| 112 |
+
for i, item in enumerate(conditions):
|
| 113 |
+
if i == 0:
|
| 114 |
+
# First element goes directly to stack (condition or nested)
|
| 115 |
+
stack.append(item)
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# Pop stack top as left operand
|
| 119 |
+
left = stack.pop()
|
| 120 |
+
op = operators[i-1] # Current operator (and/or)
|
| 121 |
+
right = item # Right operand (current condition)
|
| 122 |
+
|
| 123 |
+
# Build logical expression: {op: [left, right]}
|
| 124 |
+
expr = {op: [left, right]}
|
| 125 |
+
# Push result back to stack for further operations
|
| 126 |
+
stack.append(expr)
|
| 127 |
+
|
| 128 |
+
# Process nested conditions (recursive unfolding)
|
| 129 |
+
def process_nested(cond: any) -> any:
|
| 130 |
+
if isinstance(cond, dict):
|
| 131 |
+
if "nested" in cond:
|
| 132 |
+
# Recursively process sub-conditions
|
| 133 |
+
return process_nested(cond["nested"])
|
| 134 |
+
# Recursively process child elements
|
| 135 |
+
return {k: process_nested(v) for k, v in cond.items()}
|
| 136 |
+
elif isinstance(cond, list):
|
| 137 |
+
return [process_nested(item) for item in cond]
|
| 138 |
+
return cond
|
| 139 |
+
|
| 140 |
+
# Final result: only one element left in stack, return after processing nested
|
| 141 |
+
result = stack[0] if stack else None
|
| 142 |
+
return process_nested(result) if result else None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class QueryFilter:
|
| 146 |
+
'''
|
| 147 |
+
Query filter for replay buffer.
|
| 148 |
+
'''
|
| 149 |
+
|
| 150 |
+
def __init__(self, query_condition: QueryCondition) -> None:
|
| 151 |
+
self.query_condition = query_condition
|
| 152 |
+
|
| 153 |
+
def _get_field_value(self, row: DataRow, field: str) -> Any:
|
| 154 |
+
'''
|
| 155 |
+
Get field value from row.
|
| 156 |
+
'''
|
| 157 |
+
obj = row
|
| 158 |
+
for part in field.split('.'):
|
| 159 |
+
obj = getattr(obj, part, None)
|
| 160 |
+
if obj is None:
|
| 161 |
+
break
|
| 162 |
+
return obj
|
| 163 |
+
|
| 164 |
+
def _do_check(self, row: DataRow, condition: QueryCondition) -> bool:
|
| 165 |
+
"""
|
| 166 |
+
check if row match condition
|
| 167 |
+
"""
|
| 168 |
+
if condition is None:
|
| 169 |
+
return True
|
| 170 |
+
if "field" in condition and "op" in condition:
|
| 171 |
+
field_val = self._get_field_value(row, condition["field"])
|
| 172 |
+
op = condition["op"]
|
| 173 |
+
target_val = condition["value"]
|
| 174 |
+
|
| 175 |
+
if op == "eq":
|
| 176 |
+
return field_val == target_val
|
| 177 |
+
if op == "ne":
|
| 178 |
+
return field_val != target_val
|
| 179 |
+
if op == "gt":
|
| 180 |
+
return field_val > target_val
|
| 181 |
+
if op == "gte":
|
| 182 |
+
return field_val >= target_val
|
| 183 |
+
if op == "lt":
|
| 184 |
+
return field_val < target_val
|
| 185 |
+
if op == "lte":
|
| 186 |
+
return field_val <= target_val
|
| 187 |
+
if op == "in":
|
| 188 |
+
return field_val in target_val
|
| 189 |
+
if op == "not_in":
|
| 190 |
+
return field_val not in target_val
|
| 191 |
+
if op == "like":
|
| 192 |
+
return target_val in field_val
|
| 193 |
+
if op == "not_like":
|
| 194 |
+
return target_val not in field_val
|
| 195 |
+
if op == "is_null":
|
| 196 |
+
return field_val is None
|
| 197 |
+
if op == "is_not_null":
|
| 198 |
+
return field_val is not None
|
| 199 |
+
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
elif "and_" in condition or "or_" in condition:
|
| 203 |
+
if "and_" in condition:
|
| 204 |
+
return all(self._do_check(row, c) for c in condition["and_"])
|
| 205 |
+
if "or_" in condition:
|
| 206 |
+
return any(self._do_check(row, c) for c in condition["or_"])
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
def check_condition(self, row: DataRow) -> bool:
|
| 212 |
+
"""
|
| 213 |
+
check if row match condition
|
| 214 |
+
"""
|
| 215 |
+
return self._do_check(row, self.query_condition)
|
| 216 |
+
|
| 217 |
+
def filter(self, rows: List[DataRow]) -> List[DataRow]:
|
| 218 |
+
"""filter rows by condition
|
| 219 |
+
Args:
|
| 220 |
+
rows (List[DataRow]): List of rows to filter.
|
| 221 |
+
query_condition (QueryCondition): Query condition.
|
| 222 |
+
Returns:
|
| 223 |
+
List[DataRow]: List of rows that match the condition.
|
| 224 |
+
"""
|
| 225 |
+
condition = self.query_condition
|
| 226 |
+
if not condition:
|
| 227 |
+
return rows
|
| 228 |
+
return [row for row in rows if self.check_condition(row)]
|