File size: 6,345 Bytes
7baf8ba
e0c585c
7baf8ba
 
 
 
 
 
 
 
 
 
 
e0c585c
 
 
 
7baf8ba
 
 
 
e0c585c
 
 
 
 
 
 
 
 
7baf8ba
 
 
 
e0c585c
 
7baf8ba
e0c585c
 
 
 
 
7baf8ba
 
e0c585c
 
 
 
 
 
 
 
7baf8ba
 
 
 
 
 
 
e0c585c
 
 
 
 
 
7baf8ba
e0c585c
 
7baf8ba
 
e0c585c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7baf8ba
e0c585c
 
 
 
7baf8ba
 
e0c585c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2b22f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0c585c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""Shared research memory layer for all orchestration modes.

Design Pattern: Dependency Injection
- Receives embedding service via constructor
- Uses service_loader.get_embedding_service() as default (Strategy Pattern)
- Allows testing with mock services

SOLID Principles:
- Dependency Inversion: Depends on EmbeddingServiceProtocol, not concrete class
- Open/Closed: Works with any service implementing the protocol
"""

from typing import TYPE_CHECKING, Any, get_args

import structlog

from src.agents.graph.state import Conflict, Hypothesis
from src.utils.models import Citation, Evidence, SourceName

if TYPE_CHECKING:
    from src.services.embedding_protocol import EmbeddingServiceProtocol

logger = structlog.get_logger()


class ResearchMemory:
    """Shared cognitive state for research workflows.

    This is the memory layer that ALL modes use.
    It mimics the LangGraph state management but for manual orchestration.

    The embedding service is selected via get_embedding_service(), which returns:
    - LlamaIndexRAGService (premium tier) if OPENAI_API_KEY is available
    - EmbeddingService (free tier) as fallback
    """

    def __init__(self, query: str, embedding_service: "EmbeddingServiceProtocol | None" = None):
        """Initialize ResearchMemory with a query and optional embedding service.

        Args:
            query: The research query to track evidence for.
            embedding_service: Service for semantic search and deduplication.
                             Uses get_embedding_service() if not provided,
                             which selects the best available service.
        """
        self.query = query
        self.hypotheses: list[Hypothesis] = []
        self.conflicts: list[Conflict] = []
        self.evidence_ids: list[str] = []
        self._evidence_cache: dict[str, Evidence] = {}
        self.iteration_count: int = 0

        # Use service loader for tiered service selection (Strategy Pattern)
        if embedding_service is None:
            from src.utils.service_loader import get_embedding_service

            self._embedding_service: EmbeddingServiceProtocol = get_embedding_service()
        else:
            self._embedding_service = embedding_service

    async def store_evidence(self, evidence: list[Evidence]) -> list[str]:
        """Store evidence and return new IDs (deduped)."""
        if not self._embedding_service:
            return []

        # Deduplicate and store (deduplicate() already calls add_evidence() internally)
        unique = await self._embedding_service.deduplicate(evidence)

        # Track IDs and cache (evidence already stored by deduplicate())
        new_ids = []
        for ev in unique:
            ev_id = ev.citation.url
            new_ids.append(ev_id)
            self._evidence_cache[ev_id] = ev

        self.evidence_ids.extend(new_ids)
        if new_ids:
            logger.info("Stored new evidence", count=len(new_ids))
        return new_ids

    def get_all_evidence(self) -> list[Evidence]:
        """Get all accumulated evidence objects."""
        return list(self._evidence_cache.values())

    async def get_relevant_evidence(self, n: int = 20) -> list[Evidence]:
        """Retrieve relevant evidence for current query."""
        if not self._embedding_service:
            return []

        results = await self._embedding_service.search_similar(self.query, n_results=n)
        evidence_list = []

        for r in results:
            meta = r.get("metadata", {})
            authors_str = meta.get("authors", "")
            authors = [a.strip() for a in authors_str.split(",")] if authors_str else []

            # Reconstruct Evidence object
            source_raw = meta.get("source", "web")

            # Validate source against canonical SourceName type (avoids drift)
            valid_sources = get_args(SourceName)
            source_name: Any = source_raw if source_raw in valid_sources else "web"

            citation = Citation(
                source=source_name,
                title=meta.get("title", "Unknown"),
                url=meta.get("url", r.get("id", "")),
                date=meta.get("date", "Unknown"),
                authors=authors,
            )

            evidence_list.append(
                Evidence(
                    content=r.get("content", ""),
                    citation=citation,
                    relevance=1.0 - r.get("distance", 0.5),  # Approx conversion
                )
            )

        return evidence_list

    async def get_context_summary(self) -> str:
        """Generate a summary of all collected evidence for the final report."""
        if not self.evidence_ids:
            return "No evidence collected."

        summary = [f"Research Query: {self.query}\n"]

        # Add Hypotheses
        if self.hypotheses:
            summary.append("## Hypotheses")
            for h in self.hypotheses:
                summary.append(f"- {h.statement} (Conf: {h.confidence})")
            summary.append("")

        # Add Top Evidence (limit to avoid token overflow)
        # We use get_all_evidence() but might need to summarize if too large
        evidence = self.get_all_evidence()
        summary.append(f"## Evidence ({len(evidence)} items)")

        # Group by source for cleaner summary
        for i, ev in enumerate(evidence[:20], 1):  # Limit to top 20 items
            summary.append(f"{i}. {ev.citation.title} ({ev.citation.date})")
            summary.append(f"   {ev.content[:200]}...")  # Brief snippet

        return "\n".join(summary)

    def add_hypothesis(self, hypothesis: Hypothesis) -> None:
        """Add a hypothesis to tracking."""
        self.hypotheses.append(hypothesis)
        logger.info("Added hypothesis", id=hypothesis.id, confidence=hypothesis.confidence)

    def add_conflict(self, conflict: Conflict) -> None:
        """Add a detected conflict."""
        self.conflicts.append(conflict)
        logger.info("Added conflict", id=conflict.id)

    def get_open_conflicts(self) -> list[Conflict]:
        """Get unresolved conflicts."""
        return [c for c in self.conflicts if c.status == "open"]

    def get_confirmed_hypotheses(self) -> list[Hypothesis]:
        """Get high-confidence hypotheses."""
        return [h for h in self.hypotheses if h.confidence > 0.8]