Spaces:
Sleeping
Sleeping
| import asyncio | |
| from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union | |
| from loguru import logger | |
| from pydantic import BaseModel, Field, field_validator | |
| from surreal_commands import submit_command | |
| from surrealdb import RecordID | |
| from open_notebook.database.repository import ensure_record_id, repo_query | |
| from open_notebook.domain.base import ObjectModel | |
| from open_notebook.domain.models import model_manager | |
| from open_notebook.exceptions import DatabaseOperationError, InvalidInputError | |
| from open_notebook.utils import split_text | |
| class Notebook(ObjectModel): | |
| table_name: ClassVar[str] = "notebook" | |
| name: str | |
| description: str | |
| archived: Optional[bool] = False | |
| def name_must_not_be_empty(cls, v): | |
| if not v.strip(): | |
| raise InvalidInputError("Notebook name cannot be empty") | |
| return v | |
| async def get_sources(self) -> List["Source"]: | |
| try: | |
| srcs = await repo_query( | |
| """ | |
| select * omit source.full_text from ( | |
| select in as source from reference where out=$id | |
| fetch source | |
| ) order by source.updated desc | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return [Source(**src["source"]) for src in srcs] if srcs else [] | |
| except Exception as e: | |
| logger.error(f"Error fetching sources for notebook {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def get_notes(self) -> List["Note"]: | |
| try: | |
| srcs = await repo_query( | |
| """ | |
| select * omit note.content, note.embedding from ( | |
| select in as note from artifact where out=$id | |
| fetch note | |
| ) order by note.updated desc | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return [Note(**src["note"]) for src in srcs] if srcs else [] | |
| except Exception as e: | |
| logger.error(f"Error fetching notes for notebook {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def get_chat_sessions(self) -> List["ChatSession"]: | |
| try: | |
| srcs = await repo_query( | |
| """ | |
| select * from ( | |
| select | |
| <- chat_session as chat_session | |
| from refers_to | |
| where out=$id | |
| fetch chat_session | |
| ) | |
| order by chat_session.updated desc | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return ( | |
| [ChatSession(**src["chat_session"][0]) for src in srcs] if srcs else [] | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Error fetching chat sessions for notebook {self.id}: {str(e)}" | |
| ) | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| class Asset(BaseModel): | |
| file_path: Optional[str] = None | |
| url: Optional[str] = None | |
| class SourceEmbedding(ObjectModel): | |
| table_name: ClassVar[str] = "source_embedding" | |
| content: str | |
| async def get_source(self) -> "Source": | |
| try: | |
| src = await repo_query( | |
| """ | |
| select source.* from $id fetch source | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return Source(**src[0]["source"]) | |
| except Exception as e: | |
| logger.error(f"Error fetching source for embedding {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| class SourceInsight(ObjectModel): | |
| table_name: ClassVar[str] = "source_insight" | |
| insight_type: str | |
| content: str | |
| async def get_source(self) -> "Source": | |
| try: | |
| src = await repo_query( | |
| """ | |
| select source.* from $id fetch source | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return Source(**src[0]["source"]) | |
| except Exception as e: | |
| logger.error(f"Error fetching source for insight {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def save_as_note(self, notebook_id: Optional[str] = None) -> Any: | |
| source = await self.get_source() | |
| note = Note( | |
| title=f"{self.insight_type} from source {source.title}", | |
| content=self.content, | |
| ) | |
| await note.save() | |
| if notebook_id: | |
| await note.add_to_notebook(notebook_id) | |
| return note | |
| class Source(ObjectModel): | |
| table_name: ClassVar[str] = "source" | |
| asset: Optional[Asset] = None | |
| title: Optional[str] = None | |
| topics: Optional[List[str]] = Field(default_factory=list) | |
| full_text: Optional[str] = None | |
| command: Optional[Union[str, RecordID]] = Field( | |
| default=None, description="Link to surreal-commands processing job" | |
| ) | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def parse_command(cls, value): | |
| """Parse command field to ensure RecordID format""" | |
| if isinstance(value, str) and value: | |
| return ensure_record_id(value) | |
| return value | |
| def parse_id(cls, value): | |
| """Parse id field to handle both string and RecordID inputs""" | |
| if value is None: | |
| return None | |
| if isinstance(value, RecordID): | |
| return str(value) | |
| return str(value) if value else None | |
| async def get_status(self) -> Optional[str]: | |
| """Get the processing status of the associated command""" | |
| if not self.command: | |
| return None | |
| try: | |
| from surreal_commands import get_command_status | |
| status = await get_command_status(str(self.command)) | |
| return status.status if status else "unknown" | |
| except Exception as e: | |
| logger.warning(f"Failed to get command status for {self.command}: {e}") | |
| return "unknown" | |
| async def get_processing_progress(self) -> Optional[Dict[str, Any]]: | |
| """Get detailed processing information for the associated command""" | |
| if not self.command: | |
| return None | |
| try: | |
| from surreal_commands import get_command_status | |
| status_result = await get_command_status(str(self.command)) | |
| if not status_result: | |
| return None | |
| # Extract execution metadata if available | |
| result = getattr(status_result, "result", None) | |
| execution_metadata = result.get("execution_metadata", {}) if isinstance(result, dict) else {} | |
| return { | |
| "status": status_result.status, | |
| "started_at": execution_metadata.get("started_at"), | |
| "completed_at": execution_metadata.get("completed_at"), | |
| "error": getattr(status_result, "error_message", None), | |
| "result": result, | |
| } | |
| except Exception as e: | |
| logger.warning(f"Failed to get command progress for {self.command}: {e}") | |
| return None | |
| async def get_context( | |
| self, context_size: Literal["short", "long"] = "short" | |
| ) -> Dict[str, Any]: | |
| insights_list = await self.get_insights() | |
| insights = [insight.model_dump() for insight in insights_list] | |
| if context_size == "long": | |
| return dict( | |
| id=self.id, | |
| title=self.title, | |
| insights=insights, | |
| full_text=self.full_text, | |
| ) | |
| else: | |
| return dict(id=self.id, title=self.title, insights=insights) | |
| async def get_embedded_chunks(self) -> int: | |
| try: | |
| result = await repo_query( | |
| """ | |
| select count() as chunks from source_embedding where source=$id GROUP ALL | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| if len(result) == 0: | |
| return 0 | |
| return result[0]["chunks"] | |
| except Exception as e: | |
| logger.error(f"Error fetching chunks count for source {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(f"Failed to count chunks for source: {str(e)}") | |
| async def get_insights(self) -> List[SourceInsight]: | |
| try: | |
| result = await repo_query( | |
| """ | |
| SELECT * FROM source_insight WHERE source=$id | |
| """, | |
| {"id": ensure_record_id(self.id)}, | |
| ) | |
| return [SourceInsight(**insight) for insight in result] | |
| except Exception as e: | |
| logger.error(f"Error fetching insights for source {self.id}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError("Failed to fetch insights for source") | |
| async def add_to_notebook(self, notebook_id: str) -> Any: | |
| if not notebook_id: | |
| raise InvalidInputError("Notebook ID must be provided") | |
| return await self.relate("reference", notebook_id) | |
| async def vectorize(self) -> str: | |
| """ | |
| Submit vectorization as a background job using the vectorize_source command. | |
| This method now leverages the job-based architecture to prevent HTTP connection | |
| pool exhaustion when processing large documents. The actual chunk processing | |
| happens in the background worker pool, with natural concurrency control. | |
| Returns: | |
| str: The command/job ID that can be used to track progress via the commands API | |
| Raises: | |
| ValueError: If source has no text to vectorize | |
| DatabaseOperationError: If job submission fails | |
| """ | |
| logger.info(f"Submitting vectorization job for source {self.id}") | |
| try: | |
| if not self.full_text: | |
| raise ValueError(f"Source {self.id} has no text to vectorize") | |
| # Submit the vectorize_source command which will: | |
| # 1. Delete existing embeddings (idempotency) | |
| # 2. Split text into chunks | |
| # 3. Submit each chunk as an embed_chunk job | |
| command_id = submit_command( | |
| "open_notebook", # app name | |
| "vectorize_source", # command name | |
| { | |
| "source_id": str(self.id), | |
| } | |
| ) | |
| command_id_str = str(command_id) | |
| logger.info( | |
| f"Vectorization job submitted for source {self.id}: " | |
| f"command_id={command_id_str}" | |
| ) | |
| return command_id_str | |
| except Exception as e: | |
| logger.error(f"Failed to submit vectorization job for source {self.id}: {e}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def add_insight(self, insight_type: str, content: str) -> Any: | |
| EMBEDDING_MODEL = await model_manager.get_embedding_model() | |
| if not EMBEDDING_MODEL: | |
| logger.warning("No embedding model found. Insight will not be searchable.") | |
| if not insight_type or not content: | |
| raise InvalidInputError("Insight type and content must be provided") | |
| try: | |
| embedding = ( | |
| (await EMBEDDING_MODEL.aembed([content]))[0] if EMBEDDING_MODEL else [] | |
| ) | |
| return await repo_query( | |
| """ | |
| CREATE source_insight CONTENT { | |
| "source": $source_id, | |
| "insight_type": $insight_type, | |
| "content": $content, | |
| "embedding": $embedding, | |
| };""", | |
| { | |
| "source_id": ensure_record_id(self.id), | |
| "insight_type": insight_type, | |
| "content": content, | |
| "embedding": embedding, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error adding insight to source {self.id}: {str(e)}") | |
| raise # DatabaseOperationError(e) | |
| def _prepare_save_data(self) -> dict: | |
| """Override to ensure command field is always RecordID format for database""" | |
| data = super()._prepare_save_data() | |
| # Ensure command field is RecordID format if not None | |
| if data.get("command") is not None: | |
| data["command"] = ensure_record_id(data["command"]) | |
| return data | |
| class Note(ObjectModel): | |
| table_name: ClassVar[str] = "note" | |
| title: Optional[str] = None | |
| note_type: Optional[Literal["human", "ai"]] = None | |
| content: Optional[str] = None | |
| def content_must_not_be_empty(cls, v): | |
| if v is not None and not v.strip(): | |
| raise InvalidInputError("Note content cannot be empty") | |
| return v | |
| async def add_to_notebook(self, notebook_id: str) -> Any: | |
| if not notebook_id: | |
| raise InvalidInputError("Notebook ID must be provided") | |
| return await self.relate("artifact", notebook_id) | |
| def get_context( | |
| self, context_size: Literal["short", "long"] = "short" | |
| ) -> Dict[str, Any]: | |
| if context_size == "long": | |
| return dict(id=self.id, title=self.title, content=self.content) | |
| else: | |
| return dict( | |
| id=self.id, | |
| title=self.title, | |
| content=self.content[:100] if self.content else None, | |
| ) | |
| def needs_embedding(self) -> bool: | |
| return True | |
| def get_embedding_content(self) -> Optional[str]: | |
| return self.content | |
| class ChatSession(ObjectModel): | |
| table_name: ClassVar[str] = "chat_session" | |
| nullable_fields: ClassVar[set[str]] = {"model_override"} | |
| title: Optional[str] = None | |
| model_override: Optional[str] = None | |
| async def relate_to_notebook(self, notebook_id: str) -> Any: | |
| if not notebook_id: | |
| raise InvalidInputError("Notebook ID must be provided") | |
| return await self.relate("refers_to", notebook_id) | |
| async def relate_to_source(self, source_id: str) -> Any: | |
| if not source_id: | |
| raise InvalidInputError("Source ID must be provided") | |
| return await self.relate("refers_to", source_id) | |
| async def text_search( | |
| keyword: str, results: int, source: bool = True, note: bool = True | |
| ): | |
| if not keyword: | |
| raise InvalidInputError("Search keyword cannot be empty") | |
| try: | |
| search_results = await repo_query( | |
| """ | |
| select * | |
| from fn::text_search($keyword, $results, $source, $note) | |
| """, | |
| {"keyword": keyword, "results": results, "source": source, "note": note}, | |
| ) | |
| return search_results | |
| except Exception as e: | |
| logger.error(f"Error performing text search: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def vector_search( | |
| keyword: str, | |
| results: int, | |
| source: bool = True, | |
| note: bool = True, | |
| minimum_score=0.2, | |
| ): | |
| if not keyword: | |
| raise InvalidInputError("Search keyword cannot be empty") | |
| try: | |
| EMBEDDING_MODEL = await model_manager.get_embedding_model() | |
| if EMBEDDING_MODEL is None: | |
| raise ValueError("EMBEDDING_MODEL is not configured") | |
| embed = (await EMBEDDING_MODEL.aembed([keyword]))[0] | |
| search_results = await repo_query( | |
| """ | |
| SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score); | |
| """, | |
| { | |
| "embed": embed, | |
| "results": results, | |
| "source": source, | |
| "note": note, | |
| "minimum_score": minimum_score, | |
| }, | |
| ) | |
| return search_results | |
| except Exception as e: | |
| logger.error(f"Error performing vector search: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |