Spaces:
Runtime error
Runtime error
| import asyncio | |
| import os | |
| from pathlib import Path | |
| from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union | |
| from loguru import logger | |
| from pydantic import BaseModel, ConfigDict, Field, field_validator | |
| from surreal_commands import submit_command | |
| from surrealdb import RecordID | |
| from open_notebook.database.repository import ensure_record_id, repo_query | |
| from open_notebook.domain.base import ObjectModel | |
| from open_notebook.exceptions import DatabaseOperationError, InvalidInputError | |
| class Notebook(ObjectModel): | |
| table_name: ClassVar[str] = "notebook" | |
| name: str | |
| description: str | |
| archived: Optional[bool] = False | |
| def name_must_not_be_empty(cls, v): | |
| if not v.strip(): | |
| raise InvalidInputError("Notebook name cannot be empty") | |
| return v | |
| async def get_sources(self) -> List["Source"]: | |
| try: | |
| srcs = await repo_query( | |
| """ | |
| select * omit source.full_text from ( | |
| select in as source from reference where out=$id | |
| fetch source | |
| ) order by source.updated desc | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return [Source(**src["source"]) for src in srcs] if srcs else [] | |
| except Exception as e: | |
| logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def get_notes(self) -> List["Note"]: | |
| try: | |
| srcs = await repo_query( | |
| """ | |
| select * omit note.content, note.embedding from ( | |
| select in as note from artifact where out=$id | |
| fetch note | |
| ) order by note.updated desc | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return [Note(**src["note"]) for src in srcs] if srcs else [] | |
| except Exception as e: | |
| logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def get_chat_sessions(self) -> List["ChatSession"]: | |
| try: | |
| srcs = await repo_query( | |
| """ | |
| select * from ( | |
| select | |
| <- chat_session as chat_session | |
| from refers_to | |
| where out=$id | |
| fetch chat_session | |
| ) | |
| order by chat_session.updated desc | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return ( | |
| [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else [] | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Error fetching chat sessions for notebook {self.id}: {str(e)}" | |
| ) | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def get_delete_preview(self) -> Dict[str, Any]: | |
| """ | |
| Get counts of items that would be affected by deleting this notebook. | |
| Returns a dict with: | |
| - note_count: Number of notes that will be deleted | |
| - exclusive_source_count: Sources only in this notebook (can be deleted) | |
| - shared_source_count: Sources in other notebooks (will be unlinked only) | |
| """ | |
| try: | |
| notebook_id = ensure_record_id(self.id) | |
| # Count notes | |
| note_result = await repo_query( | |
| "SELECT count() as count FROM artifact WHERE out = $notebook_id GROUP ALL", | |
| {"notebook_id": notebook_id}, | |
| ) | |
| note_count = note_result[0]["count"] if note_result else 0 | |
| # Get sources with count of references to OTHER notebooks | |
| # If assigned_others = 0, source is exclusive to this notebook | |
| # If assigned_others > 0, source is shared with other notebooks | |
| source_counts = await repo_query( | |
| """ | |
| SELECT | |
| id, | |
| count(->reference[WHERE out != $notebook_id].out) as assigned_others | |
| FROM (SELECT VALUE <-reference.in AS sources FROM $notebook_id)[0] | |
| """, | |
| {"notebook_id": notebook_id}, | |
| ) | |
| exclusive_count = 0 | |
| shared_count = 0 | |
| for src in source_counts: | |
| if src.get("assigned_others", 0) == 0: | |
| exclusive_count += 1 | |
| else: | |
| shared_count += 1 | |
| return { | |
| "note_count": note_count, | |
| "exclusive_source_count": exclusive_count, | |
| "shared_source_count": shared_count, | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting delete preview for notebook {self.id}: {e}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def delete(self, delete_exclusive_sources: bool = False) -> Dict[str, int]: | |
| """ | |
| Delete notebook with cascade deletion of notes and optional source deletion. | |
| Args: | |
| delete_exclusive_sources: If True, also delete sources that belong | |
| only to this notebook. Default is False. | |
| Returns: | |
| Dict with counts: deleted_notes, deleted_sources, unlinked_sources | |
| """ | |
| if self.id is None: | |
| raise InvalidInputError("Cannot delete notebook without an ID") | |
| try: | |
| notebook_id = ensure_record_id(self.id) | |
| deleted_notes = 0 | |
| deleted_sources = 0 | |
| unlinked_sources = 0 | |
| # 1. Get and delete all notes linked to this notebook | |
| notes = await self.get_notes() | |
| for note in notes: | |
| await note.delete() | |
| deleted_notes += 1 | |
| logger.info(f"Deleted {deleted_notes} notes for notebook {self.id}") | |
| # Delete artifact relationships | |
| await repo_query( | |
| "DELETE artifact WHERE out = $notebook_id", | |
| {"notebook_id": notebook_id}, | |
| ) | |
| # 2. Handle sources | |
| if delete_exclusive_sources: | |
| # Find sources with count of references to OTHER notebooks | |
| # If assigned_others = 0, source is exclusive to this notebook | |
| source_counts = await repo_query( | |
| """ | |
| SELECT | |
| id, | |
| count(->reference[WHERE out != $notebook_id].out) as assigned_others | |
| FROM (SELECT VALUE <-reference.in AS sources FROM $notebook_id)[0] | |
| """, | |
| {"notebook_id": notebook_id}, | |
| ) | |
| for src in source_counts: | |
| source_id = src.get("id") | |
| if source_id and src.get("assigned_others", 0) == 0: | |
| # Exclusive source - delete it | |
| try: | |
| source = await Source.get(str(source_id)) | |
| await source.delete() | |
| deleted_sources += 1 | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to delete exclusive source {source_id}: {e}" | |
| ) | |
| else: | |
| unlinked_sources += 1 | |
| else: | |
| # Just count sources that will be unlinked | |
| source_result = await repo_query( | |
| "SELECT count() as count FROM reference WHERE out = $notebook_id GROUP ALL", | |
| {"notebook_id": notebook_id}, | |
| ) | |
| unlinked_sources = source_result[0]["count"] if source_result else 0 | |
| # Delete reference relationships (unlink all sources) | |
| await repo_query( | |
| "DELETE reference WHERE out = $notebook_id", | |
| {"notebook_id": notebook_id}, | |
| ) | |
| logger.info( | |
| f"Unlinked {unlinked_sources} sources, deleted {deleted_sources} " | |
| f"exclusive sources for notebook {self.id}" | |
| ) | |
| # 3. Delete the notebook record itself | |
| await super().delete() | |
| logger.info(f"Deleted notebook {self.id}") | |
| return { | |
| "deleted_notes": deleted_notes, | |
| "deleted_sources": deleted_sources, | |
| "unlinked_sources": unlinked_sources, | |
| } | |
| except Exception as e: | |
| logger.error(f"Error deleting notebook {self.id}: {e}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(f"Failed to delete notebook: {e}") | |
| class Asset(BaseModel): | |
| file_path: Optional[str] = None | |
| url: Optional[str] = None | |
| class SourceEmbedding(ObjectModel): | |
| table_name: ClassVar[str] = "source_embedding" | |
| content: str | |
| async def get_source(self) -> "Source": | |
| try: | |
| src = await repo_query( | |
| """ | |
| select source.* from $id fetch source | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return Source(**src[0]["source"]) | |
| except Exception as e: | |
| logger.error(f"Error fetching source for embedding {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| class SourceInsight(ObjectModel): | |
| table_name: ClassVar[str] = "source_insight" | |
| insight_type: str | |
| content: str | |
| async def get_source(self) -> "Source": | |
| try: | |
| src = await repo_query( | |
| """ | |
| select source.* from $id fetch source | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return Source(**src[0]["source"]) | |
| except Exception as e: | |
| logger.error(f"Error fetching source for insight {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def save_as_note(self, notebook_id: Optional[str] = None) -> Any: | |
| source = await self.get_source() | |
| note = Note( | |
| title=f"{self.insight_type} from source {source.title}", | |
| content=self.content, | |
| ) | |
| await note.save() | |
| if notebook_id: | |
| await note.add_to_notebook(notebook_id) | |
| return note | |
| class Source(ObjectModel): | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| table_name: ClassVar[str] = "source" | |
| asset: Optional[Asset] = None | |
| title: Optional[str] = None | |
| topics: Optional[List[str]] = Field(default_factory=list) | |
| full_text: Optional[str] = None | |
| command: Optional[Union[str, RecordID]] = Field( | |
| default=None, description="Link to surreal-commands processing job" | |
| ) | |
| def parse_command(cls, value): | |
| """Parse command field to ensure RecordID format""" | |
| if isinstance(value, str) and value: | |
| return ensure_record_id(value) | |
| return value | |
| def parse_id(cls, value): | |
| """Parse id field to handle both string and RecordID inputs""" | |
| if value is None: | |
| return None | |
| if isinstance(value, RecordID): | |
| return str(value) | |
| return str(value) if value else None | |
| async def get_status(self) -> Optional[str]: | |
| """Get the processing status of the associated command""" | |
| if not self.command: | |
| return None | |
| try: | |
| from surreal_commands import get_command_status | |
| status = await get_command_status(str(self.command)) | |
| return status.status if status else "unknown" | |
| except Exception as e: | |
| logger.warning(f"Failed to get command status for {self.command}: {e}") | |
| return "unknown" | |
| async def get_processing_progress(self) -> Optional[Dict[str, Any]]: | |
| """Get detailed processing information for the associated command""" | |
| if not self.command: | |
| return None | |
| try: | |
| from surreal_commands import get_command_status | |
| status_result = await get_command_status(str(self.command)) | |
| if not status_result: | |
| return None | |
| # Extract execution metadata if available | |
| result = getattr(status_result, "result", None) | |
| execution_metadata = ( | |
| result.get("execution_metadata", {}) if isinstance(result, dict) else {} | |
| ) | |
| return { | |
| "status": status_result.status, | |
| "started_at": execution_metadata.get("started_at"), | |
| "completed_at": execution_metadata.get("completed_at"), | |
| "error": getattr(status_result, "error_message", None), | |
| "result": result, | |
| } | |
| except Exception as e: | |
| logger.warning(f"Failed to get command progress for {self.command}: {e}") | |
| return None | |
| async def get_context( | |
| self, context_size: Literal["short", "long"] = "short" | |
| ) -> Dict[str, Any]: | |
| insights_list = await self.get_insights() | |
| insights = [insight.model_dump() for insight in insights_list] | |
| if context_size == "long": | |
| return dict( | |
| id=self.id, | |
| title=self.title, | |
| insights=insights, | |
| full_text=self.full_text, | |
| ) | |
| else: | |
| return dict(id=self.id, title=self.title, insights=insights) | |
| async def get_embedded_chunks(self) -> int: | |
| try: | |
| result = await repo_query( | |
| """ | |
| select count() as chunks from source_embedding where source=$id GROUP ALL | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| if len(result) == 0: | |
| return 0 | |
| return result[0]["chunks"] | |
| except Exception as e: | |
| logger.error(f"Error fetching chunks count for source {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}") | |
| async def get_insights(self) -> List[SourceInsight]: | |
| try: | |
| result = await repo_query( | |
| """ | |
| SELECT * FROM source_insight WHERE source=$id | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return [SourceInsight(**insight) for insight in result] | |
| except Exception as e: | |
| logger.error(f"Error fetching insights for source {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError("Failed to fetch insights for source") | |
| async def add_to_notebook(self, notebook_id: str) -> Any: | |
| if not notebook_id: | |
| raise InvalidInputError("Notebook ID must be provided") | |
| return await self.relate("reference", notebook_id) | |
| async def vectorize(self) -> str: | |
| """ | |
| Submit vectorization as a background job using the embed_source command. | |
| This method leverages the job-based architecture to prevent HTTP connection | |
| pool exhaustion when processing large documents. The embed_source command: | |
| 1. Detects content type from file path | |
| 2. Chunks text using content-type aware splitter | |
| 3. Generates all embeddings in batches | |
| 4. Bulk inserts source_embedding records | |
| Returns: | |
| str: The command/job ID that can be used to track progress via the commands API | |
| Raises: | |
| ValueError: If source has no text to vectorize | |
| DatabaseOperationError: If job submission fails | |
| """ | |
| logger.info(f"Submitting embed_source job for source {self.id}") | |
| try: | |
| if not self.full_text or not self.full_text.strip(): | |
| raise ValueError(f"Source {self.id} has no text to vectorize") | |
| # Submit the embed_source command | |
| command_id = submit_command( | |
| "open_notebook", | |
| "embed_source", | |
| {"source_id": str(self.id)}, | |
| ) | |
| command_id_str = str(command_id) | |
| logger.info( | |
| f"Embed source job submitted for source {self.id}: " | |
| f"command_id={command_id_str}" | |
| ) | |
| return command_id_str | |
| except ValueError: | |
| raise | |
| except Exception as e: | |
| logger.error( | |
| f"Failed to submit embed_source job for source {self.id}: {e}" | |
| ) | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def add_insight(self, insight_type: str, content: str) -> Optional[str]: | |
| """ | |
| Submit insight creation as an async command (fire-and-forget). | |
| Submits a create_insight command that handles database operations with | |
| automatic retry logic for transaction conflicts. The command also submits | |
| an embed_insight command for async embedding. | |
| This method returns immediately after submitting the command - it does NOT | |
| wait for the insight to be created. Use this for batch operations where | |
| throughput is more important than immediate confirmation. | |
| Args: | |
| insight_type: Type/category of the insight | |
| content: The insight content text | |
| Returns: | |
| command_id for optional tracking, or None if submission failed | |
| Raises: | |
| InvalidInputError: If insight_type or content is empty | |
| """ | |
| if not insight_type or not content: | |
| raise InvalidInputError("Insight type and content must be provided") | |
| try: | |
| # Submit create_insight command (fire-and-forget) | |
| # Command handles retries internally for transaction conflicts | |
| command_id = submit_command( | |
| "open_notebook", | |
| "create_insight", | |
| { | |
| "source_id": str(self.id), | |
| "insight_type": insight_type, | |
| "content": content, | |
| }, | |
| ) | |
| logger.info( | |
| f"Submitted create_insight command {command_id} for source {self.id} " | |
| f"(type={insight_type})" | |
| ) | |
| return str(command_id) | |
| except Exception as e: | |
| logger.error(f"Error submitting create_insight for source {self.id}: {e}") | |
| return None | |
| def _prepare_save_data(self) -> dict: | |
| """Override to ensure command field is always RecordID format for database""" | |
| data = super()._prepare_save_data() | |
| # Ensure command field is RecordID format if not None | |
| if data.get("command") is not None: | |
| data["command"] = ensure_record_id(data["command"]) | |
| return data | |
| async def delete(self) -> bool: | |
| """Delete source and clean up associated file, embeddings, and insights.""" | |
| # Clean up uploaded file if it exists | |
| if self.asset and self.asset.file_path: | |
| file_path = Path(self.asset.file_path) | |
| if file_path.exists(): | |
| try: | |
| os.unlink(file_path) | |
| logger.info(f"Deleted file for source {self.id}: {file_path}") | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to delete file {file_path} for source {self.id}: {e}. " | |
| "Continuing with database deletion." | |
| ) | |
| else: | |
| logger.debug( | |
| f"File {file_path} not found for source {self.id}, skipping cleanup" | |
| ) | |
| # Delete associated embeddings and insights to prevent orphaned records | |
| try: | |
| source_id = ensure_record_id(self.id) | |
| await repo_query( | |
| "DELETE source_embedding WHERE source = $source_id", | |
| {"source_id": source_id}, | |
| ) | |
| await repo_query( | |
| "DELETE source_insight WHERE source = $source_id", | |
| {"source_id": source_id}, | |
| ) | |
| logger.debug(f"Deleted embeddings and insights for source {self.id}") | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to delete embeddings/insights for source {self.id}: {e}. " | |
| "Continuing with source deletion." | |
| ) | |
| # Call parent delete to remove database record | |
| return await super().delete() | |
| class Note(ObjectModel): | |
| table_name: ClassVar[str] = "note" | |
| title: Optional[str] = None | |
| note_type: Optional[Literal["human", "ai"]] = None | |
| content: Optional[str] = None | |
| def content_must_not_be_empty(cls, v): | |
| if v is not None and not v.strip(): | |
| raise InvalidInputError("Note content cannot be empty") | |
| return v | |
| async def save(self) -> Optional[str]: | |
| """ | |
| Save the note and submit embedding command. | |
| Overrides ObjectModel.save() to submit an async embed_note command | |
| after saving, instead of inline embedding. | |
| Returns: | |
| Optional[str]: The command_id if embedding was submitted, None otherwise | |
| """ | |
| # Call parent save (without embedding) | |
| await super().save() | |
| # Submit embedding command (fire-and-forget) if note has content | |
| if self.id and self.content and self.content.strip(): | |
| command_id = submit_command( | |
| "open_notebook", | |
| "embed_note", | |
| {"note_id": str(self.id)}, | |
| ) | |
| logger.debug(f"Submitted embed_note command {command_id} for {self.id}") | |
| return command_id | |
| return None | |
| async def add_to_notebook(self, notebook_id: str) -> Any: | |
| if not notebook_id: | |
| raise InvalidInputError("Notebook ID must be provided") | |
| return await self.relate("artifact", notebook_id) | |
| def get_context( | |
| self, context_size: Literal["short", "long"] = "short" | |
| ) -> Dict[str, Any]: | |
| if context_size == "long": | |
| return dict(id=self.id, title=self.title, content=self.content) | |
| else: | |
| return dict( | |
| id=self.id, | |
| title=self.title, | |
| content=self.content[:100] if self.content else None, | |
| ) | |
| class ChatSession(ObjectModel): | |
| table_name: ClassVar[str] = "chat_session" | |
| nullable_fields: ClassVar[set[str]] = {"model_override"} | |
| title: Optional[str] = None | |
| model_override: Optional[str] = None | |
| async def relate_to_notebook(self, notebook_id: str) -> Any: | |
| if not notebook_id: | |
| raise InvalidInputError("Notebook ID must be provided") | |
| return await self.relate("refers_to", notebook_id) | |
| async def relate_to_source(self, source_id: str) -> Any: | |
| if not source_id: | |
| raise InvalidInputError("Source ID must be provided") | |
| return await self.relate("refers_to", source_id) | |
| async def text_search( | |
| keyword: str, results: int, source: bool = True, note: bool = True | |
| ): | |
| if not keyword: | |
| raise InvalidInputError("Search keyword cannot be empty") | |
| try: | |
| search_results = await repo_query( | |
| """ | |
| select * | |
| from fn::text_search($keyword, $results, $source, $note) | |
| """, | |
| {"keyword": keyword, "results": results, "source": source, "note": note}, | |
| ) | |
| return search_results | |
| except Exception as e: | |
| logger.error(f"Error performing text search: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def vector_search( | |
| keyword: str, | |
| results: int, | |
| source: bool = True, | |
| note: bool = True, | |
| minimum_score=0.2, | |
| ): | |
| if not keyword: | |
| raise InvalidInputError("Search keyword cannot be empty") | |
| try: | |
| from open_notebook.utils.embedding import generate_embedding | |
| # Use unified embedding function (handles chunking if query is very long) | |
| embed = await generate_embedding(keyword) | |
| search_results = await repo_query( | |
| """ | |
| SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score); | |
| """, | |
| { | |
| "embed": embed, | |
| "results": results, | |
| "source": source, | |
| "note": note, | |
| "minimum_score": minimum_score, | |
| }, | |
| ) | |
| return search_results | |
| except Exception as e: | |
| logger.error(f"Error performing vector search: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |