Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- aworld/checkpoint/README.md +98 -0
- aworld/checkpoint/inmemory.py +101 -0
aworld/checkpoint/README.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Checkpoint Module
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
The Checkpoint module provides a robust and extensible framework for managing state snapshots (checkpoints) in Python applications. It is designed for scenarios where you need to persist, restore, and version the state of a process, session, or task.
|
| 5 |
+
|
| 6 |
+
```mermaid
|
| 7 |
+
sequenceDiagram
|
| 8 |
+
participant Application
|
| 9 |
+
participant CheckpointRepository
|
| 10 |
+
participant BackendStorage
|
| 11 |
+
|
| 12 |
+
Note over Application,BackendStorage: Create and store a checkpoint
|
| 13 |
+
%% Create and store a checkpoint
|
| 14 |
+
Application->>CheckpointRepository: create checkpoint
|
| 15 |
+
CheckpointRepository->>BackendStorage: put(checkpoint)
|
| 16 |
+
BackendStorage-->>CheckpointRepository: success
|
| 17 |
+
CheckpointRepository-->>Application: ack
|
| 18 |
+
|
| 19 |
+
Note over Application,BackendStorage: Retrieve the latest checkpoint by session
|
| 20 |
+
|
| 21 |
+
%% Retrieve the latest checkpoint by session
|
| 22 |
+
Application->>CheckpointRepository: get checkpoint by session_id
|
| 23 |
+
CheckpointRepository->>BackendStorage: get_by_session(session_id)
|
| 24 |
+
BackendStorage-->>CheckpointRepository: Checkpoint
|
| 25 |
+
CheckpointRepository-->>Application: Checkpoint
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Key Features
|
| 30 |
+
|
| 31 |
+
- **Structured Data Model**: Uses Pydantic's `BaseModel` for strong typing and validation of checkpoint data and metadata.
|
| 32 |
+
- **Versioning Support**: Built-in version management utilities for checkpoint evolution and comparison.
|
| 33 |
+
- **Extensible Repository Pattern**: Abstract base class (`BaseCheckpointRepository`) defines a standard interface for checkpoint storage, supporting both synchronous and asynchronous operations.
|
| 34 |
+
- **In-Memory Implementation**: Includes a simple, ready-to-use in-memory repository for development and testing.
|
| 35 |
+
- **Utility Functions**: Helper methods for creating, copying, and managing checkpoints.
|
| 36 |
+
|
| 37 |
+
## Data Structures
|
| 38 |
+
|
| 39 |
+
```mermaid
|
| 40 |
+
classDiagram
|
| 41 |
+
class Application {
|
| 42 |
+
+CheckpointRepository repo
|
| 43 |
+
+create_checkpoint()
|
| 44 |
+
+get_checkpoint_by_session()
|
| 45 |
+
}
|
| 46 |
+
class CheckpointRepository {
|
| 47 |
+
+put(checkpoint)
|
| 48 |
+
+get_by_session(session_id)
|
| 49 |
+
+delete_by_session(session_id)
|
| 50 |
+
-BackendStorage backend
|
| 51 |
+
}
|
| 52 |
+
class BackendStorage {
|
| 53 |
+
+put(checkpoint)
|
| 54 |
+
+get_by_session(session_id)
|
| 55 |
+
+delete_by_session(session_id)
|
| 56 |
+
}
|
| 57 |
+
Application --> CheckpointRepository : uses
|
| 58 |
+
CheckpointRepository --> BackendStorage : delegates
|
| 59 |
+
class Checkpoint {
|
| 60 |
+
+id: str
|
| 61 |
+
+ts: str
|
| 62 |
+
+metadata: CheckpointMetadata
|
| 63 |
+
+values: dict
|
| 64 |
+
+version: int
|
| 65 |
+
+parent_id: str
|
| 66 |
+
+namespace: str
|
| 67 |
+
}
|
| 68 |
+
class CheckpointMetadata {
|
| 69 |
+
+session_id: str
|
| 70 |
+
+task_id: str
|
| 71 |
+
}
|
| 72 |
+
Checkpoint o-- CheckpointMetadata
|
| 73 |
+
CheckpointRepository o-- Checkpoint
|
| 74 |
+
BackendStorage o-- Checkpoint
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
## Usage Example
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
from aworld.checkpoint import (
|
| 82 |
+
Checkpoint, CheckpointMetadata, empty_checkpoint, create_checkpoint, InMemoryCheckpointRepository
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Create a new checkpoint
|
| 86 |
+
metadata = CheckpointMetadata(session_id="session-123", task_id="task-456")
|
| 87 |
+
values = {"step": 1, "score": 100}
|
| 88 |
+
checkpoint = create_checkpoint(values=values, metadata=metadata)
|
| 89 |
+
|
| 90 |
+
# Store and retrieve using the in-memory repository
|
| 91 |
+
repo = InMemoryCheckpointRepository()
|
| 92 |
+
repo.put(checkpoint)
|
| 93 |
+
restored = repo.get(checkpoint.id)
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
## Extensibility
|
| 97 |
+
- Implement custom repositories by inheriting from `BaseCheckpointRepository` (e.g., for database, file, or cloud storage).
|
| 98 |
+
- Extend versioning logic via the `VersionUtils` class.
|
aworld/checkpoint/inmemory.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional
|
| 2 |
+
from . import Checkpoint, BaseCheckpointRepository, VersionUtils
|
| 3 |
+
|
| 4 |
+
class InMemoryCheckpointRepository(BaseCheckpointRepository):
|
| 5 |
+
"""
|
| 6 |
+
In-memory implementation of BaseCheckpointRepository.
|
| 7 |
+
Stores checkpoints in a simple in-memory dictionary.
|
| 8 |
+
Thread safety is not guaranteed.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self) -> None:
|
| 11 |
+
"""
|
| 12 |
+
Initialize the in-memory checkpoint repository.
|
| 13 |
+
"""
|
| 14 |
+
self._checkpoints: Dict[str, Checkpoint] = {}
|
| 15 |
+
self._session_index: Dict[str, List[str]] = {}
|
| 16 |
+
|
| 17 |
+
def get(self, checkpoint_id: str) -> Optional[Checkpoint]:
|
| 18 |
+
"""
|
| 19 |
+
Retrieve a checkpoint by its unique identifier.
|
| 20 |
+
Args:
|
| 21 |
+
checkpoint_id (str): The unique identifier of the checkpoint.
|
| 22 |
+
Returns:
|
| 23 |
+
Optional[Checkpoint]: The checkpoint if found, otherwise None.
|
| 24 |
+
"""
|
| 25 |
+
return self._checkpoints.get(checkpoint_id)
|
| 26 |
+
|
| 27 |
+
def list(self, params: Dict[str, Any]) -> List[Checkpoint]:
|
| 28 |
+
"""
|
| 29 |
+
List checkpoints matching the given parameters.
|
| 30 |
+
Args:
|
| 31 |
+
params (dict): Parameters to filter checkpoints.
|
| 32 |
+
Returns:
|
| 33 |
+
List[Checkpoint]: List of matching checkpoints.
|
| 34 |
+
"""
|
| 35 |
+
result = []
|
| 36 |
+
for cp in self._checkpoints.values():
|
| 37 |
+
match = True
|
| 38 |
+
for k, v in params.items():
|
| 39 |
+
if k == 'session_id':
|
| 40 |
+
if cp.metadata.session_id != v:
|
| 41 |
+
match = False
|
| 42 |
+
break
|
| 43 |
+
elif k == 'task_id':
|
| 44 |
+
if cp.metadata.task_id != v:
|
| 45 |
+
match = False
|
| 46 |
+
break
|
| 47 |
+
elif cp.get(k) != v:
|
| 48 |
+
match = False
|
| 49 |
+
break
|
| 50 |
+
if match:
|
| 51 |
+
result.append(cp)
|
| 52 |
+
return result
|
| 53 |
+
|
| 54 |
+
def put(self, checkpoint: Checkpoint) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Store a checkpoint.
|
| 57 |
+
Args:
|
| 58 |
+
checkpoint (Checkpoint): The checkpoint to store.
|
| 59 |
+
"""
|
| 60 |
+
# Find last version checkpoint by session_id
|
| 61 |
+
last_checkpoint = self.get_by_session(checkpoint.metadata.session_id)
|
| 62 |
+
|
| 63 |
+
if last_checkpoint:
|
| 64 |
+
# Compare versions to ensure optimistic locking
|
| 65 |
+
if VersionUtils.is_version_less(checkpoint, last_checkpoint.version):
|
| 66 |
+
raise ValueError(f"New checkpoint version {checkpoint.version} must be greater than last version {last_checkpoint.version}")
|
| 67 |
+
|
| 68 |
+
# Store the new checkpoint
|
| 69 |
+
self._checkpoints[checkpoint.id] = checkpoint
|
| 70 |
+
|
| 71 |
+
# Update session index
|
| 72 |
+
session_id = checkpoint.metadata.session_id
|
| 73 |
+
if session_id:
|
| 74 |
+
if session_id not in self._session_index:
|
| 75 |
+
self._session_index[session_id] = []
|
| 76 |
+
self._session_index[session_id].append(checkpoint.id)
|
| 77 |
+
|
| 78 |
+
def get_by_session(self, session_id: str) -> Optional[Checkpoint]:
|
| 79 |
+
"""
|
| 80 |
+
Get the latest checkpoint for a session.
|
| 81 |
+
Args:
|
| 82 |
+
session_id (str): The session identifier.
|
| 83 |
+
Returns:
|
| 84 |
+
Optional[Checkpoint]: The latest checkpoint if found, otherwise None.
|
| 85 |
+
"""
|
| 86 |
+
ids = self._session_index.get(session_id, [])
|
| 87 |
+
if not ids:
|
| 88 |
+
return None
|
| 89 |
+
# Assume the last one is the latest
|
| 90 |
+
last_id = ids[-1]
|
| 91 |
+
return self._checkpoints.get(last_id)
|
| 92 |
+
|
| 93 |
+
def delete_by_session(self, session_id: str) -> None:
|
| 94 |
+
"""
|
| 95 |
+
Delete all checkpoints related to a session.
|
| 96 |
+
Args:
|
| 97 |
+
session_id (str): The session identifier.
|
| 98 |
+
"""
|
| 99 |
+
ids = self._session_index.pop(session_id, [])
|
| 100 |
+
for cid in ids:
|
| 101 |
+
self._checkpoints.pop(cid, None)
|