File size: 2,864 Bytes
645a051
 
 
 
 
 
 
e0c585c
645a051
e0c585c
645a051
e0c585c
645a051
 
 
e0c585c
645a051
 
 
 
 
e0c585c
 
645a051
 
 
e0c585c
 
 
 
 
 
 
 
 
 
 
 
645a051
 
e0c585c
645a051
e0c585c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645a051
 
 
 
 
 
e0c585c
 
 
645a051
e0c585c
 
645a051
 
 
 
 
 
 
 
e0c585c
645a051
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Thread-safe state management for Magentic agents.

Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users
searching simultaneously via Gradio).
"""

from contextvars import ContextVar
from typing import TYPE_CHECKING, Any, cast

from pydantic import BaseModel

from src.services.research_memory import ResearchMemory

if TYPE_CHECKING:
    from src.services.embeddings import EmbeddingService
    from src.utils.models import Evidence


class MagenticState(BaseModel):
    """Mutable state for a Magentic workflow session."""

    # We wrap ResearchMemory. Type as Any to avoid pydantic validation issues with complex objects
    memory: Any = None  # Instance of ResearchMemory

    model_config = {"arbitrary_types_allowed": True}

    # --- Proxy methods for backwards compatibility with retrieval_agent.py ---

    async def add_evidence(self, evidence: list["Evidence"]) -> int:
        """Add evidence to memory with deduplication and embedding storage.

        This method delegates to ResearchMemory.store_evidence() which:
        1. Performs semantic deduplication (threshold 0.9)
        2. Stores unique evidence in the vector store
        3. Caches evidence for retrieval

        Args:
            evidence: List of Evidence objects to store.

        Returns:
            Number of new (non-duplicate) evidence items stored.
        """
        if self.memory is None:
            return 0

        memory: ResearchMemory = self.memory
        initial_count = len(memory.evidence_ids)
        await memory.store_evidence(evidence)
        return len(memory.evidence_ids) - initial_count

    @property
    def embedding_service(self) -> "EmbeddingService | None":
        """Get the embedding service from memory."""
        if self.memory is None:
            return None
        # Cast needed because memory is typed as Any to avoid Pydantic issues
        from src.services.embeddings import EmbeddingService as EmbeddingSvc

        return cast(EmbeddingSvc | None, self.memory._embedding_service)


# The ContextVar holds the MagenticState for the current execution context
_magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_state", default=None)


def init_magentic_state(
    query: str, embedding_service: "EmbeddingService | None" = None
) -> MagenticState:
    """Initialize a new state for the current context."""
    memory = ResearchMemory(query=query, embedding_service=embedding_service)
    state = MagenticState(memory=memory)
    _magentic_state_var.set(state)
    return state


def get_magentic_state() -> MagenticState:
    """Get the current state. Raises RuntimeError if not initialized."""
    state = _magentic_state_var.get()
    if state is None:
        raise RuntimeError("MagenticState not initialized. Call init_magentic_state() first.")
    return state