Spaces:
Sleeping
Sleeping
File size: 4,846 Bytes
f871fed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from fastapi import APIRouter, HTTPException
from loguru import logger
from api.models import ContextRequest, ContextResponse
from open_notebook.domain.notebook import Note, Notebook, Source
from open_notebook.exceptions import InvalidInputError
from open_notebook.utils import token_count
router = APIRouter()
@router.post("/notebooks/{notebook_id}/context", response_model=ContextResponse)
async def get_notebook_context(notebook_id: str, context_request: ContextRequest):
"""Get context for a notebook based on configuration."""
try:
# Verify notebook exists
notebook = await Notebook.get(notebook_id)
if not notebook:
raise HTTPException(status_code=404, detail="Notebook not found")
context_data: dict[str, list[dict[str, str]]] = {"note": [], "source": []}
total_content = ""
# Process context configuration if provided
if context_request.context_config:
# Process sources
for source_id, status in context_request.context_config.sources.items():
if "not in" in status:
continue
try:
# Add table prefix if not present
full_source_id = (
source_id
if source_id.startswith("source:")
else f"source:{source_id}"
)
try:
source = await Source.get(full_source_id)
except Exception:
continue
if "insights" in status:
source_context = await source.get_context(context_size="short")
context_data["source"].append(source_context)
total_content += str(source_context)
elif "full content" in status:
source_context = await source.get_context(context_size="long")
context_data["source"].append(source_context)
total_content += str(source_context)
except Exception as e:
logger.warning(f"Error processing source {source_id}: {str(e)}")
continue
# Process notes
for note_id, status in context_request.context_config.notes.items():
if "not in" in status:
continue
try:
# Add table prefix if not present
full_note_id = (
note_id if note_id.startswith("note:") else f"note:{note_id}"
)
note = await Note.get(full_note_id)
if not note:
continue
if "full content" in status:
note_context = note.get_context(context_size="long")
context_data["note"].append(note_context)
total_content += str(note_context)
except Exception as e:
logger.warning(f"Error processing note {note_id}: {str(e)}")
continue
else:
# Default behavior - include all sources and notes with short context
sources = await notebook.get_sources()
for source in sources:
try:
source_context = await source.get_context(context_size="short")
context_data["source"].append(source_context)
total_content += str(source_context)
except Exception as e:
logger.warning(f"Error processing source {source.id}: {str(e)}")
continue
notes = await notebook.get_notes()
for note in notes:
try:
note_context = note.get_context(context_size="short")
context_data["note"].append(note_context)
total_content += str(note_context)
except Exception as e:
logger.warning(f"Error processing note {note.id}: {str(e)}")
continue
# Calculate estimated token count
estimated_tokens = token_count(total_content) if total_content else 0
return ContextResponse(
notebook_id=notebook_id,
sources=context_data["source"],
notes=context_data["note"],
total_tokens=estimated_tokens,
)
except HTTPException:
raise
except InvalidInputError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error getting context for notebook {notebook_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error getting context: {str(e)}")
|