Spaces:
Configuration error
Configuration error
File size: 7,995 Bytes
88f8604 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
"""
Antigravity Notebook - Context Manager
The "Brain" that makes whole-notebook reasoning possible.
This service implements the key NotebookLM functionality:
- If the notebook fits in context: Load EVERYTHING (true whole-notebook reasoning)
- If the notebook is too large: Intelligently select the most relevant parts
"""
import torch
from typing import List, Dict, Tuple
from sqlalchemy.orm import Session
from uuid import UUID
from backend.config import settings
from backend.models.clara import ClaraModel
from backend.services.storage import StorageService
class ContextManager:
"""
Manages context preparation for notebook queries.
The magic of NotebookLM is "whole-context awareness". Since CLaRa
compresses text by ~16x, we can fit 10-20 books worth of content
in a single context window.
"""
def __init__(
self,
clara: ClaraModel,
storage: StorageService,
max_tokens: int = None
):
self.clara = clara
self.storage = storage
self.max_tokens = max_tokens or settings.MAX_CONTEXT_TOKENS
print(f"β
ContextManager initialized (max tokens: {self.max_tokens})")
def prepare_notebook_context(
self,
db: Session,
notebook_id: UUID,
query: str
) -> Tuple[torch.Tensor, List[Dict], Dict]:
"""
The Magic Function: Prepares context for a notebook query.
This is where the NotebookLM magic happens. We decide what
memories to load into the AI's brain.
Args:
db: Database session
notebook_id: Notebook UUID
query: User query
Returns:
Tuple of:
- Combined latent tensor (stacked context)
- List of source metadata (for citations)
- Statistics dict (token usage, etc.)
"""
print(f"\nπ§ Preparing context for notebook {notebook_id}")
# 1. Fetch ALL compressed tensors for this notebook
all_tensor_data = self.storage.get_notebook_tensors(db, notebook_id)
if not all_tensor_data:
raise ValueError(f"No tensors found for notebook {notebook_id}")
print(f"π Found {len(all_tensor_data)} tensor segments across sources")
# Extract tensors and metadata
all_tensors = [td["tensor"] for td in all_tensor_data]
source_map = [
{
"source_id": str(td["source_id"]),
"filename": td["source_filename"],
"source_type": td["source_type"],
"segment_index": td["segment_index"]
}
for td in all_tensor_data
]
# 2. Calculate total token count
total_tokens = sum(self.clara.get_token_count(t) for t in all_tensors)
print(f"π Total tokens: {total_tokens} / {self.max_tokens} max")
# 3. Decision: Whole-notebook vs. Selective
if total_tokens <= self.max_tokens:
# β¨ SCENARIO A: The "Whole Notebook" fits!
# We stack them all. The AI reads EVERYTHING.
print("β
Full notebook fits! Using WHOLE-NOTEBOOK reasoning")
combined_context = torch.cat(all_tensors, dim=1)
selected_sources = source_map
strategy = "full_notebook"
else:
# π― SCENARIO B: Too big (e.g., 50 books)
# We must use CLaRa's retrieval to pick the best parts
print(f"β οΈ Notebook too large ({total_tokens} tokens). Using SELECTIVE retrieval")
combined_context, selected_sources = self._selective_retrieval(
query,
all_tensors,
source_map
)
strategy = "selective_retrieval"
# 4. Generate statistics
stats = {
"total_segments": len(all_tensor_data),
"total_tokens": total_tokens,
"selected_segments": len(selected_sources),
"selected_tokens": self.clara.get_token_count(combined_context),
"max_tokens": self.max_tokens,
"context_usage_percent": round(
(self.clara.get_token_count(combined_context) / self.max_tokens) * 100,
2
),
"strategy": strategy,
"can_fit_full_context": total_tokens <= self.max_tokens
}
print(f"π Context prepared: {stats['selected_tokens']} tokens ({stats['context_usage_percent']}% usage)")
return combined_context, selected_sources, stats
def _selective_retrieval(
self,
query: str,
tensors: List[torch.Tensor],
source_map: List[Dict]
) -> Tuple[torch.Tensor, List[Dict]]:
"""
Selective retrieval: Pick the most relevant tensors that fit in budget.
Uses CLaRa's ranking to score tensors by relevance to the query,
then greedily selects the highest-scoring tensors until we hit
the token budget.
Args:
query: User query
tensors: All available tensors
source_map: Metadata for each tensor
Returns:
Tuple of (combined tensor, selected source metadata)
"""
print("π Ranking tensors by relevance...")
# Score all tensors against the query
scores = self.clara.rank_latents(query, tensors)
# Create scored list
scored_tensors = [
{
"tensor": tensors[i],
"source": source_map[i],
"score": scores[i],
"tokens": self.clara.get_token_count(tensors[i])
}
for i in range(len(tensors))
]
# Sort by score (highest first)
scored_tensors.sort(key=lambda x: x["score"], reverse=True)
# Greedy selection (knapsack problem)
selected = []
total_tokens = 0
for item in scored_tensors:
if total_tokens + item["tokens"] <= self.max_tokens:
selected.append(item)
total_tokens += item["tokens"]
else:
# Would exceed budget, skip
continue
print(f"β
Selected {len(selected)}/{len(tensors)} segments ({total_tokens} tokens)")
# Combine selected tensors
selected_tensors = [s["tensor"] for s in selected]
selected_sources = [s["source"] for s in selected]
combined = torch.cat(selected_tensors, dim=1) if selected_tensors else tensors[0]
return combined, selected_sources
def get_notebook_stats(
self,
db: Session,
notebook_id: UUID
) -> Dict:
"""
Get statistics about a notebook's context usage.
Useful for showing users how much of their context budget
is being used (like the memory meter in the UI).
Args:
db: Database session
notebook_id: Notebook UUID
Returns:
Statistics dictionary
"""
all_tensor_data = self.storage.get_notebook_tensors(db, notebook_id)
if not all_tensor_data:
return {
"total_segments": 0,
"total_tokens": 0,
"max_tokens": self.max_tokens,
"context_usage_percent": 0.0,
"can_fit_full_context": True
}
all_tensors = [td["tensor"] for td in all_tensor_data]
total_tokens = sum(self.clara.get_token_count(t) for t in all_tensors)
return {
"total_segments": len(all_tensor_data),
"total_tokens": total_tokens,
"max_tokens": self.max_tokens,
"context_usage_percent": round((total_tokens / self.max_tokens) * 100, 2),
"can_fit_full_context": total_tokens <= self.max_tokens
}
def get_context_manager(
clara: ClaraModel,
storage: StorageService
) -> ContextManager:
"""Factory function to create ContextManager instance"""
return ContextManager(clara, storage)
|