translation_app / backend /services /context_manager.py
Athena1621's picture
feat: Introduce new backend architecture with notebooks, sources, chat, and CLaRa models, alongside database schema and updated deployment scripts, while removing old frontend, deployment files, and previous backend components.
88f8604
"""
Antigravity Notebook - Context Manager
The "Brain" that makes whole-notebook reasoning possible.
This service implements the key NotebookLM functionality:
- If the notebook fits in context: Load EVERYTHING (true whole-notebook reasoning)
- If the notebook is too large: Intelligently select the most relevant parts
"""
import torch
from typing import List, Dict, Tuple
from sqlalchemy.orm import Session
from uuid import UUID
from backend.config import settings
from backend.models.clara import ClaraModel
from backend.services.storage import StorageService
class ContextManager:
"""
Manages context preparation for notebook queries.
The magic of NotebookLM is "whole-context awareness". Since CLaRa
compresses text by ~16x, we can fit 10-20 books worth of content
in a single context window.
"""
def __init__(
self,
clara: ClaraModel,
storage: StorageService,
max_tokens: int = None
):
self.clara = clara
self.storage = storage
self.max_tokens = max_tokens or settings.MAX_CONTEXT_TOKENS
print(f"βœ… ContextManager initialized (max tokens: {self.max_tokens})")
def prepare_notebook_context(
self,
db: Session,
notebook_id: UUID,
query: str
) -> Tuple[torch.Tensor, List[Dict], Dict]:
"""
The Magic Function: Prepares context for a notebook query.
This is where the NotebookLM magic happens. We decide what
memories to load into the AI's brain.
Args:
db: Database session
notebook_id: Notebook UUID
query: User query
Returns:
Tuple of:
- Combined latent tensor (stacked context)
- List of source metadata (for citations)
- Statistics dict (token usage, etc.)
"""
print(f"\n🧠 Preparing context for notebook {notebook_id}")
# 1. Fetch ALL compressed tensors for this notebook
all_tensor_data = self.storage.get_notebook_tensors(db, notebook_id)
if not all_tensor_data:
raise ValueError(f"No tensors found for notebook {notebook_id}")
print(f"πŸ“š Found {len(all_tensor_data)} tensor segments across sources")
# Extract tensors and metadata
all_tensors = [td["tensor"] for td in all_tensor_data]
source_map = [
{
"source_id": str(td["source_id"]),
"filename": td["source_filename"],
"source_type": td["source_type"],
"segment_index": td["segment_index"]
}
for td in all_tensor_data
]
# 2. Calculate total token count
total_tokens = sum(self.clara.get_token_count(t) for t in all_tensors)
print(f"πŸ“Š Total tokens: {total_tokens} / {self.max_tokens} max")
# 3. Decision: Whole-notebook vs. Selective
if total_tokens <= self.max_tokens:
# ✨ SCENARIO A: The "Whole Notebook" fits!
# We stack them all. The AI reads EVERYTHING.
print("βœ… Full notebook fits! Using WHOLE-NOTEBOOK reasoning")
combined_context = torch.cat(all_tensors, dim=1)
selected_sources = source_map
strategy = "full_notebook"
else:
# 🎯 SCENARIO B: Too big (e.g., 50 books)
# We must use CLaRa's retrieval to pick the best parts
print(f"⚠️ Notebook too large ({total_tokens} tokens). Using SELECTIVE retrieval")
combined_context, selected_sources = self._selective_retrieval(
query,
all_tensors,
source_map
)
strategy = "selective_retrieval"
# 4. Generate statistics
stats = {
"total_segments": len(all_tensor_data),
"total_tokens": total_tokens,
"selected_segments": len(selected_sources),
"selected_tokens": self.clara.get_token_count(combined_context),
"max_tokens": self.max_tokens,
"context_usage_percent": round(
(self.clara.get_token_count(combined_context) / self.max_tokens) * 100,
2
),
"strategy": strategy,
"can_fit_full_context": total_tokens <= self.max_tokens
}
print(f"πŸ“ˆ Context prepared: {stats['selected_tokens']} tokens ({stats['context_usage_percent']}% usage)")
return combined_context, selected_sources, stats
def _selective_retrieval(
self,
query: str,
tensors: List[torch.Tensor],
source_map: List[Dict]
) -> Tuple[torch.Tensor, List[Dict]]:
"""
Selective retrieval: Pick the most relevant tensors that fit in budget.
Uses CLaRa's ranking to score tensors by relevance to the query,
then greedily selects the highest-scoring tensors until we hit
the token budget.
Args:
query: User query
tensors: All available tensors
source_map: Metadata for each tensor
Returns:
Tuple of (combined tensor, selected source metadata)
"""
print("πŸ” Ranking tensors by relevance...")
# Score all tensors against the query
scores = self.clara.rank_latents(query, tensors)
# Create scored list
scored_tensors = [
{
"tensor": tensors[i],
"source": source_map[i],
"score": scores[i],
"tokens": self.clara.get_token_count(tensors[i])
}
for i in range(len(tensors))
]
# Sort by score (highest first)
scored_tensors.sort(key=lambda x: x["score"], reverse=True)
# Greedy selection (knapsack problem)
selected = []
total_tokens = 0
for item in scored_tensors:
if total_tokens + item["tokens"] <= self.max_tokens:
selected.append(item)
total_tokens += item["tokens"]
else:
# Would exceed budget, skip
continue
print(f"βœ… Selected {len(selected)}/{len(tensors)} segments ({total_tokens} tokens)")
# Combine selected tensors
selected_tensors = [s["tensor"] for s in selected]
selected_sources = [s["source"] for s in selected]
combined = torch.cat(selected_tensors, dim=1) if selected_tensors else tensors[0]
return combined, selected_sources
def get_notebook_stats(
self,
db: Session,
notebook_id: UUID
) -> Dict:
"""
Get statistics about a notebook's context usage.
Useful for showing users how much of their context budget
is being used (like the memory meter in the UI).
Args:
db: Database session
notebook_id: Notebook UUID
Returns:
Statistics dictionary
"""
all_tensor_data = self.storage.get_notebook_tensors(db, notebook_id)
if not all_tensor_data:
return {
"total_segments": 0,
"total_tokens": 0,
"max_tokens": self.max_tokens,
"context_usage_percent": 0.0,
"can_fit_full_context": True
}
all_tensors = [td["tensor"] for td in all_tensor_data]
total_tokens = sum(self.clara.get_token_count(t) for t in all_tensors)
return {
"total_segments": len(all_tensor_data),
"total_tokens": total_tokens,
"max_tokens": self.max_tokens,
"context_usage_percent": round((total_tokens / self.max_tokens) * 100, 2),
"can_fit_full_context": total_tokens <= self.max_tokens
}
def get_context_manager(
clara: ClaraModel,
storage: StorageService
) -> ContextManager:
"""Factory function to create ContextManager instance"""
return ContextManager(clara, storage)