shekkari21's picture
added session and memory
64462d2
"""Core data models for the agent framework."""
from typing import Literal, Union, List, Dict, Optional, Any
from pydantic import BaseModel, Field
from dataclasses import dataclass, field
import uuid
from datetime import datetime
class Message(BaseModel):
"""A text message in the conversation."""
type: Literal["message"] = "message"
role: Literal["system", "user", "assistant"]
content: str
class ToolCall(BaseModel):
"""LLM's request to execute a tool."""
type: Literal["tool_call"] = "tool_call"
tool_call_id: str
name: str
arguments: dict
class ToolResult(BaseModel):
"""Result from tool execution."""
type: Literal["tool_result"] = "tool_result"
tool_call_id: str
name: str
status: Literal["success", "error"]
content: list
ContentItem = Union[Message, ToolCall, ToolResult]
class ToolConfirmation(BaseModel):
"""User's decision on a pending tool call."""
tool_call_id: str
approved: bool
modified_arguments: dict | None = None
reason: str | None = None # Reason for rejection (if not approved)
class PendingToolCall(BaseModel):
"""A tool call awaiting user confirmation."""
tool_call: ToolCall
confirmation_message: str
class Event(BaseModel):
"""A recorded occurrence during agent execution."""
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
execution_id: str
timestamp: float = Field(default_factory=lambda: datetime.now().timestamp())
author: str # "user" or agent name
content: List[ContentItem] = Field(default_factory=list)
@dataclass
class ExecutionContext:
"""Central storage for all execution state."""
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
events: List[Event] = field(default_factory=list)
current_step: int = 0
state: Dict[str, Any] = field(default_factory=dict)
final_result: Optional[str | BaseModel] = None
session_id: Optional[str] = None # Link to session for persistence
def add_event(self, event: Event):
"""Append an event to the execution history."""
self.events.append(event)
def increment_step(self):
"""Move to the next execution step."""
self.current_step += 1
class Session(BaseModel):
"""Container for persistent conversation state across multiple run() calls."""
session_id: str
user_id: str | None = None
events: list[Event] = Field(default_factory=list)
state: dict[str, Any] = Field(default_factory=dict)
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)
from abc import ABC, abstractmethod
class BaseSessionManager(ABC):
"""Abstract base class for session management."""
@abstractmethod
async def create(
self,
session_id: str,
user_id: str | None = None
) -> Session:
"""Create a new session."""
pass
@abstractmethod
async def get(self, session_id: str) -> Session | None:
"""Retrieve a session by ID. Returns None if not found."""
pass
@abstractmethod
async def save(self, session: Session) -> None:
"""Persist session changes to storage."""
pass
async def get_or_create(
self,
session_id: str,
user_id: str | None = None
) -> Session:
"""Get existing session or create new one."""
session = await self.get(session_id)
if session is None:
session = await self.create(session_id, user_id)
return session
class InMemorySessionManager(BaseSessionManager):
"""In-memory session storage for development and testing."""
def __init__(self):
self._sessions: dict[str, Session] = {}
async def create(
self,
session_id: str,
user_id: str | None = None
) -> Session:
"""Create a new session."""
if session_id in self._sessions:
raise ValueError(f"Session {session_id} already exists")
session = Session(
session_id=session_id,
user_id=user_id
)
self._sessions[session_id] = session
return session
async def get(self, session_id: str) -> Session | None:
"""Retrieve a session by ID."""
return self._sessions.get(session_id)
async def save(self, session: Session) -> None:
"""Save session to storage."""
self._sessions[session.session_id] = session