DeepBoner / src /agents /state.py
VibecoderMcSwaggins's picture
feat: Wire LlamaIndex RAG into Simple Mode (Tiered Embedding) (#83)
7baf8ba unverified
"""Thread-safe state management for Magentic agents.
Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users
searching simultaneously via Gradio).
"""
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel
from src.services.research_memory import ResearchMemory
if TYPE_CHECKING:
from src.services.embedding_protocol import EmbeddingServiceProtocol
from src.utils.models import Evidence
class MagenticState(BaseModel):
"""Mutable state for a Magentic workflow session."""
# We wrap ResearchMemory. Type as Any to avoid pydantic validation issues with complex objects
memory: Any = None # Instance of ResearchMemory
model_config = {"arbitrary_types_allowed": True}
# --- Proxy methods for backwards compatibility with retrieval_agent.py ---
async def add_evidence(self, evidence: list["Evidence"]) -> int:
"""Add evidence to memory with deduplication and embedding storage.
This method delegates to ResearchMemory.store_evidence() which:
1. Performs semantic deduplication (threshold 0.9)
2. Stores unique evidence in the vector store
3. Caches evidence for retrieval
Args:
evidence: List of Evidence objects to store.
Returns:
Number of new (non-duplicate) evidence items stored.
"""
if self.memory is None:
return 0
memory: ResearchMemory = self.memory
initial_count = len(memory.evidence_ids)
await memory.store_evidence(evidence)
return len(memory.evidence_ids) - initial_count
@property
def embedding_service(self) -> "EmbeddingServiceProtocol | None":
"""Get the embedding service from memory."""
if self.memory is None:
return None
# Cast needed because memory is typed as Any to avoid Pydantic issues
from src.services.embedding_protocol import EmbeddingServiceProtocol
return cast(EmbeddingServiceProtocol | None, self.memory._embedding_service)
# The ContextVar holds the MagenticState for the current execution context
_magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_state", default=None)
def init_magentic_state(
query: str, embedding_service: "EmbeddingServiceProtocol | None" = None
) -> MagenticState:
"""Initialize a new state for the current context."""
memory = ResearchMemory(query=query, embedding_service=embedding_service)
state = MagenticState(memory=memory)
_magentic_state_var.set(state)
return state
def get_magentic_state() -> MagenticState:
"""Get the current state. Raises RuntimeError if not initialized."""
state = _magentic_state_var.get()
if state is None:
raise RuntimeError("MagenticState not initialized. Call init_magentic_state() first.")
return state