|
|
"""Thread-safe state management for Magentic agents. |
|
|
|
|
|
Uses contextvars to ensure isolation between concurrent requests (e.g., multiple users |
|
|
searching simultaneously via Gradio). |
|
|
""" |
|
|
|
|
|
from contextvars import ContextVar |
|
|
from typing import TYPE_CHECKING, Any |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
from src.utils.models import Citation, Evidence |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from src.services.embeddings import EmbeddingService |
|
|
|
|
|
|
|
|
class MagenticState(BaseModel): |
|
|
"""Mutable state for a Magentic workflow session.""" |
|
|
|
|
|
evidence: list[Evidence] = Field(default_factory=list) |
|
|
|
|
|
|
|
|
embedding_service: Any = None |
|
|
|
|
|
model_config = {"arbitrary_types_allowed": True} |
|
|
|
|
|
def add_evidence(self, new_evidence: list[Evidence]) -> int: |
|
|
"""Add new evidence, deduplicating by URL. |
|
|
|
|
|
Returns: |
|
|
Number of *new* items added. |
|
|
""" |
|
|
existing_urls = {e.citation.url for e in self.evidence} |
|
|
count = 0 |
|
|
for item in new_evidence: |
|
|
if item.citation.url not in existing_urls: |
|
|
self.evidence.append(item) |
|
|
existing_urls.add(item.citation.url) |
|
|
count += 1 |
|
|
return count |
|
|
|
|
|
async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]: |
|
|
"""Search for semantically related evidence using the embedding service.""" |
|
|
if not self.embedding_service: |
|
|
return [] |
|
|
|
|
|
results = await self.embedding_service.search_similar(query, n_results=n_results) |
|
|
|
|
|
|
|
|
evidence_list = [] |
|
|
for item in results: |
|
|
meta = item.get("metadata", {}) |
|
|
authors_str = meta.get("authors", "") |
|
|
authors = [a.strip() for a in authors_str.split(",") if a.strip()] |
|
|
|
|
|
ev = Evidence( |
|
|
content=item["content"], |
|
|
citation=Citation( |
|
|
title=meta.get("title", "Related Evidence"), |
|
|
url=item["id"], |
|
|
source="pubmed", |
|
|
date=meta.get("date", "n.d."), |
|
|
authors=authors, |
|
|
), |
|
|
relevance=max(0.0, 1.0 - item.get("distance", 0.5)), |
|
|
) |
|
|
evidence_list.append(ev) |
|
|
|
|
|
return evidence_list |
|
|
|
|
|
|
|
|
|
|
|
_magentic_state_var: ContextVar[MagenticState | None] = ContextVar("magentic_state", default=None) |
|
|
|
|
|
|
|
|
def init_magentic_state(embedding_service: "EmbeddingService | None" = None) -> MagenticState: |
|
|
"""Initialize a new state for the current context.""" |
|
|
state = MagenticState(embedding_service=embedding_service) |
|
|
_magentic_state_var.set(state) |
|
|
return state |
|
|
|
|
|
|
|
|
def get_magentic_state() -> MagenticState: |
|
|
"""Get the current state. Raises RuntimeError if not initialized.""" |
|
|
state = _magentic_state_var.get() |
|
|
if state is None: |
|
|
|
|
|
return init_magentic_state() |
|
|
return state |
|
|
|