Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- aworld/output/README.md +156 -0
- aworld/output/__init__.py +31 -0
- aworld/output/artifact.py +155 -0
- aworld/output/base.py +405 -0
- aworld/output/code_artifact.py +277 -0
- aworld/output/observer.py +189 -0
- aworld/output/outputs.py +202 -0
- aworld/output/utils.py +80 -0
- aworld/output/workspace.py +522 -0
aworld/output/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AWorld Output Module
|
| 2 |
+
|
| 3 |
+
The Output module is a flexible and extensible system for managing outputs and artifacts in the AWorld framework.
|
| 4 |
+
|
| 5 |
+
## Key Features
|
| 6 |
+
|
| 7 |
+
- **Unified Output Management**: Centralized management of all output types (messages, artifacts, tool results, etc.) through a flexible Outputs interface.
|
| 8 |
+
- **Support for Multiple Output Types**: Handles text, code, files, tool calls, and custom outputs, enabling rich interaction and extensibility.
|
| 9 |
+
- **Async & Sync Streaming**: Provides both asynchronous and synchronous streaming of outputs, supporting real-time and batch processing scenarios.
|
| 10 |
+
- **Output Aggregation & Dispatch**: Aggregates outputs from various sources and dispatches them to different consumers or UIs.
|
| 11 |
+
- **Extensible Output Channels**: Easily extendable output channels and renderers for custom UI or integration needs.
|
| 12 |
+
- **Integration with Task & Workspace**: Seamlessly integrates with Task and WorkSpace modules for collaborative, versioned, and observable output management.
|
| 13 |
+
- **Real-time, Batch, and Streaming Modes**: Supports real-time, batch, and streaming output modes for flexible workflow requirements.
|
| 14 |
+
- **Observer Pattern Support**: Built-in observer pattern for real-time updates and notifications on output or artifact changes.
|
| 15 |
+
- **Easy Customization**: Designed for easy extension and customization to fit various application scenarios.
|
| 16 |
+
- **Rich UI Rendering**: Supports diverse output rendering with pluggable UI renderers for CLI, web, and custom frontends.
|
| 17 |
+
- **Decoupled UI & Output Types**: UI rendering is decoupled from output types, enabling flexible presentation and interaction.
|
| 18 |
+
- **Real-time Interactive Display**: Enables real-time, streaming, and interactive output display in various UI environments.
|
| 19 |
+
- **UI Extensibility**: Easy to extend and customize UI components to fit different user experiences and workflows.
|
| 20 |
+
|
| 21 |
+
## Class Diagram
|
| 22 |
+
|
| 23 |
+
```mermaid
|
| 24 |
+
classDiagram
|
| 25 |
+
direction TB
|
| 26 |
+
%% Output Related Classes
|
| 27 |
+
class Output {
|
| 28 |
+
+metadata: Dict
|
| 29 |
+
+parts: List[OutputPart]
|
| 30 |
+
}
|
| 31 |
+
class OutputPart {
|
| 32 |
+
+content: Any
|
| 33 |
+
+metadata: Dict
|
| 34 |
+
}
|
| 35 |
+
Output *-- OutputPart
|
| 36 |
+
class Outputs {
|
| 37 |
+
<<abstract>>
|
| 38 |
+
+add_output(output: Output)
|
| 39 |
+
+sync_add_output(output: Output)
|
| 40 |
+
+stream_events()
|
| 41 |
+
+sync_stream_events()
|
| 42 |
+
+mark_completed()
|
| 43 |
+
}
|
| 44 |
+
class AsyncOutputs {
|
| 45 |
+
+add_output(output: Output)
|
| 46 |
+
+sync_add_output(output: Output)
|
| 47 |
+
+stream_events()
|
| 48 |
+
+sync_stream_events()
|
| 49 |
+
}
|
| 50 |
+
class DefaultOutputs {
|
| 51 |
+
+_outputs: List[Output]
|
| 52 |
+
+add_output(output: Output)
|
| 53 |
+
+sync_add_output(output: Output)
|
| 54 |
+
+stream_events()
|
| 55 |
+
+sync_stream_events()
|
| 56 |
+
+mark_completed()
|
| 57 |
+
}
|
| 58 |
+
class StreamingOutputs {
|
| 59 |
+
+input: Any
|
| 60 |
+
+usage: dict
|
| 61 |
+
+is_complete: bool
|
| 62 |
+
+_output_queue: asyncio.Queue[Output]
|
| 63 |
+
+_visited_outputs: List[Output]
|
| 64 |
+
+_stored_exception: Exception
|
| 65 |
+
+_run_impl_task: asyncio.Task
|
| 66 |
+
+add_output(output: Output)
|
| 67 |
+
+stream_events()
|
| 68 |
+
+mark_completed()
|
| 69 |
+
}
|
| 70 |
+
class MessageOutput {
|
| 71 |
+
+source: Any
|
| 72 |
+
+reason_generator: Any
|
| 73 |
+
+response_generator: Any
|
| 74 |
+
+reasoning: str
|
| 75 |
+
+response: Any
|
| 76 |
+
+has_reasoning: bool
|
| 77 |
+
+finished: bool
|
| 78 |
+
}
|
| 79 |
+
class Artifact {
|
| 80 |
+
+artifact_id: str
|
| 81 |
+
+artifact_type: ArtifactType
|
| 82 |
+
+content: Any
|
| 83 |
+
+metadata: Dict
|
| 84 |
+
}
|
| 85 |
+
class ToolCallOutput {
|
| 86 |
+
+tool_id: str
|
| 87 |
+
+content: Any
|
| 88 |
+
+metadata: Dict
|
| 89 |
+
}
|
| 90 |
+
class ToolResultOutput {
|
| 91 |
+
+tool_id: str
|
| 92 |
+
+params: Any
|
| 93 |
+
+content: Any
|
| 94 |
+
+metadata: Dict
|
| 95 |
+
}
|
| 96 |
+
class SearchOutput {
|
| 97 |
+
+artifact_id: str
|
| 98 |
+
+artifact_type: ArtifactType
|
| 99 |
+
+content: Any
|
| 100 |
+
+metadata: Dict
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
%% Inheritance Relationships
|
| 104 |
+
Outputs <|-- AsyncOutputs
|
| 105 |
+
Outputs <|-- DefaultOutputs
|
| 106 |
+
AsyncOutputs <|-- StreamingOutputs
|
| 107 |
+
Output <|-- MessageOutput
|
| 108 |
+
Output <|-- Artifact
|
| 109 |
+
Output <|-- ToolCallOutput
|
| 110 |
+
Output <|-- ToolResultOutput
|
| 111 |
+
ToolResultOutput <|-- SearchOutput
|
| 112 |
+
|
| 113 |
+
%% Aggregation/Composition
|
| 114 |
+
Outputs o-- Output
|
| 115 |
+
DefaultOutputs o-- Output
|
| 116 |
+
StreamingOutputs o-- Output
|
| 117 |
+
|
| 118 |
+
%% Workspace Related Classes
|
| 119 |
+
class WorkSpace {
|
| 120 |
+
+workspace_id: str
|
| 121 |
+
+name: str
|
| 122 |
+
+created_at: str
|
| 123 |
+
+updated_at: str
|
| 124 |
+
+metadata: Dict
|
| 125 |
+
+artifacts: List[Artifact]
|
| 126 |
+
+observers: List[WorkspaceObserver]
|
| 127 |
+
+repository: ArtifactRepository
|
| 128 |
+
+create_artifact()
|
| 129 |
+
+add_artifact()
|
| 130 |
+
+get_artifact()
|
| 131 |
+
+update_artifact()
|
| 132 |
+
+delete_artifact()
|
| 133 |
+
+list_artifacts()
|
| 134 |
+
+add_observer()
|
| 135 |
+
+remove_observer()
|
| 136 |
+
+save()
|
| 137 |
+
+load()
|
| 138 |
+
}
|
| 139 |
+
class WorkspaceObserver
|
| 140 |
+
class ArtifactRepository
|
| 141 |
+
|
| 142 |
+
WorkSpace o-- Artifact
|
| 143 |
+
WorkSpace o-- WorkspaceObserver
|
| 144 |
+
WorkSpace o-- ArtifactRepository
|
| 145 |
+
|
| 146 |
+
%% Task Related Classes
|
| 147 |
+
class Task {
|
| 148 |
+
+name: str
|
| 149 |
+
+input: Any
|
| 150 |
+
+outputs: Outputs
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
class Outputs
|
| 154 |
+
|
| 155 |
+
Task o-- Outputs
|
| 156 |
+
```
|
aworld/output/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from aworld.output.base import Output, SearchOutput, SearchItem, ToolResultOutput, MessageOutput, ToolCallOutput, \
|
| 2 |
+
RUN_FINISHED_SIGNAL
|
| 3 |
+
from aworld.output.artifact import Artifact, ArtifactType
|
| 4 |
+
from aworld.output.code_artifact import CodeArtifact, ShellArtifact
|
| 5 |
+
from aworld.output.outputs import Outputs, StreamingOutputs
|
| 6 |
+
from aworld.output.workspace import WorkSpace
|
| 7 |
+
from aworld.output.observer import WorkspaceObserver,get_observer
|
| 8 |
+
from aworld.output.storage.artifact_repository import ArtifactRepository, LocalArtifactRepository
|
| 9 |
+
from aworld.output.ui.base import AworldUI,PrinterAworldUI
|
| 10 |
+
__all__ = [
|
| 11 |
+
"Output",
|
| 12 |
+
"Artifact",
|
| 13 |
+
"ArtifactType",
|
| 14 |
+
"CodeArtifact",
|
| 15 |
+
"ShellArtifact",
|
| 16 |
+
"WorkSpace",
|
| 17 |
+
"ArtifactRepository",
|
| 18 |
+
"LocalArtifactRepository",
|
| 19 |
+
"WorkspaceObserver",
|
| 20 |
+
"get_observer",
|
| 21 |
+
"SearchOutput",
|
| 22 |
+
"SearchItem",
|
| 23 |
+
"MessageOutput",
|
| 24 |
+
"ToolCallOutput",
|
| 25 |
+
"ToolResultOutput",
|
| 26 |
+
"Outputs",
|
| 27 |
+
"StreamingOutputs",
|
| 28 |
+
"RUN_FINISHED_SIGNAL",
|
| 29 |
+
"AworldUI",
|
| 30 |
+
"PrinterAworldUI"
|
| 31 |
+
]
|
aworld/output/artifact.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from enum import Enum, auto
|
| 4 |
+
from typing import Dict, Any, Optional, ClassVar
|
| 5 |
+
from pydantic import Field, field_validator, model_validator, BaseModel
|
| 6 |
+
|
| 7 |
+
from aworld.output.base import Output
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ArtifactType(Enum):
|
| 11 |
+
"""Defines supported artifact types"""
|
| 12 |
+
TEXT = "TEXT"
|
| 13 |
+
CODE = "CODE"
|
| 14 |
+
MARKDOWN = "MARKDOWN"
|
| 15 |
+
HTML = "HTML"
|
| 16 |
+
SVG = "SVG"
|
| 17 |
+
IMAGE = "IMAGE"
|
| 18 |
+
JSON = "JSON"
|
| 19 |
+
CSV = "CSV"
|
| 20 |
+
TABLE = "TABLE"
|
| 21 |
+
CHART = "CHART"
|
| 22 |
+
DIAGRAM = "DIAGRAM"
|
| 23 |
+
MCP_CALL = "MCP_CALL"
|
| 24 |
+
TOOL_CALL = "TOOL_CALL"
|
| 25 |
+
LLM_OUTPUT = "LLM_OUTPUT"
|
| 26 |
+
WEB_PAGES = "WEB_PAGES"
|
| 27 |
+
DIR = "DIR"
|
| 28 |
+
CUSTOM = "CUSTOM"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ArtifactStatus(Enum):
|
| 33 |
+
"""Artifact status"""
|
| 34 |
+
DRAFT = auto() # Draft status
|
| 35 |
+
COMPLETE = auto() # Completed status
|
| 36 |
+
EDITED = auto() # Edited status
|
| 37 |
+
ARCHIVED = auto() # Archived status
|
| 38 |
+
|
| 39 |
+
class ArtifactAttachment(BaseModel):
|
| 40 |
+
filename: str = Field(..., description="Filename")
|
| 41 |
+
content: str = Field(..., description="Content", exclude=True)
|
| 42 |
+
mime_type: str = Field(..., description="MIME type")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Artifact(Output):
|
| 46 |
+
"""
|
| 47 |
+
Represents a specific content generation result (artifact)
|
| 48 |
+
|
| 49 |
+
Artifacts are the basic units of Artifacts technology, representing a structured content unit
|
| 50 |
+
Can be code, markdown, charts, and various other formats
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
artifact_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique identifier for the artifact")
|
| 54 |
+
artifact_type: ArtifactType = Field(..., description="Type of the artifact")
|
| 55 |
+
content: Any = Field(..., description="Content of the artifact")
|
| 56 |
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Metadata associated with the artifact")
|
| 57 |
+
created_at: str = Field(default_factory=lambda: datetime.now().isoformat(), description="Creation timestamp")
|
| 58 |
+
updated_at: str = Field(default_factory=lambda: datetime.now().isoformat(), description="Last updated timestamp")
|
| 59 |
+
status: ArtifactStatus = Field(default=ArtifactStatus.COMPLETE, description="Current status of the artifact")
|
| 60 |
+
current_version: str = Field(default="", description="Current version of the artifact")
|
| 61 |
+
version_history: list = Field(default_factory=list, description="History of versions for the artifact")
|
| 62 |
+
create_file: bool = Field(default=False, description="Flag to indicate if a file should be created")
|
| 63 |
+
attachments: Optional[list[ArtifactAttachment]] = Field(default_factory=list, description="Attachments associated with the artifact")
|
| 64 |
+
|
| 65 |
+
def _record_version(self, description: str) -> None:
|
| 66 |
+
"""Record current state as a new version"""
|
| 67 |
+
version = {
|
| 68 |
+
"timestamp": datetime.now().isoformat(),
|
| 69 |
+
"description": description,
|
| 70 |
+
"status": self.status
|
| 71 |
+
}
|
| 72 |
+
self.version_history.append(version)
|
| 73 |
+
self.updated_at = version["timestamp"]
|
| 74 |
+
|
| 75 |
+
def update_content(self, content: Any, description: str = "Content update") -> None:
|
| 76 |
+
"""
|
| 77 |
+
Update artifact content and record version
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
content: New content
|
| 81 |
+
description: Update description
|
| 82 |
+
"""
|
| 83 |
+
self.content = content
|
| 84 |
+
self.status = ArtifactStatus.EDITED
|
| 85 |
+
self._record_version(description)
|
| 86 |
+
|
| 87 |
+
def update_metadata(self, metadata: Dict[str, Any]) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Update artifact metadata
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
metadata: New metadata (will be merged with existing metadata)
|
| 93 |
+
"""
|
| 94 |
+
self.metadata.update(metadata)
|
| 95 |
+
self.updated_at = datetime.now().isoformat()
|
| 96 |
+
|
| 97 |
+
def mark_complete(self) -> None:
|
| 98 |
+
"""Mark the artifact as complete"""
|
| 99 |
+
self.status = ArtifactStatus.COMPLETE
|
| 100 |
+
self.updated_at = datetime.now().isoformat()
|
| 101 |
+
self._record_version("Marked as complete")
|
| 102 |
+
|
| 103 |
+
def archive(self) -> None:
|
| 104 |
+
"""Archive the artifact"""
|
| 105 |
+
self.status = ArtifactStatus.ARCHIVED
|
| 106 |
+
self._record_version("Artifact archived")
|
| 107 |
+
|
| 108 |
+
def get_version(self, index: int) -> Optional[Dict[str, Any]]:
|
| 109 |
+
"""Get version at the specified index"""
|
| 110 |
+
if 0 <= index < len(self.version_history):
|
| 111 |
+
return self.version_history[index]
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
def revert_to_version(self, index: int) -> bool:
|
| 115 |
+
"""Revert to a specific version"""
|
| 116 |
+
version = self.get_version(index)
|
| 117 |
+
if version:
|
| 118 |
+
self.content = version["content"]
|
| 119 |
+
self.status = version["status"]
|
| 120 |
+
self._record_version(f"Reverted to version {index}")
|
| 121 |
+
return True
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 125 |
+
"""Convert artifact to dictionary"""
|
| 126 |
+
return {
|
| 127 |
+
"artifact_id": self.artifact_id,
|
| 128 |
+
"artifact_type": self.artifact_type.value,
|
| 129 |
+
"content": self.content,
|
| 130 |
+
"metadata": self.metadata,
|
| 131 |
+
"created_at": self.created_at,
|
| 132 |
+
"updated_at": self.updated_at,
|
| 133 |
+
"status": self.status.name,
|
| 134 |
+
"version_count": len(self.version_history)
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def from_dict(cls, data: Dict[str, Any]) -> "Artifact":
|
| 139 |
+
"""Create an artifact instance from a dictionary"""
|
| 140 |
+
artifact_type = ArtifactType(data["artifact_type"])
|
| 141 |
+
artifact = cls(
|
| 142 |
+
artifact_type=artifact_type,
|
| 143 |
+
content=data["content"],
|
| 144 |
+
metadata=data["metadata"],
|
| 145 |
+
artifact_id=data.get("artifact_id", str(uuid.uuid4()))
|
| 146 |
+
)
|
| 147 |
+
artifact.created_at = data["created_at"]
|
| 148 |
+
artifact.updated_at = data["updated_at"]
|
| 149 |
+
artifact.status = ArtifactStatus[data["status"]]
|
| 150 |
+
|
| 151 |
+
# If version history exists, restore it as well
|
| 152 |
+
if "version_history" in data:
|
| 153 |
+
artifact.version_history = data["version_history"]
|
| 154 |
+
|
| 155 |
+
return artifact
|
aworld/output/base.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from builtins import anext
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Any, Dict, Generator, AsyncGenerator, Optional
|
| 6 |
+
|
| 7 |
+
from pydantic import Field, BaseModel, model_validator
|
| 8 |
+
|
| 9 |
+
from aworld.models.model_response import ModelResponse, ToolCall
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OutputPart(BaseModel):
|
| 13 |
+
content: Any
|
| 14 |
+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="metadata")
|
| 15 |
+
|
| 16 |
+
@model_validator(mode='after')
|
| 17 |
+
def setup_metadata(self):
|
| 18 |
+
# Ensure metadata is initialized
|
| 19 |
+
if self.metadata is None:
|
| 20 |
+
self.metadata = {}
|
| 21 |
+
return self
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Output(BaseModel):
|
| 25 |
+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="metadata")
|
| 26 |
+
parts: Any = Field(default_factory=list, exclude=True, description="parts of Output")
|
| 27 |
+
data: Any = Field(default=None, exclude=True, description="Output Data")
|
| 28 |
+
|
| 29 |
+
@model_validator(mode='after')
|
| 30 |
+
def setup_defaults(self):
|
| 31 |
+
# Ensure metadata and parts are initialized
|
| 32 |
+
if self.metadata is None:
|
| 33 |
+
self.metadata = {}
|
| 34 |
+
if self.parts is None:
|
| 35 |
+
self.parts = []
|
| 36 |
+
return self
|
| 37 |
+
|
| 38 |
+
def add_part(self, content: Any):
|
| 39 |
+
if self.parts is None:
|
| 40 |
+
self.parts = []
|
| 41 |
+
self.parts.append(OutputPart(content=content))
|
| 42 |
+
|
| 43 |
+
def output_type(self):
|
| 44 |
+
return "default"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class ToolCallOutput(Output):
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def from_tool_call(cls, tool_call: ToolCall):
|
| 51 |
+
return cls(data = tool_call)
|
| 52 |
+
|
| 53 |
+
def output_type(self):
|
| 54 |
+
return "tool_call"
|
| 55 |
+
|
| 56 |
+
class ToolResultOutput(Output):
|
| 57 |
+
|
| 58 |
+
origin_tool_call: Optional[ToolCall] = Field(default=None, description="origin tool call", exclude=True)
|
| 59 |
+
|
| 60 |
+
image: str = Field(default=None)
|
| 61 |
+
|
| 62 |
+
images: list[str] = Field(default_factory=list)
|
| 63 |
+
|
| 64 |
+
tool_type: str = Field(default=None)
|
| 65 |
+
|
| 66 |
+
tool_name: str = Field(default=None)
|
| 67 |
+
|
| 68 |
+
def output_type(self):
|
| 69 |
+
return "tool_call_result"
|
| 70 |
+
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
class RunFinishedSignal(Output):
|
| 74 |
+
|
| 75 |
+
def output_type(self):
|
| 76 |
+
return "finished_signal"
|
| 77 |
+
|
| 78 |
+
RUN_FINISHED_SIGNAL = RunFinishedSignal()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class MessageOutput(Output):
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
MessageOutput structure of LLM output
|
| 85 |
+
if you want to get the only response, you must first call reasoning_generator or set parameter only_response to True , then call response_generator
|
| 86 |
+
if you model not reasoning, you do not need care about reasoning_generator and reasoning
|
| 87 |
+
|
| 88 |
+
1. source: async/sync generator of the message
|
| 89 |
+
2. reasoning_generator: async/sync reasoning generator of the message
|
| 90 |
+
3. response_generator: async/sync response generator of the message;
|
| 91 |
+
4. reasoning: reasoning of the message
|
| 92 |
+
5. response: response of the message
|
| 93 |
+
6. tool_calls
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
source: Any = Field(default=None, exclude=True, description="Source of the message")
|
| 98 |
+
|
| 99 |
+
reason_generator: Any = Field(default=None, exclude=True, description="reasoning generator of the message")
|
| 100 |
+
response_generator: Any = Field(default=None, exclude=True, description="response generator of the message")
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
result
|
| 104 |
+
"""
|
| 105 |
+
reasoning: str = Field(default=None, description="reasoning of the message")
|
| 106 |
+
response: Any = Field(default=None, description="response of the message")
|
| 107 |
+
tool_calls: list[ToolCallOutput] = Field(default_factory=list, description="tool_calls")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
other config
|
| 112 |
+
"""
|
| 113 |
+
reasoning_format_start: str = Field(default="<think>", description="reasoning format start of the message")
|
| 114 |
+
reasoning_format_end: str = Field(default="</think>", description="reasoning format end of the message")
|
| 115 |
+
|
| 116 |
+
json_parse: bool = Field(default=False, description="json parse of the message", exclude=True)
|
| 117 |
+
has_reasoning: bool = Field(default=False, description="has reasoning of the message")
|
| 118 |
+
finished: bool = Field(default=False, description="finished of the message")
|
| 119 |
+
|
| 120 |
+
@model_validator(mode='after')
|
| 121 |
+
def setup_generators(self):
|
| 122 |
+
"""
|
| 123 |
+
Setup generators for reasoning and response
|
| 124 |
+
"""
|
| 125 |
+
source = self.source
|
| 126 |
+
|
| 127 |
+
# if ModelResponse
|
| 128 |
+
if isinstance(self.source, ModelResponse):
|
| 129 |
+
source = self.source.content
|
| 130 |
+
if self.source.tool_calls:
|
| 131 |
+
[self.tool_calls.append(ToolCallOutput.from_tool_call(tool_call)) for tool_call in
|
| 132 |
+
self.source.tool_calls]
|
| 133 |
+
|
| 134 |
+
if source is not None and isinstance(source, AsyncGenerator):
|
| 135 |
+
# Create empty generators first, they will be initialized when actually used
|
| 136 |
+
self.reason_generator = self.__aget_reasoning_generator()
|
| 137 |
+
self.response_generator = self.__aget_response_generator()
|
| 138 |
+
elif source is not None and isinstance(source, Generator):
|
| 139 |
+
self.reason_generator, self.response_generator = self.__split_reasoning_and_response__()
|
| 140 |
+
elif source is not None and isinstance(source, str):
|
| 141 |
+
self.reasoning, self.response = self.__resolve_think__(source)
|
| 142 |
+
return self
|
| 143 |
+
|
| 144 |
+
async def get_finished_reasoning(self):
|
| 145 |
+
if self.reasoning:
|
| 146 |
+
return self.reasoning
|
| 147 |
+
else:
|
| 148 |
+
if self.has_reasoning and not self.reasoning:
|
| 149 |
+
async for reason in self.reason_generator:
|
| 150 |
+
pass
|
| 151 |
+
return self.reasoning
|
| 152 |
+
else:
|
| 153 |
+
return self.reasoning
|
| 154 |
+
|
| 155 |
+
async def get_finished_response(self):
|
| 156 |
+
if self.response:
|
| 157 |
+
return self.response
|
| 158 |
+
else:
|
| 159 |
+
if self.response_generator:
|
| 160 |
+
async for item in self.response_generator:
|
| 161 |
+
pass
|
| 162 |
+
return self.response
|
| 163 |
+
|
| 164 |
+
async def __aget_reasoning_generator(self) -> AsyncGenerator[str, None]:
|
| 165 |
+
"""
|
| 166 |
+
Get reasoning content as async generator
|
| 167 |
+
"""
|
| 168 |
+
if not self.has_reasoning:
|
| 169 |
+
yield ""
|
| 170 |
+
self.reasoning = ""
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
reasoning_buffer = ""
|
| 174 |
+
is_in_reasoning = False
|
| 175 |
+
if self.reasoning and len(self.reasoning) > 0:
|
| 176 |
+
yield self.reasoning
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
while True:
|
| 181 |
+
chunk = await anext(self.source)
|
| 182 |
+
chunk_content = self.get_chunk_content(chunk)
|
| 183 |
+
if not chunk_content:
|
| 184 |
+
continue
|
| 185 |
+
if chunk_content.startswith(self.reasoning_format_start):
|
| 186 |
+
is_in_reasoning = True
|
| 187 |
+
reasoning_buffer = chunk_content
|
| 188 |
+
yield chunk_content
|
| 189 |
+
elif chunk_content.endswith(self.reasoning_format_end) and is_in_reasoning:
|
| 190 |
+
reasoning_buffer += chunk_content
|
| 191 |
+
yield chunk_content
|
| 192 |
+
self.reasoning = reasoning_buffer
|
| 193 |
+
break
|
| 194 |
+
elif is_in_reasoning:
|
| 195 |
+
reasoning_buffer += chunk_content
|
| 196 |
+
yield chunk_content
|
| 197 |
+
except StopAsyncIteration:
|
| 198 |
+
logging.info("StopAsyncIteration")
|
| 199 |
+
|
| 200 |
+
async def __aget_response_generator(self) -> AsyncGenerator[str, None]:
|
| 201 |
+
"""
|
| 202 |
+
Get response content as async generator
|
| 203 |
+
|
| 204 |
+
if has_reasoning is True, system will first call reasoning_generator if you not call it;
|
| 205 |
+
else it will return content contains reasoning and response
|
| 206 |
+
"""
|
| 207 |
+
response_buffer = ""
|
| 208 |
+
|
| 209 |
+
if self.response and len(self.response) > 0:
|
| 210 |
+
yield self.response
|
| 211 |
+
return
|
| 212 |
+
|
| 213 |
+
# if has_reasoning is True, system will first call reasoning_generator if you not call it;
|
| 214 |
+
if self.has_reasoning and not self.reasoning:
|
| 215 |
+
async for reason in self.reason_generator:
|
| 216 |
+
pass
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
while True:
|
| 220 |
+
chunk = await anext(self.source)
|
| 221 |
+
chunk_content = self.get_chunk_content(chunk)
|
| 222 |
+
|
| 223 |
+
if not chunk_content:
|
| 224 |
+
continue
|
| 225 |
+
response_buffer += chunk_content
|
| 226 |
+
yield chunk_content
|
| 227 |
+
except StopAsyncIteration:
|
| 228 |
+
self.finished = True
|
| 229 |
+
self.response = self.__resolve_json__(response_buffer, self.json_parse)
|
| 230 |
+
|
| 231 |
+
def get_chunk_content(self, chunk):
|
| 232 |
+
if isinstance(chunk, ModelResponse):
|
| 233 |
+
return chunk.content
|
| 234 |
+
else:
|
| 235 |
+
return chunk
|
| 236 |
+
|
| 237 |
+
def __split_reasoning_and_response__(self) -> tuple[Generator[str, None, None], Generator[str, None, None]]: # type: ignore
|
| 238 |
+
"""
|
| 239 |
+
Split source into reasoning and response generators for sync source
|
| 240 |
+
Returns:
|
| 241 |
+
tuple: (reasoning_generator, response_generator)
|
| 242 |
+
"""
|
| 243 |
+
if not self.has_reasoning:
|
| 244 |
+
yield ""
|
| 245 |
+
self.reasoning = ""
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
if not isinstance(self.source, Generator):
|
| 249 |
+
raise ValueError("Source must be a Generator")
|
| 250 |
+
|
| 251 |
+
def reasoning_generator():
|
| 252 |
+
if self.reasoning and len(self.reasoning) > 0:
|
| 253 |
+
yield self.reasoning
|
| 254 |
+
return
|
| 255 |
+
|
| 256 |
+
reasoning_buffer = ""
|
| 257 |
+
is_in_reasoning = False
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
while True:
|
| 261 |
+
chunk = next(self.source)
|
| 262 |
+
chunk_content = self.get_chunk_content(chunk)
|
| 263 |
+
if chunk_content.startswith(self.reasoning_format_start):
|
| 264 |
+
is_in_reasoning = True
|
| 265 |
+
reasoning_buffer = chunk_content
|
| 266 |
+
yield chunk_content
|
| 267 |
+
elif chunk_content.endswith(self.reasoning_format_end) and is_in_reasoning:
|
| 268 |
+
reasoning_buffer += chunk_content
|
| 269 |
+
self.reasoning = reasoning_buffer
|
| 270 |
+
yield chunk_content
|
| 271 |
+
break
|
| 272 |
+
elif is_in_reasoning:
|
| 273 |
+
yield chunk_content
|
| 274 |
+
reasoning_buffer += chunk_content
|
| 275 |
+
except StopIteration:
|
| 276 |
+
print("StopIteration")
|
| 277 |
+
self.reasoning = reasoning_buffer
|
| 278 |
+
|
| 279 |
+
def response_generator():
|
| 280 |
+
if self.response and len(self.response) > 0:
|
| 281 |
+
yield self.response
|
| 282 |
+
return
|
| 283 |
+
|
| 284 |
+
# if has_reasoning is True, system will first call reasoning_generator if you not call it;
|
| 285 |
+
if self.has_reasoning and not self.reasoning:
|
| 286 |
+
for reason in self.reason_generator:
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
response_buffer = ""
|
| 290 |
+
try:
|
| 291 |
+
while True:
|
| 292 |
+
chunk = next(self.source)
|
| 293 |
+
chunk_content = self.get_chunk_content(chunk)
|
| 294 |
+
response_buffer += chunk_content
|
| 295 |
+
self.response = response_buffer
|
| 296 |
+
yield chunk_content
|
| 297 |
+
except StopIteration:
|
| 298 |
+
self.response = self.__resolve_json__(response_buffer,self.json_parse)
|
| 299 |
+
self.finished = True
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
return reasoning_generator(), response_generator()
|
| 303 |
+
|
| 304 |
+
def __resolve_think__(self, content):
|
| 305 |
+
import re
|
| 306 |
+
start_tag = self.reasoning_format_start.replace("<", "").replace(">", "")
|
| 307 |
+
end_tag = self.reasoning_format_end.replace("<", "").replace(">", "")
|
| 308 |
+
|
| 309 |
+
llm_think = ""
|
| 310 |
+
match = re.search(
|
| 311 |
+
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
|
| 312 |
+
content,
|
| 313 |
+
flags=re.DOTALL,
|
| 314 |
+
)
|
| 315 |
+
if match:
|
| 316 |
+
llm_think = match.group(0).replace("<think>", "").replace("</think>", "")
|
| 317 |
+
llm_result = re.sub(
|
| 318 |
+
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
|
| 319 |
+
"",
|
| 320 |
+
content,
|
| 321 |
+
flags=re.DOTALL,
|
| 322 |
+
)
|
| 323 |
+
llm_result = self.__resolve_json__(llm_result, self.json_parse)
|
| 324 |
+
|
| 325 |
+
return llm_think, llm_result
|
| 326 |
+
|
| 327 |
+
def __resolve_json__(self, content, json_parse = False):
|
| 328 |
+
if json_parse:
|
| 329 |
+
if content.__contains__("```json"):
|
| 330 |
+
content = content.replace("```json", "").replace("```", "")
|
| 331 |
+
return json.loads(content)
|
| 332 |
+
return content
|
| 333 |
+
|
| 334 |
+
def output_type(self):
|
| 335 |
+
return "message_output"
|
| 336 |
+
|
| 337 |
+
class StepOutput(Output):
|
| 338 |
+
name: str
|
| 339 |
+
step_num: int
|
| 340 |
+
alias_name: Optional[str] = Field(default=None, description="alias_name of the step")
|
| 341 |
+
status: Optional[str] = Field(default="START", description="step_status")
|
| 342 |
+
started_at: str = Field(default_factory=lambda: datetime.now().isoformat(), description="started at")
|
| 343 |
+
finished_at: str = Field(default_factory=lambda: datetime.now().isoformat(), description="finished at")
|
| 344 |
+
|
| 345 |
+
@classmethod
|
| 346 |
+
def build_start_output(cls, name, step_num, alias_name=None, data=None):
|
| 347 |
+
return cls(name=name, step_num=step_num, alias_name=alias_name, data=data)
|
| 348 |
+
|
| 349 |
+
@classmethod
|
| 350 |
+
def build_finished_output(cls, name, step_num, alias_name=None, data=None):
|
| 351 |
+
return cls(name=name, step_num=step_num, alias_name=alias_name, status='FINISHED', data=data)
|
| 352 |
+
|
| 353 |
+
@classmethod
|
| 354 |
+
def build_failed_output(cls, name, step_num, alias_name=None, data=None):
|
| 355 |
+
return cls(name=name, step_num=step_num, alias_name=alias_name, status='FAILED', data=data)
|
| 356 |
+
|
| 357 |
+
def output_type(self):
|
| 358 |
+
return "step_output"
|
| 359 |
+
|
| 360 |
+
@property
|
| 361 |
+
def show_name(self):
|
| 362 |
+
return self.alias_name if self.alias_name else self.name
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
class SearchItem(BaseModel):
|
| 367 |
+
title: str = Field(default="", description="search result title")
|
| 368 |
+
url: str = Field(default="", description="search result url")
|
| 369 |
+
snippet: str = Field(default="", description="search result snippet")
|
| 370 |
+
content: str = Field(default="", description="search content", exclude=True)
|
| 371 |
+
raw_content: Optional[str] = Field(default="", description="search raw content", exclude=True)
|
| 372 |
+
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict, description="metadata")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class SearchOutput(ToolResultOutput):
|
| 376 |
+
query: str = Field(..., description="Search query string")
|
| 377 |
+
results: list[SearchItem] = Field(default_factory=list, description="List of search results")
|
| 378 |
+
|
| 379 |
+
@classmethod
|
| 380 |
+
def from_dict(cls, data: dict) -> "SearchOutput":
|
| 381 |
+
if not isinstance(data, dict):
|
| 382 |
+
data = {}
|
| 383 |
+
|
| 384 |
+
query = data.get("query")
|
| 385 |
+
if query is None:
|
| 386 |
+
raise ValueError("query is required")
|
| 387 |
+
|
| 388 |
+
results_data = data.get("results", [])
|
| 389 |
+
|
| 390 |
+
search_items = []
|
| 391 |
+
for result in results_data:
|
| 392 |
+
if isinstance(result, SearchItem):
|
| 393 |
+
search_items.append(result)
|
| 394 |
+
elif isinstance(result, dict):
|
| 395 |
+
search_items.append(SearchItem(**result))
|
| 396 |
+
else:
|
| 397 |
+
raise ValueError(f"Invalid result type: {type(result)}")
|
| 398 |
+
|
| 399 |
+
return cls(
|
| 400 |
+
query=query,
|
| 401 |
+
results=search_items
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def output_type(self):
|
| 405 |
+
return "search_output"
|
aworld/output/code_artifact.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from typing import Any, Optional, Dict, List
|
| 3 |
+
|
| 4 |
+
from pydantic import Field
|
| 5 |
+
|
| 6 |
+
from aworld.output.artifact import Artifact, ArtifactType, ArtifactAttachment
|
| 7 |
+
|
| 8 |
+
CODE_FILE_EXTENSION_MAP = {
|
| 9 |
+
"python": "py",
|
| 10 |
+
"java": "java",
|
| 11 |
+
"javascript": "js",
|
| 12 |
+
"typescript": "ts",
|
| 13 |
+
"html": "html",
|
| 14 |
+
"css": "css",
|
| 15 |
+
"c": "c",
|
| 16 |
+
"cpp": "cpp",
|
| 17 |
+
"csharp": "cs",
|
| 18 |
+
"go": "go",
|
| 19 |
+
"rust": "rs",
|
| 20 |
+
"ruby": "rb",
|
| 21 |
+
"php": "php",
|
| 22 |
+
"swift": "swift",
|
| 23 |
+
"kotlin": "kt",
|
| 24 |
+
"scala": "scala",
|
| 25 |
+
"markdown": "md",
|
| 26 |
+
"txt": "txt",
|
| 27 |
+
"shell": "sh",
|
| 28 |
+
"bash": "sh",
|
| 29 |
+
"sh": "sh",
|
| 30 |
+
"zsh": "zsh",
|
| 31 |
+
"powershell": "ps1",
|
| 32 |
+
"cmd": "cmd",
|
| 33 |
+
"bat": "bat"
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class CodeArtifact(Artifact):
|
| 38 |
+
code_interceptor: Any = Field(default=None, description="code executor type")
|
| 39 |
+
|
| 40 |
+
def __init__(self, artifact_type: ArtifactType, content: Any, code_type: Optional[str], code_version: Optional[str],
|
| 41 |
+
code_interceptor_provider: Optional[str] = None,
|
| 42 |
+
artifact_id: Optional[str] = None, render_type: Optional[str] = None, **kwargs):
|
| 43 |
+
# Extract filename from the first line of the content
|
| 44 |
+
filename = self.extract_filename(content)
|
| 45 |
+
|
| 46 |
+
# Initialize metadata, including any passed in kwargs
|
| 47 |
+
metadata = {
|
| 48 |
+
"code_type": code_type,
|
| 49 |
+
"code_version": code_version,
|
| 50 |
+
"code_interceptor_provider": code_interceptor_provider,
|
| 51 |
+
"filename": filename # Store filename in metadata
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# Merge additional metadata from kwargs if provided
|
| 55 |
+
if 'metadata' in kwargs:
|
| 56 |
+
metadata.update(kwargs['metadata'])
|
| 57 |
+
del kwargs['metadata'] # Remove metadata from kwargs to avoid multiple values
|
| 58 |
+
|
| 59 |
+
super().__init__(
|
| 60 |
+
artifact_type=artifact_type,
|
| 61 |
+
content=content,
|
| 62 |
+
metadata=metadata,
|
| 63 |
+
artifact_id=artifact_id,
|
| 64 |
+
render_type=render_type,
|
| 65 |
+
**kwargs
|
| 66 |
+
)
|
| 67 |
+
self.archive()
|
| 68 |
+
self.code_interceptor = self.init_code_interceptor(code_interceptor_provider)
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def extract_filename(content: Any) -> Optional[str]:
|
| 72 |
+
"""Extract filename from the first line of the code block comment."""
|
| 73 |
+
if isinstance(content, str):
|
| 74 |
+
lines = content.splitlines()
|
| 75 |
+
if lines:
|
| 76 |
+
first_line = lines[0].strip()
|
| 77 |
+
# Check if the first line is a shebang for bash or other interpreters
|
| 78 |
+
if first_line in ["# /bin/bash", "#!/bin/bash", "#!/usr/bin/env bash",
|
| 79 |
+
"#!/bin/sh", "#!/usr/bin/env python",
|
| 80 |
+
"#!/usr/bin/env python3"]:
|
| 81 |
+
return None # Do not return a filename
|
| 82 |
+
# Check for common comment styles in various languages
|
| 83 |
+
if first_line.startswith("#"): # Python, Ruby, Shell
|
| 84 |
+
return first_line[1:].strip() # Remove the comment symbol
|
| 85 |
+
elif first_line.startswith("//"): # Java, JavaScript, C, C++
|
| 86 |
+
return first_line[2:].strip() # Remove the comment symbol
|
| 87 |
+
elif first_line.startswith("/*") and "*/" in first_line: # C, C++
|
| 88 |
+
return first_line.split("*/")[0][2:].strip() # Remove comment symbols
|
| 89 |
+
elif first_line.startswith("<!--"): # HTML
|
| 90 |
+
return first_line[4:].strip() # Remove the comment symbol
|
| 91 |
+
# Add more languages as needed
|
| 92 |
+
return None # Return None if filename is unknown
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def build_artifact(cls,
|
| 96 |
+
content: Any,
|
| 97 |
+
code_type: Optional[str] = None,
|
| 98 |
+
code_version: Optional[str] = None,
|
| 99 |
+
code_interceptor_provider: Optional[str] = None,
|
| 100 |
+
artifact_id: Optional[str] = None,
|
| 101 |
+
render_type: Optional[str] = None,
|
| 102 |
+
**kwargs) -> "CodeArtifact":
|
| 103 |
+
|
| 104 |
+
# Create CodeArtifact instance
|
| 105 |
+
if code_type in ['shell', 'sh', 'bash', 'zsh']:
|
| 106 |
+
return ShellArtifact(
|
| 107 |
+
artifact_type=ArtifactType.CODE,
|
| 108 |
+
content=content,
|
| 109 |
+
code_version=code_version,
|
| 110 |
+
code_interceptor_provider=code_interceptor_provider,
|
| 111 |
+
artifact_id=artifact_id,
|
| 112 |
+
render_type=render_type,
|
| 113 |
+
**kwargs
|
| 114 |
+
)
|
| 115 |
+
elif code_type in ['html']:
|
| 116 |
+
return HtmlArtifact(
|
| 117 |
+
content=content,
|
| 118 |
+
artifact_id=artifact_id,
|
| 119 |
+
**kwargs
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return cls(
|
| 123 |
+
artifact_type=ArtifactType.CODE,
|
| 124 |
+
content=content,
|
| 125 |
+
code_type=code_type,
|
| 126 |
+
code_version=code_version,
|
| 127 |
+
code_interceptor_provider=code_interceptor_provider,
|
| 128 |
+
artifact_id=artifact_id,
|
| 129 |
+
render_type=render_type,
|
| 130 |
+
**kwargs
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def from_code_content(cls, artifact_type: ArtifactType,
|
| 135 |
+
content: Any,
|
| 136 |
+
render_type: Optional[str] = None,
|
| 137 |
+
**kwargs) -> List["CodeArtifact"]:
|
| 138 |
+
code_blocks = cls.extract_model_output_to_code_content(content) # Extract code blocks
|
| 139 |
+
artifacts = [] # List to store CodeArtifact instances
|
| 140 |
+
|
| 141 |
+
for block in code_blocks:
|
| 142 |
+
code_type = block['language']
|
| 143 |
+
code_version = "1.0"
|
| 144 |
+
|
| 145 |
+
if code_type in ['python', 'javascript', 'java']:
|
| 146 |
+
code_interceptor_provider = "default_interceptor"
|
| 147 |
+
elif code_type in ['shell', 'sh', 'bash', 'zsh']:
|
| 148 |
+
code_interceptor_provider = "shell_interceptor"
|
| 149 |
+
else:
|
| 150 |
+
code_interceptor_provider = "generic_interceptor"
|
| 151 |
+
|
| 152 |
+
artifact = cls.create_artifact(
|
| 153 |
+
artifact_type=ArtifactType.CODE,
|
| 154 |
+
content=block['content'],
|
| 155 |
+
code_type=code_type,
|
| 156 |
+
code_version=code_version,
|
| 157 |
+
code_interceptor_provider=code_interceptor_provider,
|
| 158 |
+
artifact_id=block['artifact_id'], # Use extracted artifact_id
|
| 159 |
+
render_type=render_type,
|
| 160 |
+
**kwargs
|
| 161 |
+
)
|
| 162 |
+
artifacts.append(artifact) # Add to the list
|
| 163 |
+
|
| 164 |
+
return artifacts # Return the list of CodeArtifact instances
|
| 165 |
+
|
| 166 |
+
def init_code_interceptor(self, code_interceptor_provider):
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
@classmethod
|
| 170 |
+
def extract_model_output_to_code_content(cls, content):
|
| 171 |
+
"""
|
| 172 |
+
Extract code blocks from markdown content using mistune.
|
| 173 |
+
|
| 174 |
+
First extracts all code blocks enclosed in triple backticks,
|
| 175 |
+
then determines the language for each block.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
import mistune
|
| 180 |
+
except ImportError:
|
| 181 |
+
# install mistune
|
| 182 |
+
import subprocess
|
| 183 |
+
subprocess.run(["pip", "install", "mistune>=3.0.0"], check=True)
|
| 184 |
+
import mistune
|
| 185 |
+
|
| 186 |
+
code_blocks = []
|
| 187 |
+
|
| 188 |
+
#
|
| 189 |
+
extracted_blocks = []
|
| 190 |
+
|
| 191 |
+
# create custom Render
|
| 192 |
+
class CustomRenderer(mistune.HTMLRenderer):
|
| 193 |
+
def block_code(self, code, info=None):
|
| 194 |
+
language = info.split()[0] if info else 'unknown'
|
| 195 |
+
extracted_blocks.append({
|
| 196 |
+
"content": code,
|
| 197 |
+
"language": language
|
| 198 |
+
})
|
| 199 |
+
return ""
|
| 200 |
+
|
| 201 |
+
# create Markdown render
|
| 202 |
+
renderer = CustomRenderer()
|
| 203 |
+
markdown = mistune.create_markdown(
|
| 204 |
+
renderer=renderer
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# resolve markdown
|
| 208 |
+
markdown(content)
|
| 209 |
+
|
| 210 |
+
# process codeblocks
|
| 211 |
+
for block in extracted_blocks:
|
| 212 |
+
artifact_id = str(uuid.uuid4())
|
| 213 |
+
language = block['language']
|
| 214 |
+
file_suffix = CODE_FILE_EXTENSION_MAP.get(language, "txt")
|
| 215 |
+
|
| 216 |
+
code_blocks.append({
|
| 217 |
+
"artifact_id": artifact_id,
|
| 218 |
+
"content": block['content'],
|
| 219 |
+
"language": language,
|
| 220 |
+
"file_suffix": file_suffix
|
| 221 |
+
})
|
| 222 |
+
|
| 223 |
+
return code_blocks
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class ShellArtifact(CodeArtifact):
|
| 227 |
+
shell_result: str = Field(default="", description="shell execution result")
|
| 228 |
+
|
| 229 |
+
def __init__(self, artifact_type: ArtifactType, content: Any, code_version: str,
|
| 230 |
+
code_interceptor_provider: Optional[str] = None,
|
| 231 |
+
artifact_id: Optional[str] = None, render_type: Optional[str] = None,
|
| 232 |
+
shell_result: str = "", **kwargs):
|
| 233 |
+
|
| 234 |
+
code_type = "shell"
|
| 235 |
+
|
| 236 |
+
# extract filename
|
| 237 |
+
filename = self.extract_filename(content)
|
| 238 |
+
|
| 239 |
+
# default set terminal.txt
|
| 240 |
+
if not filename:
|
| 241 |
+
filename = "terminal.txt"
|
| 242 |
+
|
| 243 |
+
# update metadata
|
| 244 |
+
metadata = kwargs.get('metadata', {})
|
| 245 |
+
metadata['filename'] = filename
|
| 246 |
+
|
| 247 |
+
# setting code_interceptor_provider
|
| 248 |
+
if code_interceptor_provider is None:
|
| 249 |
+
code_interceptor_provider = "shell_interceptor"
|
| 250 |
+
|
| 251 |
+
super().__init__(artifact_type, content, code_type, code_version,
|
| 252 |
+
code_interceptor_provider, artifact_id, render_type, metadata=metadata, **kwargs)
|
| 253 |
+
self.shell_result = shell_result
|
| 254 |
+
|
| 255 |
+
def execute(self):
|
| 256 |
+
# todo add
|
| 257 |
+
pass
|
| 258 |
+
|
| 259 |
+
class HtmlArtifact(CodeArtifact):
|
| 260 |
+
|
| 261 |
+
def __init__(self, content: Any, artifact_id: Optional[str] = None, **kwargs):
|
| 262 |
+
# Remove artifact_type from kwargs if it exists to avoid conflicts
|
| 263 |
+
kwargs.pop('artifact_type', None)
|
| 264 |
+
|
| 265 |
+
super().__init__(
|
| 266 |
+
artifact_type=ArtifactType.HTML,
|
| 267 |
+
content=content,
|
| 268 |
+
code_type='html',
|
| 269 |
+
code_version='1.0',
|
| 270 |
+
artifact_id=artifact_id,
|
| 271 |
+
**kwargs
|
| 272 |
+
)
|
| 273 |
+
content = content.replace("```html", "").replace("```", "")
|
| 274 |
+
self.content = None
|
| 275 |
+
self.attachments.append(
|
| 276 |
+
ArtifactAttachment(filename=f"{artifact_id}.html", content=content, mime_type="text/html")
|
| 277 |
+
)
|
aworld/output/observer.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import Callable, List, Dict, Any, Optional, Union
|
| 4 |
+
import inspect
|
| 5 |
+
|
| 6 |
+
from aworld.output.artifact import Artifact
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class WorkspaceObserver:
|
| 10 |
+
"""Base class for workspace observers"""
|
| 11 |
+
|
| 12 |
+
async def on_create(self, workspace_id: str, artifact: Artifact) -> None:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
async def on_update(self, workspace_id: str, artifact: Artifact) -> None:
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
async def on_delete(self, workspace_id: str, artifact: Artifact) -> None:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
class Handler:
|
| 22 |
+
"""Handler wrapper to support both functions and class methods"""
|
| 23 |
+
def __init__(self, func: Callable, instance=None, workspace_id: Optional[str] = None, filters: Optional[Dict[str, Any]] = None):
|
| 24 |
+
self.func = func
|
| 25 |
+
self.instance = instance # Class instance if method
|
| 26 |
+
self.workspace_id = workspace_id # Specific workspace to monitor
|
| 27 |
+
self.filters = filters or {} # Additional filters (e.g., artifact_type)
|
| 28 |
+
|
| 29 |
+
async def __call__(self, artifact: Artifact, **kwargs) -> Any:
|
| 30 |
+
"""Call the handler with appropriate arguments"""
|
| 31 |
+
# Check if this handler should process the artifact
|
| 32 |
+
if self.workspace_id and kwargs.get('workspace_id') != self.workspace_id:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
# Check additional filters
|
| 36 |
+
for key, value in self.filters.items():
|
| 37 |
+
if key == 'artifact_type':
|
| 38 |
+
if artifact.artifact_type != value and artifact.artifact_type.value != value:
|
| 39 |
+
return None
|
| 40 |
+
elif key in artifact.metadata and artifact.metadata[key] != value:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
# Get function signature to determine what arguments it expects
|
| 44 |
+
sig = inspect.signature(self.func)
|
| 45 |
+
param_count = len(sig.parameters)
|
| 46 |
+
|
| 47 |
+
# Call based on whether it's a method or function, and parameter count
|
| 48 |
+
if self.instance:
|
| 49 |
+
if param_count == 0: # Just self
|
| 50 |
+
return await self.func() if inspect.iscoroutinefunction(self.func) else self.func()
|
| 51 |
+
elif param_count == 1: # Self + artifact
|
| 52 |
+
return await self.func(artifact) if inspect.iscoroutinefunction(self.func) else self.func(artifact)
|
| 53 |
+
else: # Self + artifact + kwargs
|
| 54 |
+
return await self.func(artifact, **kwargs) if inspect.iscoroutinefunction(self.func) else self.func(artifact, **kwargs)
|
| 55 |
+
else:
|
| 56 |
+
if param_count == 0: # No parameters
|
| 57 |
+
return await self.func() if inspect.iscoroutinefunction(self.func) else self.func()
|
| 58 |
+
elif param_count == 1: # Just artifact
|
| 59 |
+
return await self.func(artifact) if inspect.iscoroutinefunction(self.func) else self.func(artifact)
|
| 60 |
+
else: # Artifact + kwargs
|
| 61 |
+
return await self.func(artifact, **kwargs) if inspect.iscoroutinefunction(self.func) else self.func(artifact, **kwargs)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DecoratorBasedObserver(WorkspaceObserver):
|
| 65 |
+
"""Enhanced decorator-based observer implementation"""
|
| 66 |
+
def __init__(self):
|
| 67 |
+
self.create_handlers: List[Handler] = []
|
| 68 |
+
self.update_handlers: List[Handler] = []
|
| 69 |
+
self.delete_handlers: List[Handler] = []
|
| 70 |
+
|
| 71 |
+
async def on_create(self, workspace_id: str, artifact: Artifact, **kwargs) -> List[Any]:
|
| 72 |
+
"""Process artifact creation with all handlers"""
|
| 73 |
+
results = []
|
| 74 |
+
for handler in self.create_handlers:
|
| 75 |
+
try:
|
| 76 |
+
result = await handler(workspace_id=workspace_id, artifact=artifact, **kwargs)
|
| 77 |
+
if result is not None:
|
| 78 |
+
results.append(result)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Create handler failed: error is {e}: {traceback.format_exc()}")
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
async def on_update(self, artifact: Artifact, **kwargs) -> List[Any]:
|
| 84 |
+
"""Process artifact update with all handlers"""
|
| 85 |
+
results = []
|
| 86 |
+
for handler in self.update_handlers:
|
| 87 |
+
try:
|
| 88 |
+
result = await handler(artifact, **kwargs)
|
| 89 |
+
if result is not None:
|
| 90 |
+
results.append(result)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f"Update handler failed: {e}")
|
| 93 |
+
return results
|
| 94 |
+
|
| 95 |
+
async def on_delete(self, artifact: Artifact, **kwargs) -> List[Any]:
|
| 96 |
+
"""Process artifact deletion with all handlers"""
|
| 97 |
+
results = []
|
| 98 |
+
for handler in self.delete_handlers:
|
| 99 |
+
try:
|
| 100 |
+
result = await handler(artifact, **kwargs)
|
| 101 |
+
if result is not None:
|
| 102 |
+
results.append(result)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Delete handler failed: {e}")
|
| 105 |
+
return results
|
| 106 |
+
|
| 107 |
+
def register_create_handler(self, func, instance=None, workspace_id=None, filters=None):
|
| 108 |
+
"""Register a handler for artifact creation"""
|
| 109 |
+
logging.info(f"[📂WORKSPACE]✨ Registering create handler for {func}")
|
| 110 |
+
self.create_handlers.append(Handler(func, instance, workspace_id, filters))
|
| 111 |
+
return func
|
| 112 |
+
|
| 113 |
+
def un_register_create_handler(self, func, instance=None, workspace_id=None):
|
| 114 |
+
"""Register a handler for artifact creation"""
|
| 115 |
+
logging.info(f"[📂WORKSPACE] UnRegister create handler for {func}")
|
| 116 |
+
for handler in self.create_handlers:
|
| 117 |
+
if handler.func == func:
|
| 118 |
+
self.create_handlers.remove(handler)
|
| 119 |
+
logging.info(f"[📂WORKSPACE] UnRegister create handler for {func} success")
|
| 120 |
+
|
| 121 |
+
def register_update_handler(self, func, instance=None, workspace_id=None, filters=None):
|
| 122 |
+
"""Register a handler for artifact update"""
|
| 123 |
+
logging.info(f"[📂WORKSPACE]✨ Registering update handler for {func}")
|
| 124 |
+
self.update_handlers.append(Handler(func, instance, workspace_id, filters))
|
| 125 |
+
return func
|
| 126 |
+
|
| 127 |
+
def register_delete_handler(self, func, instance=None, workspace_id=None, filters=None):
|
| 128 |
+
"""Register a handler for artifact deletion"""
|
| 129 |
+
logging.info(f"[📂WORKSPACE]✨ Registering delete handler for {func}")
|
| 130 |
+
self.delete_handlers.append(Handler(func, instance, workspace_id, filters))
|
| 131 |
+
return func
|
| 132 |
+
|
| 133 |
+
# Global observer instance
|
| 134 |
+
_observer = DecoratorBasedObserver()
|
| 135 |
+
|
| 136 |
+
def get_observer() -> DecoratorBasedObserver:
|
| 137 |
+
"""Get the global observer instance"""
|
| 138 |
+
return _observer
|
| 139 |
+
|
| 140 |
+
def on_artifact_create(func=None, workspace_id=None, filters=None):
|
| 141 |
+
"""
|
| 142 |
+
Decorator for artifact creation handlers
|
| 143 |
+
|
| 144 |
+
Can be used as a simple decorator (@on_artifact_create) or with parameters:
|
| 145 |
+
@on_artifact_create(workspace_id='abc', filters={'artifact_type': 'WEB_PAGES'})
|
| 146 |
+
"""
|
| 147 |
+
if func is None:
|
| 148 |
+
# Called with parameters
|
| 149 |
+
def decorator(f):
|
| 150 |
+
return _observer.register_create_handler(f, None, workspace_id, filters)
|
| 151 |
+
return decorator
|
| 152 |
+
|
| 153 |
+
# Called as simple decorator
|
| 154 |
+
return _observer.register_create_handler(func)
|
| 155 |
+
|
| 156 |
+
def on_artifact_update(func=None, workspace_id=None, filters=None):
|
| 157 |
+
"""
|
| 158 |
+
Decorator for artifact update handlers
|
| 159 |
+
|
| 160 |
+
Can be used as a simple decorator (@on_artifact_update) or with parameters:
|
| 161 |
+
@on_artifact_update(workspace_id='abc', filters={'artifact_type': 'WEB_PAGES'})
|
| 162 |
+
"""
|
| 163 |
+
if func is None:
|
| 164 |
+
# Called with parameters
|
| 165 |
+
def decorator(f):
|
| 166 |
+
return _observer.register_update_handler(f, None, workspace_id, filters)
|
| 167 |
+
return decorator
|
| 168 |
+
|
| 169 |
+
# Called as simple decorator
|
| 170 |
+
return _observer.register_update_handler(func)
|
| 171 |
+
|
| 172 |
+
def on_artifact_delete(func=None, workspace_id=None, filters=None):
|
| 173 |
+
"""
|
| 174 |
+
Decorator for artifact deletion handlers
|
| 175 |
+
|
| 176 |
+
Can be used as a simple decorator (@on_artifact_delete) or with parameters:
|
| 177 |
+
@on_artifact_delete(workspace_id='abc', filters={'artifact_type': 'WEB_PAGES'})
|
| 178 |
+
"""
|
| 179 |
+
if func is None:
|
| 180 |
+
# Called with parameters
|
| 181 |
+
def decorator(f):
|
| 182 |
+
return _observer.register_delete_handler(f, None, workspace_id, filters)
|
| 183 |
+
return decorator
|
| 184 |
+
|
| 185 |
+
# Called as simple decorator
|
| 186 |
+
return _observer.register_delete_handler(func)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
aworld/output/outputs.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
import asyncio
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
from dataclasses import field, dataclass
|
| 5 |
+
from typing import AsyncIterator, Any, Union, Iterator
|
| 6 |
+
|
| 7 |
+
from aworld.logs.util import logger
|
| 8 |
+
from aworld.output import Output
|
| 9 |
+
from aworld.output.base import RUN_FINISHED_SIGNAL
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class Outputs(abc.ABC):
|
| 14 |
+
"""Base class for managing output streams in the AWorld framework.
|
| 15 |
+
Provides abstract methods for adding and streaming outputs both synchronously and asynchronously.
|
| 16 |
+
reference: https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py
|
| 17 |
+
"""
|
| 18 |
+
_metadata: dict = field(default_factory=dict)
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
async def add_output(self, output: Output):
|
| 22 |
+
"""Add an output asynchronously to the output stream.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
output (Output): The output to be added
|
| 26 |
+
"""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def sync_add_output(self, output: Output):
|
| 31 |
+
"""Add an output synchronously to the output stream.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
output (Output): The output to be added
|
| 35 |
+
"""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
async def stream_events(self) -> Union[AsyncIterator[Output], list]:
|
| 40 |
+
"""Stream outputs asynchronously.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
AsyncIterator[Output]: An async iterator of outputs
|
| 44 |
+
"""
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def sync_stream_events(self) -> Union[Iterator[Output], list]:
|
| 49 |
+
"""Stream outputs synchronously.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Iterator[Output]: An iterator of outputs
|
| 53 |
+
"""
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
async def mark_completed(self):
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
async def get_metadata(self) -> dict:
|
| 61 |
+
return self._metadata
|
| 62 |
+
|
| 63 |
+
async def set_metadata(self, metadata: dict):
|
| 64 |
+
self._metadata = metadata
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class AsyncOutputs(Outputs):
|
| 68 |
+
"""Intermediate class that implements the Outputs interface with async support.
|
| 69 |
+
This class serves as a base for more specific async output implementations."""
|
| 70 |
+
|
| 71 |
+
async def add_output(self, output: Output):
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
def sync_add_output(self, output: Output):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
async def stream_events(self) -> Union[AsyncIterator[Output], list]:
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
def sync_stream_events(self) -> Union[Iterator[Output]]:
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class DefaultOutputs(Outputs):
|
| 86 |
+
"""DefaultAsyncOutputs """
|
| 87 |
+
|
| 88 |
+
_outputs: list = field(default_factory=list)
|
| 89 |
+
|
| 90 |
+
async def add_output(self, output: Output):
|
| 91 |
+
self._outputs.append(output)
|
| 92 |
+
|
| 93 |
+
def sync_add_output(self, output: Output):
|
| 94 |
+
self._outputs.append(output)
|
| 95 |
+
|
| 96 |
+
async def stream_events(self) -> Union[AsyncIterator[Output], list]:
|
| 97 |
+
return self._outputs
|
| 98 |
+
|
| 99 |
+
def sync_stream_events(self) -> Union[Iterator[Output], list]:
|
| 100 |
+
return self._outputs
|
| 101 |
+
|
| 102 |
+
async def mark_completed(self):
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class StreamingOutputs(AsyncOutputs):
|
| 108 |
+
"""Concrete implementation of AsyncOutputs that provides streaming functionality.
|
| 109 |
+
Manages a queue of outputs and handles streaming with error checking and task management."""
|
| 110 |
+
|
| 111 |
+
# Task and input related fields
|
| 112 |
+
# task: Task = Field(default=None) # The task associated with these outputs
|
| 113 |
+
input: Any = field(default=None) # Input data for the task
|
| 114 |
+
usage: dict = field(default=None) # Usage statistics
|
| 115 |
+
|
| 116 |
+
# State tracking
|
| 117 |
+
is_complete: bool = field(default=False) # Flag indicating if streaming is complete
|
| 118 |
+
|
| 119 |
+
# Queue for managing outputs
|
| 120 |
+
_output_queue: asyncio.Queue[Output] = field(
|
| 121 |
+
default_factory=asyncio.Queue, repr=False
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Internal state management
|
| 125 |
+
_visited_outputs: list[Output] = field(default_factory=list)
|
| 126 |
+
_stored_exception: Exception | None = field(default=None, repr=False) # Stores any exceptions that occur
|
| 127 |
+
_run_impl_task: asyncio.Task[Any] | None = field(default=None, repr=False) # The running task
|
| 128 |
+
|
| 129 |
+
async def add_output(self, output: Output):
|
| 130 |
+
"""Add an output to the queue asynchronously.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
output (Output): The output to be added to the queue
|
| 134 |
+
"""
|
| 135 |
+
self._output_queue.put_nowait(output)
|
| 136 |
+
|
| 137 |
+
async def stream_events(self) -> AsyncIterator[Output]:
|
| 138 |
+
"""Stream outputs asynchronously, handling cached outputs and new outputs from the queue.
|
| 139 |
+
Includes error checking and task cleanup.
|
| 140 |
+
|
| 141 |
+
Yields:
|
| 142 |
+
Output: The next output in the stream
|
| 143 |
+
|
| 144 |
+
Raises:
|
| 145 |
+
Exception: Any stored exception that occurred during streaming
|
| 146 |
+
"""
|
| 147 |
+
# First yield any cached outputs
|
| 148 |
+
for output in self._visited_outputs:
|
| 149 |
+
if output == RUN_FINISHED_SIGNAL:
|
| 150 |
+
self._output_queue.task_done()
|
| 151 |
+
return
|
| 152 |
+
yield output
|
| 153 |
+
|
| 154 |
+
# Main streaming loop
|
| 155 |
+
while True:
|
| 156 |
+
self._check_errors()
|
| 157 |
+
if self._stored_exception:
|
| 158 |
+
logger.debug("Breaking due to stored exception")
|
| 159 |
+
self.is_complete = True
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
if self.is_complete and self._output_queue.empty():
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
output = await self._output_queue.get()
|
| 167 |
+
self._visited_outputs.append(output)
|
| 168 |
+
|
| 169 |
+
except asyncio.CancelledError:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
if output == RUN_FINISHED_SIGNAL:
|
| 173 |
+
self._output_queue.task_done()
|
| 174 |
+
self._check_errors()
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
yield output
|
| 178 |
+
self._output_queue.task_done()
|
| 179 |
+
|
| 180 |
+
self._cleanup_tasks()
|
| 181 |
+
|
| 182 |
+
if self._stored_exception:
|
| 183 |
+
raise self._stored_exception
|
| 184 |
+
|
| 185 |
+
def _check_errors(self):
|
| 186 |
+
"""Check for errors in the streaming process.
|
| 187 |
+
Verifies step count and checks for exceptions in the running task.
|
| 188 |
+
"""
|
| 189 |
+
# Check the task for any exceptions
|
| 190 |
+
if self._run_impl_task and self._run_impl_task.done():
|
| 191 |
+
exc = self._run_impl_task.exception()
|
| 192 |
+
if exc and isinstance(exc, Exception):
|
| 193 |
+
self._stored_exception = exc
|
| 194 |
+
|
| 195 |
+
def _cleanup_tasks(self):
|
| 196 |
+
"""Clean up any running tasks by cancelling them if they're not done."""
|
| 197 |
+
if self._run_impl_task and not self._run_impl_task.done():
|
| 198 |
+
self._run_impl_task.cancel()
|
| 199 |
+
|
| 200 |
+
async def mark_completed(self) -> None:
|
| 201 |
+
"""Mark the streaming process as completed by adding a RUN_FINISHED_SIGNAL to the queue."""
|
| 202 |
+
await self._output_queue.put(RUN_FINISHED_SIGNAL)
|
aworld/output/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import AsyncGenerator, Generator, Callable, Any
|
| 3 |
+
|
| 4 |
+
from aworld.output.workspace import WorkSpace
|
| 5 |
+
from aworld.output.base import OutputPart, MessageOutput, Output
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
async def consume_output(__output__, callback):
|
| 9 |
+
if isinstance(__output__, Output):
|
| 10 |
+
## parts
|
| 11 |
+
if __output__.parts:
|
| 12 |
+
for part in __output__.parts:
|
| 13 |
+
await consume_part(part, callback)
|
| 14 |
+
## MessageOutput
|
| 15 |
+
if isinstance(__output__, MessageOutput):
|
| 16 |
+
if __output__.reason_generator or __output__.response_generator:
|
| 17 |
+
if __output__.reason_generator:
|
| 18 |
+
await consume_content(__output__.reason_generator, callback)
|
| 19 |
+
if __output__.reason_generator:
|
| 20 |
+
await consume_content(__output__.response_generator, callback)
|
| 21 |
+
else:
|
| 22 |
+
await consume_content(__output__.reasoning, callback)
|
| 23 |
+
await consume_content(__output__.response, callback)
|
| 24 |
+
if __output__.tool_calls:
|
| 25 |
+
await consume_content(__output__.tool_calls, callback)
|
| 26 |
+
else:
|
| 27 |
+
await consume_content(__output__.data, callback)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def consume_part(part, callback):
|
| 33 |
+
if isinstance(part.content, Output):
|
| 34 |
+
await consume_output(__output__=part.content, callback=callback)
|
| 35 |
+
else:
|
| 36 |
+
await consume_content(__content__=part.content, callback=callback)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
async def consume_content(__content__, callback: Callable[..., Any]):
|
| 41 |
+
if not __content__:
|
| 42 |
+
return
|
| 43 |
+
if isinstance(__content__, AsyncGenerator):
|
| 44 |
+
async for sub_content in __content__:
|
| 45 |
+
if isinstance(sub_content, OutputPart):
|
| 46 |
+
await consume_part(sub_content, callback)
|
| 47 |
+
elif isinstance(sub_content, Output):
|
| 48 |
+
await consume_output(sub_content, callback)
|
| 49 |
+
else:
|
| 50 |
+
await callback(sub_content)
|
| 51 |
+
elif isinstance(__content__, Generator) or isinstance(__content__, list):
|
| 52 |
+
for sub_content in __content__:
|
| 53 |
+
if isinstance(sub_content, OutputPart):
|
| 54 |
+
await consume_part(sub_content, callback)
|
| 55 |
+
elif isinstance(sub_content, Output):
|
| 56 |
+
await consume_output(sub_content, callback)
|
| 57 |
+
else:
|
| 58 |
+
await callback(sub_content)
|
| 59 |
+
elif isinstance(__content__, str):
|
| 60 |
+
await callback(__content__)
|
| 61 |
+
else:
|
| 62 |
+
await callback(__content__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
async def load_workspace(workspace_id: str, workspace_type: str, workspace_parent_path: str):
|
| 66 |
+
"""
|
| 67 |
+
This function is used to get the workspace by its id.
|
| 68 |
+
It first checks the workspace type and then creates the workspace accordingly.
|
| 69 |
+
If the workspace type is not valid, it raises an HTTPException.
|
| 70 |
+
"""
|
| 71 |
+
if workspace_id is None:
|
| 72 |
+
raise RuntimeError("workspace_id is None")
|
| 73 |
+
|
| 74 |
+
if workspace_type == "local":
|
| 75 |
+
workspace = WorkSpace.from_local_storages(workspace_id, storage_path=os.path.join(workspace_parent_path, workspace_id))
|
| 76 |
+
elif workspace_type == "oss":
|
| 77 |
+
workspace = WorkSpace.from_oss_storages(workspace_id, storage_path=os.path.join(workspace_parent_path, workspace_id))
|
| 78 |
+
else:
|
| 79 |
+
raise RuntimeError("Invalid workspace type")
|
| 80 |
+
return workspace
|
aworld/output/workspace.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import traceback
|
| 4 |
+
import uuid
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Dict, Any, Optional, List, Union
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 9 |
+
|
| 10 |
+
from aworld.output.artifact import ArtifactType, Artifact
|
| 11 |
+
from aworld.output.code_artifact import CodeArtifact
|
| 12 |
+
from aworld.output.storage.artifact_repository import ArtifactRepository, LocalArtifactRepository
|
| 13 |
+
from aworld.output.observer import WorkspaceObserver, get_observer
|
| 14 |
+
from aworld.output.storage.oss_artifact_repository import OSSArtifactRepository
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class WorkSpace(BaseModel):
|
| 18 |
+
"""
|
| 19 |
+
Artifact workspace, managing a group of related artifacts
|
| 20 |
+
|
| 21 |
+
Provides collaborative editing features, supporting version management, update notifications, etc. for multiple Artifacts
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
workspace_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="unique identifier for the workspace")
|
| 25 |
+
name: str = Field(default="", description="name of the workspace")
|
| 26 |
+
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
|
| 27 |
+
updated_at: str = Field(default_factory=lambda: datetime.now().isoformat())
|
| 28 |
+
metadata: Dict[str, Any] = Field(default={}, description="metadata")
|
| 29 |
+
artifacts: List[Artifact] = Field(default=[], description="list of artifacts")
|
| 30 |
+
|
| 31 |
+
artifact_id_index: Dict[str, int] = Field(default={}, description="artifact id index", exclude=True)
|
| 32 |
+
observers: Optional[List[WorkspaceObserver]] = Field(default=[], description="list of observers", exclude=True)
|
| 33 |
+
repository: Optional[ArtifactRepository] = Field(default=None, description="local artifact repository", exclude=True)
|
| 34 |
+
|
| 35 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
workspace_id: Optional[str] = None,
|
| 40 |
+
name: Optional[str] = None,
|
| 41 |
+
storage_path: Optional[str] = None,
|
| 42 |
+
observers: Optional[List[WorkspaceObserver]] = None,
|
| 43 |
+
use_default_observer: bool = True,
|
| 44 |
+
clear_existing: bool = False,
|
| 45 |
+
repository: Optional[ArtifactRepository] = None
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.workspace_id = workspace_id or str(uuid.uuid4())
|
| 49 |
+
self.name = name or f"Workspace-{self.workspace_id[:8]}"
|
| 50 |
+
self.created_at = datetime.now().isoformat()
|
| 51 |
+
self.updated_at = self.created_at
|
| 52 |
+
|
| 53 |
+
# Initialize repository first
|
| 54 |
+
storage_dir = storage_path or os.path.join("data", "workspaces", self.workspace_id)
|
| 55 |
+
if repository is None:
|
| 56 |
+
self.repository = LocalArtifactRepository(storage_dir)
|
| 57 |
+
else:
|
| 58 |
+
self.repository = repository
|
| 59 |
+
|
| 60 |
+
# Initialize artifacts and metadata
|
| 61 |
+
if clear_existing:
|
| 62 |
+
self.artifacts = []
|
| 63 |
+
self.metadata = {}
|
| 64 |
+
else:
|
| 65 |
+
# Try to load existing workspace data
|
| 66 |
+
workspace_data = self._load_workspace_data()
|
| 67 |
+
if workspace_data:
|
| 68 |
+
self.artifacts = workspace_data.get('artifacts', [])
|
| 69 |
+
self.metadata = workspace_data.get('metadata', {})
|
| 70 |
+
self.created_at = workspace_data.get('created_at', self.created_at)
|
| 71 |
+
self.updated_at = workspace_data.get('updated_at', self.updated_at)
|
| 72 |
+
else:
|
| 73 |
+
self.artifacts = []
|
| 74 |
+
self.metadata = {}
|
| 75 |
+
|
| 76 |
+
# Build artifact_id_index after loading artifacts
|
| 77 |
+
self._rebuild_artifact_id_index()
|
| 78 |
+
|
| 79 |
+
# Initialize observers
|
| 80 |
+
self.observers: List[WorkspaceObserver] = []
|
| 81 |
+
if use_default_observer:
|
| 82 |
+
self.observers.append(get_observer())
|
| 83 |
+
|
| 84 |
+
if observers:
|
| 85 |
+
for observer in observers:
|
| 86 |
+
if observer not in self.observers: # Avoid duplicates
|
| 87 |
+
self.add_observer(observer)
|
| 88 |
+
|
| 89 |
+
def _load_workspace_data(self) -> Optional[Dict[str, Any]]:
|
| 90 |
+
"""
|
| 91 |
+
Load workspace data from repository
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Dictionary containing workspace data if exists, None otherwise
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
# Get workspace versions
|
| 98 |
+
|
| 99 |
+
workspace_data = self.repository.load_index()
|
| 100 |
+
|
| 101 |
+
if not workspace_data:
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
# Load artifacts
|
| 105 |
+
artifacts = []
|
| 106 |
+
# First load the artifacts list from workspace data
|
| 107 |
+
workspace_artifacts = workspace_data.get("artifacts", [])
|
| 108 |
+
for artifact_data in workspace_artifacts:
|
| 109 |
+
artifact_id = artifact_data.get("artifact_id")
|
| 110 |
+
if artifact_id:
|
| 111 |
+
artifact_data = self.repository.retrieve_latest_artifact(artifact_id)
|
| 112 |
+
if artifact_data:
|
| 113 |
+
artifacts.append(Artifact.from_dict(artifact_data))
|
| 114 |
+
|
| 115 |
+
return {
|
| 116 |
+
"artifacts": artifacts,
|
| 117 |
+
"metadata": workspace_data.get("metadata", {}),
|
| 118 |
+
"created_at": workspace_data.get("created_at"),
|
| 119 |
+
"updated_at": workspace_data.get("updated_at")
|
| 120 |
+
}
|
| 121 |
+
except Exception as e:
|
| 122 |
+
traceback.print_exc()
|
| 123 |
+
print(f"Error loading workspace data: {e}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def from_local_storages(cls, workspace_id: Optional[str] = None,
|
| 128 |
+
name: Optional[str] = None,
|
| 129 |
+
storage_path: Optional[str] = None,
|
| 130 |
+
observers: Optional[List[WorkspaceObserver]] = None,
|
| 131 |
+
use_default_observer: bool = True
|
| 132 |
+
) -> "WorkSpace":
|
| 133 |
+
"""
|
| 134 |
+
Create a workspace instance from local storage
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
workspace_id: Optional workspace ID
|
| 138 |
+
name: Optional workspace name
|
| 139 |
+
storage_path: Optional storage path
|
| 140 |
+
observers: Optional list of observers
|
| 141 |
+
use_default_observer: Whether to use default observer
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
WorkSpace instance
|
| 145 |
+
"""
|
| 146 |
+
if storage_path is None:
|
| 147 |
+
storage_path = os.path.join("data", "workspaces", workspace_id)
|
| 148 |
+
workspace = cls(
|
| 149 |
+
workspace_id=workspace_id,
|
| 150 |
+
name=name,
|
| 151 |
+
storage_path=storage_path,
|
| 152 |
+
observers=observers,
|
| 153 |
+
use_default_observer=use_default_observer,
|
| 154 |
+
clear_existing=False # Always try to load existing data
|
| 155 |
+
)
|
| 156 |
+
return workspace
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def from_oss_storages(cls,
|
| 160 |
+
workspace_id: Optional[str] = None,
|
| 161 |
+
name: Optional[str] = None,
|
| 162 |
+
storage_path: Optional[str] = "aworld/workspaces/",
|
| 163 |
+
observers: Optional[List[WorkspaceObserver]] = None,
|
| 164 |
+
use_default_observer: bool = True,
|
| 165 |
+
oss_config: Optional[Dict[str, Any]] = None,
|
| 166 |
+
) -> "WorkSpace":
|
| 167 |
+
if oss_config is None:
|
| 168 |
+
oss_config = {
|
| 169 |
+
"access_key_id": os.getenv("OSS_ACCESS_KEY_ID"),
|
| 170 |
+
"access_key_secret": os.getenv("OSS_ACCESS_KEY_SECRET"),
|
| 171 |
+
"endpoint": os.getenv("OSS_ENDPOINT"),
|
| 172 |
+
"bucket_name": os.getenv("OSS_BUCKET_NAME"),
|
| 173 |
+
}
|
| 174 |
+
repository = OSSArtifactRepository(
|
| 175 |
+
access_key_id=oss_config["access_key_id"],
|
| 176 |
+
access_key_secret=oss_config["access_key_secret"],
|
| 177 |
+
endpoint=oss_config["endpoint"],
|
| 178 |
+
bucket_name=oss_config["bucket_name"],
|
| 179 |
+
storage_path=storage_path
|
| 180 |
+
)
|
| 181 |
+
workspace = cls(
|
| 182 |
+
workspace_id=workspace_id,
|
| 183 |
+
name=name,
|
| 184 |
+
storage_path=storage_path,
|
| 185 |
+
observers=observers,
|
| 186 |
+
use_default_observer=use_default_observer,
|
| 187 |
+
repository=repository
|
| 188 |
+
)
|
| 189 |
+
return workspace
|
| 190 |
+
|
| 191 |
+
async def create_artifact(
|
| 192 |
+
self,
|
| 193 |
+
artifact_type: Union[ArtifactType, str],
|
| 194 |
+
artifact_id: Optional[str] = None,
|
| 195 |
+
content: Optional[Any] = None,
|
| 196 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 197 |
+
) -> List[Artifact]:
|
| 198 |
+
"""
|
| 199 |
+
Create a new artifact
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
artifact_type: Artifact type (enum or string)
|
| 203 |
+
artifact_id: Optional artifact ID (will be generated if not provided)
|
| 204 |
+
content: Artifact content
|
| 205 |
+
metadata: Metadata dictionary
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
List of created artifact objects
|
| 209 |
+
"""
|
| 210 |
+
# If a string is passed, convert to enum type
|
| 211 |
+
if isinstance(artifact_type, str):
|
| 212 |
+
artifact_type = ArtifactType(artifact_type)
|
| 213 |
+
|
| 214 |
+
# Create new artifacts
|
| 215 |
+
artifacts = []
|
| 216 |
+
|
| 217 |
+
# Ensure metadata is a dictionary
|
| 218 |
+
if metadata is None:
|
| 219 |
+
metadata = {}
|
| 220 |
+
|
| 221 |
+
# Ensure artifact_id is a valid string
|
| 222 |
+
if artifact_id is None:
|
| 223 |
+
artifact_id = str(uuid.uuid4())
|
| 224 |
+
|
| 225 |
+
if artifact_type == ArtifactType.CODE:
|
| 226 |
+
artifacts = CodeArtifact.from_code_content(artifact_type, content)
|
| 227 |
+
else:
|
| 228 |
+
artifact = Artifact(
|
| 229 |
+
artifact_id=artifact_id,
|
| 230 |
+
artifact_type=artifact_type,
|
| 231 |
+
content=content,
|
| 232 |
+
metadata=metadata
|
| 233 |
+
)
|
| 234 |
+
artifacts.append(artifact) # Add single artifact to the list
|
| 235 |
+
|
| 236 |
+
# Add to workspace
|
| 237 |
+
for artifact in artifacts:
|
| 238 |
+
# Store in repository
|
| 239 |
+
await self._store_artifact(artifact)
|
| 240 |
+
logging.info(f"[📂WORKSPACE]💾 Storing artifact in repository: {artifact.artifact_id}")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Update workspace time
|
| 244 |
+
self.updated_at = datetime.now().isoformat()
|
| 245 |
+
|
| 246 |
+
# Save workspace state to create new version
|
| 247 |
+
self.save()
|
| 248 |
+
|
| 249 |
+
return artifacts # Return the list of created artifacts
|
| 250 |
+
|
| 251 |
+
async def add_artifact(
|
| 252 |
+
self,
|
| 253 |
+
artifact: Artifact
|
| 254 |
+
) -> None:
|
| 255 |
+
"""
|
| 256 |
+
Create a new artifact
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
artifact: Artifact
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
List of created artifact objects
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
# Store in repository
|
| 266 |
+
await self._store_artifact(artifact)
|
| 267 |
+
|
| 268 |
+
# Update workspace time
|
| 269 |
+
self.updated_at = datetime.now().isoformat()
|
| 270 |
+
|
| 271 |
+
# Save workspace state to create new version
|
| 272 |
+
self.save()
|
| 273 |
+
|
| 274 |
+
await self._notify_observers("create", artifact)
|
| 275 |
+
|
| 276 |
+
async def mark_as_completed(self, artifact_id: str) -> None:
|
| 277 |
+
"""
|
| 278 |
+
Mark an artifact as completed
|
| 279 |
+
"""
|
| 280 |
+
artifact = self.get_artifact(artifact_id)
|
| 281 |
+
if artifact:
|
| 282 |
+
artifact.mark_complete()
|
| 283 |
+
self.repository.store_artifact(artifact)
|
| 284 |
+
logging.info(f"[📂WORKSPACE]🎉 Marking artifact as completed: {artifact_id}")
|
| 285 |
+
await self._notify_observers("complete", artifact)
|
| 286 |
+
self.save()
|
| 287 |
+
|
| 288 |
+
def get_artifact(self, artifact_id: str) -> Optional[Artifact]:
|
| 289 |
+
"""Get artifact with the specified ID"""
|
| 290 |
+
for artifact in self.artifacts:
|
| 291 |
+
if artifact.artifact_id == artifact_id:
|
| 292 |
+
return artifact
|
| 293 |
+
return None
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_terminal(self) -> str:
|
| 297 |
+
pass
|
| 298 |
+
|
| 299 |
+
def get_webpage_groups(self) -> list[Any] | None:
|
| 300 |
+
return self.list_artifacts(ArtifactType.WEB_PAGES)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
async def update_artifact(
|
| 304 |
+
self,
|
| 305 |
+
artifact_id: str,
|
| 306 |
+
content: Any,
|
| 307 |
+
description: str = "Content update"
|
| 308 |
+
) -> Optional[Artifact]:
|
| 309 |
+
"""
|
| 310 |
+
Update artifact content
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
artifact_id: Artifact ID
|
| 314 |
+
content: New content
|
| 315 |
+
description: Update description
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Updated artifact, or None if it doesn't exist
|
| 319 |
+
"""
|
| 320 |
+
artifact = self.get_artifact(artifact_id)
|
| 321 |
+
if artifact:
|
| 322 |
+
artifact.update_content(content, description)
|
| 323 |
+
|
| 324 |
+
# Update storage
|
| 325 |
+
await self._store_artifact(artifact)
|
| 326 |
+
|
| 327 |
+
return artifact
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
async def delete_artifact(self, artifact_id: str) -> bool:
|
| 331 |
+
"""
|
| 332 |
+
Delete an artifact from the workspace
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
artifact_id: Artifact ID
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
Whether deletion was successful
|
| 339 |
+
"""
|
| 340 |
+
existed = self._check_artifact_exists(artifact_id)
|
| 341 |
+
if not existed:
|
| 342 |
+
return True
|
| 343 |
+
for i, artifact in enumerate(self.artifacts):
|
| 344 |
+
if artifact.artifact_id == artifact_id:
|
| 345 |
+
# Remove from list
|
| 346 |
+
self.artifacts.pop(i)
|
| 347 |
+
|
| 348 |
+
# Update workspace time
|
| 349 |
+
self.updated_at = datetime.now().isoformat()
|
| 350 |
+
|
| 351 |
+
self.repository.delete_artifact(artifact_id)
|
| 352 |
+
# Save workspace state to create new version
|
| 353 |
+
self.save()
|
| 354 |
+
|
| 355 |
+
# Notify observers
|
| 356 |
+
await self._notify_observers("delete", artifact)
|
| 357 |
+
return True
|
| 358 |
+
return False
|
| 359 |
+
|
| 360 |
+
def list_artifacts(self, filter_type: Optional[ArtifactType] = None) -> List[Artifact]:
|
| 361 |
+
"""
|
| 362 |
+
List all artifacts in the workspace
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
filter_type: Optional filter type
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
List of artifacts
|
| 369 |
+
"""
|
| 370 |
+
if filter_type:
|
| 371 |
+
return [a for a in self.artifacts if a.artifact_type == filter_type]
|
| 372 |
+
return self.artifacts
|
| 373 |
+
|
| 374 |
+
def add_observer(self, observer: WorkspaceObserver) -> None:
|
| 375 |
+
"""
|
| 376 |
+
Add a workspace observer
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
observer: Observer object implementing WorkspaceObserver interface
|
| 380 |
+
"""
|
| 381 |
+
if not isinstance(observer, WorkspaceObserver):
|
| 382 |
+
raise TypeError("Observer must be an instance of WorkspaceObserver")
|
| 383 |
+
self.observers.append(observer)
|
| 384 |
+
|
| 385 |
+
def remove_observer(self, observer: WorkspaceObserver) -> None:
|
| 386 |
+
"""Remove an observer"""
|
| 387 |
+
if observer in self.observers:
|
| 388 |
+
self.observers.remove(observer)
|
| 389 |
+
|
| 390 |
+
async def _notify_observers(self, operation: str, artifact: Artifact) -> List[Any]:
|
| 391 |
+
"""
|
| 392 |
+
Notify all observers of workspace changes
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
operation: Type of operation (create, update, delete)
|
| 396 |
+
artifact: Affected artifact
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
List of results from handlers
|
| 400 |
+
"""
|
| 401 |
+
results = []
|
| 402 |
+
for observer in self.observers:
|
| 403 |
+
try:
|
| 404 |
+
if operation == "create":
|
| 405 |
+
result = await observer.on_create(workspace_id=self.workspace_id, artifact=artifact)
|
| 406 |
+
if result:
|
| 407 |
+
results.append(result)
|
| 408 |
+
elif operation == "update":
|
| 409 |
+
result = await observer.on_update(workspace_id=self.workspace_id, artifact=artifact)
|
| 410 |
+
if result:
|
| 411 |
+
results.append(result)
|
| 412 |
+
elif operation == "delete":
|
| 413 |
+
result = await observer.on_delete(workspace_id=self.workspace_id, artifact=artifact)
|
| 414 |
+
if result:
|
| 415 |
+
results.append(result)
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f"Observer notification failed: {e}")
|
| 418 |
+
return results
|
| 419 |
+
|
| 420 |
+
def _check_artifact_exists(self, artifact_id: str) -> bool:
|
| 421 |
+
return self.artifact_id_index.get(artifact_id, -1) >= 0
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _append_artifact(self, artifact: Artifact) -> None:
|
| 425 |
+
self.artifacts.append(artifact)
|
| 426 |
+
logging.info(f"[📂WORKSPACE]🆕 Appending artifact in repository: {artifact.artifact_id}")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def _update_artifact(self, artifact: Artifact) -> None:
|
| 430 |
+
for i, a in enumerate(self.artifacts):
|
| 431 |
+
if a.artifact_id == artifact.artifact_id:
|
| 432 |
+
self.artifacts[i] = artifact
|
| 433 |
+
logging.info(f"[📂WORKSPACE]🔄 Updating artifact in repository: {artifact.artifact_id}")
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
async def _store_artifact(self, artifact: Artifact) -> None:
|
| 438 |
+
if self._check_artifact_exists(artifact.artifact_id):
|
| 439 |
+
self._update_artifact(artifact)
|
| 440 |
+
await self._notify_observers("update", artifact)
|
| 441 |
+
else:
|
| 442 |
+
self._append_artifact(artifact)
|
| 443 |
+
await self._notify_observers("create", artifact)
|
| 444 |
+
|
| 445 |
+
"""Store artifact in repository"""
|
| 446 |
+
artifact_data = artifact.to_dict()
|
| 447 |
+
|
| 448 |
+
# Include complete version history
|
| 449 |
+
artifact_data["version_history"] = artifact.version_history
|
| 450 |
+
|
| 451 |
+
version_id = self.repository.store_artifact(artifact)
|
| 452 |
+
|
| 453 |
+
# Store in repository
|
| 454 |
+
artifact.current_version = version_id
|
| 455 |
+
|
| 456 |
+
def save(self) -> None:
|
| 457 |
+
"""
|
| 458 |
+
Save workspace state
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
Workspace storage ID
|
| 462 |
+
"""
|
| 463 |
+
workspace_data = {
|
| 464 |
+
"workspace_id": self.workspace_id,
|
| 465 |
+
"name": self.name,
|
| 466 |
+
"created_at": self.created_at,
|
| 467 |
+
"updated_at": self.updated_at,
|
| 468 |
+
"metadata": self.metadata,
|
| 469 |
+
"artifact_ids": [a.artifact_id for a in self.artifacts],
|
| 470 |
+
"artifacts": [
|
| 471 |
+
{
|
| 472 |
+
"artifact_id": a.artifact_id,
|
| 473 |
+
"type": str(a.artifact_type),
|
| 474 |
+
"metadata": a.metadata,
|
| 475 |
+
# "version": a.current_version
|
| 476 |
+
} for a in self.artifacts
|
| 477 |
+
]
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
# Store workspace information with workspace_id in metadata
|
| 481 |
+
self.repository.save_index(workspace_data)
|
| 482 |
+
self._rebuild_artifact_id_index()
|
| 483 |
+
|
| 484 |
+
def get_file_content_by_artifact_id(self, artifact_id: str) -> str:
|
| 485 |
+
"""
|
| 486 |
+
Get concatenated content of all artifacts with the same filename.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
artifact_id: artifact_id
|
| 490 |
+
|
| 491 |
+
Returns:
|
| 492 |
+
Raw unescaped concatenated content of all matching artifacts
|
| 493 |
+
"""
|
| 494 |
+
filename = artifact_id
|
| 495 |
+
for artifact in self.artifacts:
|
| 496 |
+
if artifact.artifact_id == artifact_id:
|
| 497 |
+
filename = artifact.metadata.get('filename')
|
| 498 |
+
break
|
| 499 |
+
|
| 500 |
+
result = ""
|
| 501 |
+
for artifact in self.artifacts:
|
| 502 |
+
if artifact.metadata.get('filename') == filename:
|
| 503 |
+
if artifact.content:
|
| 504 |
+
result = result + artifact.content
|
| 505 |
+
decoded_string = result.encode('utf-8').decode('unicode_escape')
|
| 506 |
+
print(result)
|
| 507 |
+
|
| 508 |
+
return decoded_string
|
| 509 |
+
|
| 510 |
+
def generate_tree_data(self) -> Dict[str, Any]:
|
| 511 |
+
"""
|
| 512 |
+
Generate a directory tree structure using the repository's implementation.
|
| 513 |
+
Returns:
|
| 514 |
+
A dictionary representing the directory tree.
|
| 515 |
+
"""
|
| 516 |
+
return self.repository.generate_tree_data(self.name)
|
| 517 |
+
|
| 518 |
+
def _rebuild_artifact_id_index(self) -> None:
|
| 519 |
+
"""
|
| 520 |
+
Rebuild the artifact_id_index mapping artifact_id to its index in self.artifacts.
|
| 521 |
+
"""
|
| 522 |
+
self.artifact_id_index = {artifact.artifact_id: idx for idx, artifact in enumerate(self.artifacts)}
|