File size: 4,718 Bytes
e0c585c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Shared research memory layer for all orchestration modes."""

from typing import Any

import structlog

from src.agents.graph.state import Conflict, Hypothesis
from src.services.embeddings import EmbeddingService
from src.utils.models import Citation, Evidence

logger = structlog.get_logger()


class ResearchMemory:
    """Shared cognitive state for research workflows.

    This is the memory layer that ALL modes use.
    It mimics the LangGraph state management but for manual orchestration.
    """

    def __init__(self, query: str, embedding_service: EmbeddingService | None = None):
        """Initialize ResearchMemory with a query and optional embedding service.

        Args:
            query: The research query to track evidence for.
            embedding_service: Service for semantic search and deduplication.
                             Creates a new instance if not provided.
        """
        self.query = query
        self.hypotheses: list[Hypothesis] = []
        self.conflicts: list[Conflict] = []
        self.evidence_ids: list[str] = []
        self._evidence_cache: dict[str, Evidence] = {}
        self.iteration_count: int = 0

        # Injected service
        self._embedding_service = embedding_service or EmbeddingService()

    async def store_evidence(self, evidence: list[Evidence]) -> list[str]:
        """Store evidence and return new IDs (deduped)."""
        if not self._embedding_service:
            return []

        unique = await self._embedding_service.deduplicate(evidence)
        new_ids = []

        for ev in unique:
            ev_id = ev.citation.url
            await self._embedding_service.add_evidence(
                evidence_id=ev_id,
                content=ev.content,
                metadata={
                    "source": ev.citation.source,
                    "title": ev.citation.title,
                    "date": ev.citation.date,
                    "authors": ",".join(ev.citation.authors or []),
                    "url": ev.citation.url,
                },
            )
            new_ids.append(ev_id)
            self._evidence_cache[ev_id] = ev

        self.evidence_ids.extend(new_ids)
        if new_ids:
            logger.info("Stored new evidence", count=len(new_ids))
        return new_ids

    def get_all_evidence(self) -> list[Evidence]:
        """Get all accumulated evidence objects."""
        return list(self._evidence_cache.values())

    async def get_relevant_evidence(self, n: int = 20) -> list[Evidence]:
        """Retrieve relevant evidence for current query."""
        if not self._embedding_service:
            return []

        results = await self._embedding_service.search_similar(self.query, n_results=n)
        evidence_list = []

        for r in results:
            meta = r.get("metadata", {})
            authors_str = meta.get("authors", "")
            authors = authors_str.split(",") if authors_str else []

            # Reconstruct Evidence object
            source_raw = meta.get("source", "web")

            # Basic validation/fallback for source
            valid_sources = [
                "pubmed",
                "clinicaltrials",
                "europepmc",
                "preprint",
                "openalex",
                "web",
            ]
            source_name: Any = source_raw if source_raw in valid_sources else "web"

            citation = Citation(
                source=source_name,
                title=meta.get("title", "Unknown"),
                url=meta.get("url", r.get("id", "")),
                date=meta.get("date", "Unknown"),
                authors=authors,
            )

            evidence_list.append(
                Evidence(
                    content=r.get("content", ""),
                    citation=citation,
                    relevance=1.0 - r.get("distance", 0.5),  # Approx conversion
                )
            )

        return evidence_list

    def add_hypothesis(self, hypothesis: Hypothesis) -> None:
        """Add a hypothesis to tracking."""
        self.hypotheses.append(hypothesis)
        logger.info("Added hypothesis", id=hypothesis.id, confidence=hypothesis.confidence)

    def add_conflict(self, conflict: Conflict) -> None:
        """Add a detected conflict."""
        self.conflicts.append(conflict)
        logger.info("Added conflict", id=conflict.id)

    def get_open_conflicts(self) -> list[Conflict]:
        """Get unresolved conflicts."""
        return [c for c in self.conflicts if c.status == "open"]

    def get_confirmed_hypotheses(self) -> list[Hypothesis]:
        """Get high-confidence hypotheses."""
        return [h for h in self.hypotheses if h.confidence > 0.8]