Spaces:
Sleeping
Sleeping
| """Core data models for the agent framework.""" | |
| from typing import Literal, Union, List, Dict, Optional, Any | |
| from pydantic import BaseModel, Field | |
| from dataclasses import dataclass, field | |
| import uuid | |
| from datetime import datetime | |
| class Message(BaseModel): | |
| """A text message in the conversation.""" | |
| type: Literal["message"] = "message" | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class ToolCall(BaseModel): | |
| """LLM's request to execute a tool.""" | |
| type: Literal["tool_call"] = "tool_call" | |
| tool_call_id: str | |
| name: str | |
| arguments: dict | |
| class ToolResult(BaseModel): | |
| """Result from tool execution.""" | |
| type: Literal["tool_result"] = "tool_result" | |
| tool_call_id: str | |
| name: str | |
| status: Literal["success", "error"] | |
| content: list | |
| ContentItem = Union[Message, ToolCall, ToolResult] | |
| class ToolConfirmation(BaseModel): | |
| """User's decision on a pending tool call.""" | |
| tool_call_id: str | |
| approved: bool | |
| modified_arguments: dict | None = None | |
| reason: str | None = None # Reason for rejection (if not approved) | |
| class PendingToolCall(BaseModel): | |
| """A tool call awaiting user confirmation.""" | |
| tool_call: ToolCall | |
| confirmation_message: str | |
| class Event(BaseModel): | |
| """A recorded occurrence during agent execution.""" | |
| id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
| execution_id: str | |
| timestamp: float = Field(default_factory=lambda: datetime.now().timestamp()) | |
| author: str # "user" or agent name | |
| content: List[ContentItem] = Field(default_factory=list) | |
| class ExecutionContext: | |
| """Central storage for all execution state.""" | |
| execution_id: str = field(default_factory=lambda: str(uuid.uuid4())) | |
| events: List[Event] = field(default_factory=list) | |
| current_step: int = 0 | |
| state: Dict[str, Any] = field(default_factory=dict) | |
| final_result: Optional[str | BaseModel] = None | |
| session_id: Optional[str] = None # Link to session for persistence | |
| def add_event(self, event: Event): | |
| """Append an event to the execution history.""" | |
| self.events.append(event) | |
| def increment_step(self): | |
| """Move to the next execution step.""" | |
| self.current_step += 1 | |
| class Session(BaseModel): | |
| """Container for persistent conversation state across multiple run() calls.""" | |
| session_id: str | |
| user_id: str | None = None | |
| events: list[Event] = Field(default_factory=list) | |
| state: dict[str, Any] = Field(default_factory=dict) | |
| created_at: datetime = Field(default_factory=datetime.now) | |
| updated_at: datetime = Field(default_factory=datetime.now) | |
| from abc import ABC, abstractmethod | |
| class BaseSessionManager(ABC): | |
| """Abstract base class for session management.""" | |
| async def create( | |
| self, | |
| session_id: str, | |
| user_id: str | None = None | |
| ) -> Session: | |
| """Create a new session.""" | |
| pass | |
| async def get(self, session_id: str) -> Session | None: | |
| """Retrieve a session by ID. Returns None if not found.""" | |
| pass | |
| async def save(self, session: Session) -> None: | |
| """Persist session changes to storage.""" | |
| pass | |
| async def get_or_create( | |
| self, | |
| session_id: str, | |
| user_id: str | None = None | |
| ) -> Session: | |
| """Get existing session or create new one.""" | |
| session = await self.get(session_id) | |
| if session is None: | |
| session = await self.create(session_id, user_id) | |
| return session | |
| class InMemorySessionManager(BaseSessionManager): | |
| """In-memory session storage for development and testing.""" | |
| def __init__(self): | |
| self._sessions: dict[str, Session] = {} | |
| async def create( | |
| self, | |
| session_id: str, | |
| user_id: str | None = None | |
| ) -> Session: | |
| """Create a new session.""" | |
| if session_id in self._sessions: | |
| raise ValueError(f"Session {session_id} already exists") | |
| session = Session( | |
| session_id=session_id, | |
| user_id=user_id | |
| ) | |
| self._sessions[session_id] = session | |
| return session | |
| async def get(self, session_id: str) -> Session | None: | |
| """Retrieve a session by ID.""" | |
| return self._sessions.get(session_id) | |
| async def save(self, session: Session) -> None: | |
| """Save session to storage.""" | |
| self._sessions[session.session_id] = session |