File size: 7,995 Bytes
88f8604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Antigravity Notebook - Context Manager
The "Brain" that makes whole-notebook reasoning possible.

This service implements the key NotebookLM functionality:
- If the notebook fits in context: Load EVERYTHING (true whole-notebook reasoning)
- If the notebook is too large: Intelligently select the most relevant parts
"""

import torch
from typing import List, Dict, Tuple
from sqlalchemy.orm import Session
from uuid import UUID

from backend.config import settings
from backend.models.clara import ClaraModel
from backend.services.storage import StorageService


class ContextManager:
    """
    Manages context preparation for notebook queries.

    The magic of NotebookLM is "whole-context awareness". Since CLaRa
    compresses text by ~16x, we can fit 10-20 books worth of content
    in a single context window.
    """

    def __init__(
        self,
        clara: ClaraModel,
        storage: StorageService,
        max_tokens: int = None
    ):
        self.clara = clara
        self.storage = storage
        self.max_tokens = max_tokens or settings.MAX_CONTEXT_TOKENS

        print(f"βœ… ContextManager initialized (max tokens: {self.max_tokens})")

    def prepare_notebook_context(
        self,
        db: Session,
        notebook_id: UUID,
        query: str
    ) -> Tuple[torch.Tensor, List[Dict], Dict]:
        """
        The Magic Function: Prepares context for a notebook query.

        This is where the NotebookLM magic happens. We decide what
        memories to load into the AI's brain.

        Args:
            db: Database session
            notebook_id: Notebook UUID
            query: User query

        Returns:
            Tuple of:
            - Combined latent tensor (stacked context)
            - List of source metadata (for citations)
            - Statistics dict (token usage, etc.)
        """
        print(f"\n🧠 Preparing context for notebook {notebook_id}")

        # 1. Fetch ALL compressed tensors for this notebook
        all_tensor_data = self.storage.get_notebook_tensors(db, notebook_id)

        if not all_tensor_data:
            raise ValueError(f"No tensors found for notebook {notebook_id}")

        print(f"πŸ“š Found {len(all_tensor_data)} tensor segments across sources")

        # Extract tensors and metadata
        all_tensors = [td["tensor"] for td in all_tensor_data]
        source_map = [
            {
                "source_id": str(td["source_id"]),
                "filename": td["source_filename"],
                "source_type": td["source_type"],
                "segment_index": td["segment_index"]
            }
            for td in all_tensor_data
        ]

        # 2. Calculate total token count
        total_tokens = sum(self.clara.get_token_count(t) for t in all_tensors)

        print(f"πŸ“Š Total tokens: {total_tokens} / {self.max_tokens} max")

        # 3. Decision: Whole-notebook vs. Selective
        if total_tokens <= self.max_tokens:
            # ✨ SCENARIO A: The "Whole Notebook" fits!
            # We stack them all. The AI reads EVERYTHING.
            print("βœ… Full notebook fits! Using WHOLE-NOTEBOOK reasoning")

            combined_context = torch.cat(all_tensors, dim=1)
            selected_sources = source_map
            strategy = "full_notebook"

        else:
            # 🎯 SCENARIO B: Too big (e.g., 50 books)
            # We must use CLaRa's retrieval to pick the best parts
            print(f"⚠️  Notebook too large ({total_tokens} tokens). Using SELECTIVE retrieval")

            combined_context, selected_sources = self._selective_retrieval(
                query,
                all_tensors,
                source_map
            )
            strategy = "selective_retrieval"

        # 4. Generate statistics
        stats = {
            "total_segments": len(all_tensor_data),
            "total_tokens": total_tokens,
            "selected_segments": len(selected_sources),
            "selected_tokens": self.clara.get_token_count(combined_context),
            "max_tokens": self.max_tokens,
            "context_usage_percent": round(
                (self.clara.get_token_count(combined_context) / self.max_tokens) * 100,
                2
            ),
            "strategy": strategy,
            "can_fit_full_context": total_tokens <= self.max_tokens
        }

        print(f"πŸ“ˆ Context prepared: {stats['selected_tokens']} tokens ({stats['context_usage_percent']}% usage)")

        return combined_context, selected_sources, stats

    def _selective_retrieval(
        self,
        query: str,
        tensors: List[torch.Tensor],
        source_map: List[Dict]
    ) -> Tuple[torch.Tensor, List[Dict]]:
        """
        Selective retrieval: Pick the most relevant tensors that fit in budget.

        Uses CLaRa's ranking to score tensors by relevance to the query,
        then greedily selects the highest-scoring tensors until we hit
        the token budget.

        Args:
            query: User query
            tensors: All available tensors
            source_map: Metadata for each tensor

        Returns:
            Tuple of (combined tensor, selected source metadata)
        """
        print("πŸ” Ranking tensors by relevance...")

        # Score all tensors against the query
        scores = self.clara.rank_latents(query, tensors)

        # Create scored list
        scored_tensors = [
            {
                "tensor": tensors[i],
                "source": source_map[i],
                "score": scores[i],
                "tokens": self.clara.get_token_count(tensors[i])
            }
            for i in range(len(tensors))
        ]

        # Sort by score (highest first)
        scored_tensors.sort(key=lambda x: x["score"], reverse=True)

        # Greedy selection (knapsack problem)
        selected = []
        total_tokens = 0

        for item in scored_tensors:
            if total_tokens + item["tokens"] <= self.max_tokens:
                selected.append(item)
                total_tokens += item["tokens"]
            else:
                # Would exceed budget, skip
                continue

        print(f"βœ… Selected {len(selected)}/{len(tensors)} segments ({total_tokens} tokens)")

        # Combine selected tensors
        selected_tensors = [s["tensor"] for s in selected]
        selected_sources = [s["source"] for s in selected]

        combined = torch.cat(selected_tensors, dim=1) if selected_tensors else tensors[0]

        return combined, selected_sources

    def get_notebook_stats(
        self,
        db: Session,
        notebook_id: UUID
    ) -> Dict:
        """
        Get statistics about a notebook's context usage.

        Useful for showing users how much of their context budget
        is being used (like the memory meter in the UI).

        Args:
            db: Database session
            notebook_id: Notebook UUID

        Returns:
            Statistics dictionary
        """
        all_tensor_data = self.storage.get_notebook_tensors(db, notebook_id)

        if not all_tensor_data:
            return {
                "total_segments": 0,
                "total_tokens": 0,
                "max_tokens": self.max_tokens,
                "context_usage_percent": 0.0,
                "can_fit_full_context": True
            }

        all_tensors = [td["tensor"] for td in all_tensor_data]
        total_tokens = sum(self.clara.get_token_count(t) for t in all_tensors)

        return {
            "total_segments": len(all_tensor_data),
            "total_tokens": total_tokens,
            "max_tokens": self.max_tokens,
            "context_usage_percent": round((total_tokens / self.max_tokens) * 100, 2),
            "can_fit_full_context": total_tokens <= self.max_tokens
        }


def get_context_manager(
    clara: ClaraModel,
    storage: StorageService
) -> ContextManager:
    """Factory function to create ContextManager instance"""
    return ContextManager(clara, storage)