from fastapi import APIRouter, HTTPException from loguru import logger from api.models import ContextRequest, ContextResponse from open_notebook.domain.notebook import Note, Notebook, Source from open_notebook.exceptions import InvalidInputError from open_notebook.utils import token_count router = APIRouter() @router.post("/notebooks/{notebook_id}/context", response_model=ContextResponse) async def get_notebook_context(notebook_id: str, context_request: ContextRequest): """Get context for a notebook based on configuration.""" try: # Verify notebook exists notebook = await Notebook.get(notebook_id) if not notebook: raise HTTPException(status_code=404, detail="Notebook not found") context_data: dict[str, list[dict[str, str]]] = {"note": [], "source": []} total_content = "" # Process context configuration if provided if context_request.context_config: # Process sources for source_id, status in context_request.context_config.sources.items(): if "not in" in status: continue try: # Add table prefix if not present full_source_id = ( source_id if source_id.startswith("source:") else f"source:{source_id}" ) try: source = await Source.get(full_source_id) except Exception: continue if "insights" in status: source_context = await source.get_context(context_size="short") context_data["source"].append(source_context) total_content += str(source_context) elif "full content" in status: source_context = await source.get_context(context_size="long") context_data["source"].append(source_context) total_content += str(source_context) except Exception as e: logger.warning(f"Error processing source {source_id}: {str(e)}") continue # Process notes for note_id, status in context_request.context_config.notes.items(): if "not in" in status: continue try: # Add table prefix if not present full_note_id = ( note_id if note_id.startswith("note:") else f"note:{note_id}" ) note = await Note.get(full_note_id) if not note: continue if "full content" in status: note_context = note.get_context(context_size="long") context_data["note"].append(note_context) total_content += str(note_context) except Exception as e: logger.warning(f"Error processing note {note_id}: {str(e)}") continue else: # Default behavior - include all sources and notes with short context sources = await notebook.get_sources() for source in sources: try: source_context = await source.get_context(context_size="short") context_data["source"].append(source_context) total_content += str(source_context) except Exception as e: logger.warning(f"Error processing source {source.id}: {str(e)}") continue notes = await notebook.get_notes() for note in notes: try: note_context = note.get_context(context_size="short") context_data["note"].append(note_context) total_content += str(note_context) except Exception as e: logger.warning(f"Error processing note {note.id}: {str(e)}") continue # Calculate estimated token count estimated_tokens = token_count(total_content) if total_content else 0 return ContextResponse( notebook_id=notebook_id, sources=context_data["source"], notes=context_data["note"], total_tokens=estimated_tokens, ) except HTTPException: raise except InvalidInputError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Error getting context for notebook {notebook_id}: {str(e)}") raise HTTPException(status_code=500, detail=f"Error getting context: {str(e)}")