Spaces:
Sleeping
Sleeping
| import hashlib | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from textwrap import dedent | |
| from typing import Any, cast | |
| import logfire | |
| from pydantic import BaseModel, ValidationError | |
| from surrealdb import ( | |
| AsyncHttpSurrealConnection, | |
| AsyncSurreal, | |
| AsyncWsSurrealConnection, | |
| BlockingHttpSurrealConnection, | |
| BlockingWsSurrealConnection, | |
| RecordID, | |
| Surreal, | |
| Value, | |
| ) | |
| from websockets.exceptions import ConnectionClosedError | |
| from surrealdb import ( | |
| RecordID as SurrealRecordID, | |
| ) | |
| from ..definitions import ( | |
| Analytics, | |
| GenericDocument, | |
| Node, | |
| Object, | |
| OriginalDocument, | |
| RecursiveResult, | |
| Relation, | |
| Relations, | |
| VectorTableDefinition, | |
| ) | |
| from ..embeddings import Embedder | |
| from ..llm import LLM | |
| from . import utils | |
| from .queries import COUNT_QUERY | |
| logger = logging.getLogger(__name__) | |
| _ = logfire.configure(send_to_logfire="if-token-present") | |
| class DB: | |
| def __init__( | |
| self, | |
| url: str, | |
| username: str, | |
| password: str, | |
| namespace: str, | |
| database: str, | |
| embedder: Embedder | None = None, | |
| llm: LLM | None = None, | |
| *, | |
| analytics_table: str = "analytics", | |
| original_docs_table: str = "document", | |
| tables: list[str] | None = None, | |
| vector_tables: list[VectorTableDefinition] | None = None, | |
| graph_relations: list[Relation] | None = None, | |
| ): | |
| if hasattr(logfire, "instrument_surrealdb"): | |
| logfire.instrument_surrealdb() | |
| self._sync_conn: ( | |
| BlockingHttpSurrealConnection | BlockingWsSurrealConnection | None | |
| ) = None | |
| self._async_conn: ( | |
| AsyncHttpSurrealConnection | AsyncWsSurrealConnection | None | |
| ) = None | |
| self.url: str = url | |
| self.username: str = username | |
| self.password: str = password | |
| self.namespace: str = namespace | |
| self.database: str = database | |
| self.embedder: Embedder | None = embedder | |
| self.llm: LLM | None = llm | |
| self._original_docs_table: str = original_docs_table | |
| self._analytics_table: str = analytics_table | |
| self._tables: list[str] = tables or [] | |
| self._vector_tables: list[VectorTableDefinition] = vector_tables or [] | |
| self._graph_relations: list[Relation] = graph_relations or [] | |
| self._surql_cache: dict[str, str] = {} | |
| for filename in [ | |
| "create_index_hnsw.surql", | |
| "create_index_mtree.surql", | |
| "define_relation.surql", | |
| "graph_query_in.surql", | |
| "graph_siblings.surql", | |
| "vector_search.surql", | |
| "vector_search_simple.surql", | |
| ]: | |
| self._surql_cache[filename] = self._load_surql(filename) | |
| def init_db(self, force: bool = False) -> None: | |
| r"""This needs to be called to initialise the DB indexes. | |
| Only required if you defined `vector_tables` or `graph_relations`. | |
| """ | |
| if not force: | |
| # Check if the database is already initialized | |
| is_init = self.sync_conn.query("SELECT * FROM ONLY meta:initialized") | |
| # query return type is wrong, in this case it could return None | |
| if is_init is not None: | |
| return | |
| # vector index cheat sheet: https://surrealdb.com/docs/surrealdb/reference-guide/vector-search#vector-search-cheat-sheet | |
| if self._vector_tables and self.embedder is None: | |
| raise ValueError( | |
| "Embedder is not initialized, and is required for vector tables to be created" | |
| ) | |
| if self.embedder is not None: | |
| for vector_table in self._vector_tables: | |
| match vector_table.index_type: | |
| case "HNSW": | |
| surql_name = "create_index_hnsw.surql" | |
| case _: | |
| surql_name = "create_index_mtree.surql" | |
| _ = self.execute( | |
| surql_name, | |
| None, | |
| { | |
| "table": vector_table.name, | |
| "dimension": self.embedder.dimension, | |
| "distance_function": vector_table.dist_func, | |
| "vector_type": self.embedder.vector_type, | |
| }, | |
| ) | |
| for relation in self._graph_relations: | |
| print(f"Creating relation {relation.name}") | |
| _ = self.execute( | |
| "define_relation.surql", | |
| None, | |
| { | |
| "name": relation.name, | |
| "in_tb": relation.in_table, | |
| "out_tb": relation.out_table, | |
| }, | |
| ) | |
| # -- original documents table | |
| _ = self.execute( | |
| "define_table.surql", | |
| None, | |
| { | |
| "name": self._original_docs_table, | |
| "fields": dedent(f""" | |
| DEFINE FIELD OVERWRITE filename ON {self._original_docs_table} TYPE string; | |
| DEFINE FIELD OVERWRITE file ON {self._original_docs_table} TYPE bytes; | |
| """), | |
| }, | |
| ) | |
| # -- analytics table | |
| _ = self.execute( | |
| "define_table.surql", | |
| None, | |
| { | |
| "name": self._analytics_table, | |
| "fields": dedent(f""" | |
| DEFINE FIELD OVERWRITE input ON {self._analytics_table} TYPE string; | |
| DEFINE FIELD OVERWRITE output ON {self._analytics_table} TYPE string; | |
| DEFINE FIELD OVERWRITE key ON {self._analytics_table} TYPE string; | |
| DEFINE FIELD OVERWRITE score ON {self._analytics_table} TYPE float; | |
| """), | |
| }, | |
| ) | |
| _ = self.sync_conn.upsert("meta:initialized") | |
| logger.info("Database initialized") | |
| def clear(self) -> None: | |
| res = self.sync_conn.query("REMOVE TABLE IF EXISTS meta;") | |
| res = self.sync_conn.query(f"REMOVE TABLE IF EXISTS {self._analytics_table};") | |
| logger.debug(res) | |
| for table in self._tables: | |
| res = self.sync_conn.query(f"REMOVE TABLE IF EXISTS {table};") | |
| logger.debug(res) | |
| for table in self._vector_tables: | |
| res = self.sync_conn.query(f"REMOVE TABLE IF EXISTS {table.name};") | |
| logger.debug(res) | |
| res = self.sync_conn.query( | |
| f"REMOVE INDEX IF EXISTS idx_{table.name} ON {table.name};" | |
| ) | |
| logger.debug(res) | |
| def _vector_table(self) -> str: | |
| return self._vector_tables[0].name | |
| def original_docs_table(self) -> str: | |
| return self._original_docs_table | |
| # ========================================================================== | |
| # Connections | |
| # ========================================================================== | |
| async def async_conn( | |
| self, | |
| ) -> AsyncWsSurrealConnection | AsyncHttpSurrealConnection: | |
| if self._async_conn is None: | |
| self._async_conn = AsyncSurreal(self.url) | |
| if self.url != "mem://": | |
| _ = await self._async_conn.signin( | |
| {"username": self.username, "password": self.password} | |
| ) | |
| await self._async_conn.use(self.username, self.database) | |
| # await self._init_db() | |
| return self._async_conn | |
| def sync_conn( | |
| self, | |
| ) -> BlockingHttpSurrealConnection | BlockingWsSurrealConnection: | |
| if self._sync_conn is None: | |
| self._sync_conn = Surreal(self.url) | |
| if self.url != "mem://": | |
| _ = self._sync_conn.signin( | |
| {"username": self.username, "password": self.password} | |
| ) | |
| self._sync_conn.use(self.namespace, self.database) | |
| return self._sync_conn | |
| # ========================================================================== | |
| # Execute | |
| # ========================================================================== | |
| def _load_surql(self, filename_or_path: str | Path) -> str: | |
| if isinstance(filename_or_path, Path): | |
| filename = filename_or_path.name | |
| file_path = filename_or_path | |
| else: | |
| filename = filename_or_path | |
| file_path = Path(__file__).parent / "surql" / filename_or_path | |
| # check cache | |
| cached = self._surql_cache.get(filename) | |
| if cached is not None: | |
| return cached | |
| with open(file_path, "r") as file: | |
| return file.read() | |
| def _extract_result_and_time(self, res: Object) -> tuple[Value, float]: | |
| if "result" in res: | |
| result = res["result"] | |
| if result is not None and isinstance(result, list): | |
| result_inner = result[0] | |
| if isinstance(result_inner, dict): | |
| obj = result_inner["result"] | |
| time = str(result_inner["time"]) | |
| return obj, utils.parse_time(time) | |
| raise ValueError(f"unexpected result: {res}") | |
| def execute( | |
| self, | |
| file: str | Path, | |
| vars: Object | None = None, | |
| template_vars: Object | None = None, | |
| ) -> tuple[Value, float]: | |
| surql = self._load_surql(file) | |
| if template_vars is not None: | |
| surql = surql.format(**template_vars) | |
| max_retries = int(os.getenv("KG_DB_RETRY_ATTEMPTS", "3")) | |
| retry_delay = float(os.getenv("KG_DB_RETRY_DELAY", "1.0")) | |
| last_error: Exception | None = None | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| res: Object = self.sync_conn.query_raw( | |
| surql, cast(dict[str, Value], vars) | |
| ) | |
| return self._extract_result_and_time(res) | |
| except ConnectionClosedError as exc: | |
| last_error = exc | |
| logger.warning( | |
| "SurrealDB connection closed (attempt %s/%s), retrying", | |
| attempt, | |
| max_retries, | |
| ) | |
| self._sync_conn = None | |
| time.sleep(retry_delay) | |
| if last_error: | |
| raise last_error | |
| raise RuntimeError("SurrealDB query failed") | |
| async def async_execute( | |
| self, | |
| file: str | Path, | |
| vars: dict[str, Value] | None = None, | |
| template_vars: Object | None = None, | |
| ) -> tuple[Value, float]: | |
| surql = self._load_surql(file) | |
| if template_vars is not None: | |
| surql = surql.format(**template_vars) | |
| conn = await self.async_conn | |
| res: Object = await conn.query_raw(surql, vars) | |
| return self._extract_result_and_time(res) | |
| # ========================================================================== | |
| # Basic queries | |
| # ========================================================================== | |
| def query( | |
| self, | |
| query: str, | |
| vars: Object, | |
| record_type: type[utils.RecordType], | |
| ) -> list[utils.RecordType]: | |
| r'''Query a list of records and assert their expected type `record_type` | |
| Args: | |
| query (str): The query to execute. | |
| vars (Object): The variables to use in the query. | |
| record_type (type[utils.RecordType]): The expected type of the records. | |
| Returns: | |
| list[utils.RecordType]: The list of records. | |
| Raises: | |
| TypeError: If the records are not of the expected type. | |
| Example: | |
| ```python | |
| from kaig.db.queries import WhereClause, order_limit_start | |
| from surrealdb import RecordID | |
| where = WhereClause() | |
| where = where.and_("team", RecordID("team", "green")) | |
| where_clause, where_vars = where.build() | |
| order_limit_start_clause = order_limit_start("username", "DESC", 5, 0) | |
| query = dedent(f""" | |
| SELECT * | |
| FROM user | |
| {where_clause} | |
| {order_limit_start_clause} | |
| """) | |
| filtered = db.query(query, where_vars, User) | |
| ``` | |
| ''' | |
| return utils.query(self.sync_conn, query, vars, record_type) | |
| def query_one( | |
| self, | |
| query: str, | |
| vars: Object, | |
| record_type: type[utils.RecordType], | |
| ) -> utils.RecordType | None: | |
| return utils.query_one(self.sync_conn, query, vars, record_type) | |
| def count( | |
| self, | |
| table: str, | |
| where_clause: str, | |
| where_vars: Object, | |
| group_by: str | None = None, | |
| ) -> int: | |
| total_count_query = COUNT_QUERY.format( | |
| table=table, | |
| where_clause=where_clause, | |
| group_clause="GROUP ALL" if group_by is None else f"GROUP BY{group_by}", | |
| ) | |
| count_result = self.query_one(total_count_query, where_vars, dict[str, int]) | |
| total_count = count_result.get("count") if count_result else 0 | |
| assert isinstance(total_count, int), f"Expected int, got {type(total_count)}" | |
| total_count = int(total_count) | |
| return total_count | |
| def exists(self, record: SurrealRecordID) -> bool: | |
| exists = self.sync_conn.query( | |
| "RETURN record::exists($record)", | |
| {"record": record}, | |
| ) | |
| # query return type is wrong, in this case it could return a bool | |
| if not isinstance(exists, bool): | |
| return False | |
| return exists | |
| # ========================================================================== | |
| # Analytics | |
| # ========================================================================== | |
| def insert_analytics_data( | |
| self, key: str, input: str, output: str, score: float, tag: str | |
| ) -> None: | |
| try: | |
| _res = self.sync_conn.insert( | |
| self._analytics_table, | |
| asdict(Analytics(key, tag, input, output, score)), | |
| ) | |
| except Exception: | |
| # TODO: log error | |
| ... | |
| async def safe_insert_error(self, id: int, error: str): | |
| conn = await self.async_conn | |
| try: | |
| _ = await conn.query( | |
| "CREATE $record CONTENT $content", | |
| { | |
| "record": SurrealRecordID("error", id), | |
| "content": {"error": error}, | |
| }, | |
| ) | |
| except Exception as e: | |
| print(f"Error inserting error record: {e}", file=sys.stderr) | |
| # TODO: fix if surrealdb.py changes the type for the RecordID identifier | |
| async def error_exists(self, id: Any) -> bool: # pyright: ignore[reportExplicitAny, reportAny] | |
| conn = await self.async_conn | |
| res = await conn.query( | |
| "RETURN record::exists($record)", | |
| {"record": SurrealRecordID("error", id)}, | |
| ) | |
| # query return type is wrong, in this case it could return a bool | |
| if not isinstance(res, bool): | |
| raise RuntimeError(f"Unexpected result from error_exists: {type(res)}") | |
| return res | |
| # ========================================================================== | |
| # Original Documents | |
| # ========================================================================== | |
| def store_original_document( | |
| self, file: str, content_type: str | |
| ) -> tuple[OriginalDocument, bool]: | |
| """Returns a tuple of the document and a bool indicating whether the | |
| document was chached (True) or inserted (False)""" | |
| source = Path(file) | |
| # TODO: store files as files instead of bytes | |
| # https://surrealdb.com/docs/surrealql/datamodel/files | |
| file_content = bytes() | |
| with open(source, "rb") as f: | |
| md5_hash = hashlib.md5() | |
| while True: | |
| c = f.read(4096) | |
| if not c: | |
| break | |
| md5_hash.update(c) | |
| file_content += c | |
| doc_hash = md5_hash.hexdigest() | |
| return self.store_original_document_from_bytes( | |
| source.name, content_type, file_content, doc_hash | |
| ) | |
| def store_original_document_from_bytes( | |
| self, | |
| filename: str, | |
| content_type: str, | |
| file_bytes: bytes, | |
| precomputed_hash: str | None = None, | |
| ) -> tuple[OriginalDocument, bool]: | |
| """Returns a tuple of the document and a bool indicating whether the | |
| document was chached (True) or inserted (False)""" | |
| md5_hash = hashlib.md5() | |
| md5_hash.update(file_bytes) | |
| hex_hash = md5_hash.hexdigest() | |
| if precomputed_hash is not None and hex_hash != precomputed_hash: | |
| logger.warning( | |
| f"Hash mismatch for {filename}: {hex_hash} != {precomputed_hash}" | |
| ) | |
| # -- check if the document already exists | |
| record_id = SurrealRecordID(self._original_docs_table, hex_hash) | |
| cached = self.query_one( | |
| "SELECT * FROM ONLY $record", | |
| {"record": record_id}, | |
| OriginalDocument, | |
| ) | |
| if cached: | |
| # update the document to trigger process | |
| _ = self.query_one("UPDATE ONLY $record", {"record": record_id}, dict) | |
| return cached, True | |
| else: | |
| content = OriginalDocument( | |
| record_id, filename, content_type, file_bytes, None | |
| ) | |
| inserted = self.query_one( | |
| "CREATE ONLY $record CONTENT $content", | |
| {"record": record_id, "content": asdict(content)}, | |
| OriginalDocument, | |
| ) | |
| if not inserted: | |
| raise Exception("Failed to create document: CREATE returned NONE.") | |
| return inserted, False | |
| # ========================================================================== | |
| # Documents (or more precisely: chunks) | |
| # ========================================================================== | |
| async def get_document( | |
| self, doc_type: type[GenericDocument], id: int | |
| ) -> GenericDocument | None: | |
| conn = await self.async_conn | |
| res = await conn.query( | |
| "SELECT * FROM ONLY $record", | |
| {"record": SurrealRecordID(self._vector_table, id)}, | |
| ) | |
| if not res: | |
| return None | |
| if not isinstance(res, dict): | |
| raise RuntimeError(f"Unexpected result from get_documents: {type(res)}") | |
| return doc_type.model_validate(res) | |
| async def list_documents( | |
| self, | |
| doc_type: type[GenericDocument], | |
| start_after: int = 0, | |
| limit: int = 100, | |
| ) -> list[GenericDocument]: | |
| conn = await self.async_conn | |
| if start_after == 0: | |
| res = await conn.query( | |
| f"SELECT * FROM {self._vector_table} ORDER BY id LIMIT $limit", | |
| {"limit": limit}, | |
| ) | |
| else: | |
| res = await conn.query( | |
| f"SELECT * FROM type::thing({self._vector_table}, $start_after..) ORDER BY id LIMIT $limit", | |
| {"limit": limit, "start_after": start_after}, | |
| ) | |
| if not isinstance(res, list): | |
| raise RuntimeError(f"Unexpected result from list_documents: {type(res)}") | |
| return [doc_type.model_validate(record) for record in res] | |
| # ========================================================================== | |
| # Vector store | |
| # ========================================================================== | |
| # TODO: do we still need this? | |
| async def async_insert_document( | |
| self, | |
| document: BaseModel, | |
| id: int | str | None = None, | |
| table: str | None = None, | |
| ) -> None: | |
| conn = await self.async_conn | |
| if not table: | |
| table = self._vector_table | |
| _ = await conn.create( | |
| table if id is None else SurrealRecordID(table, id), | |
| document.model_dump(by_alias=True), | |
| ) | |
| # TODO: should we merge insert_document and _insert_embedded together? | |
| def insert_document( | |
| self, | |
| document: GenericDocument, | |
| id: int | str | None = None, | |
| table: str | None = None, | |
| ) -> GenericDocument: | |
| if not table: | |
| table = self._vector_table | |
| data_dict = document.model_dump(by_alias=True) | |
| if id is not None and data_dict["id"]: | |
| del data_dict["id"] | |
| res = self.sync_conn.create( | |
| table if id is None else SurrealRecordID(table, id), data_dict | |
| ) | |
| if isinstance(res, list): | |
| raise RuntimeError(f"Unexpected result from insert_document: {res}") | |
| return type(document).model_validate(res) | |
| def _insert_embedded( | |
| self, | |
| document: GenericDocument, | |
| id: int | str | None = None, | |
| table: str | None = None, | |
| ) -> GenericDocument: | |
| if not table: | |
| table = self._vector_table | |
| data_dict = document.model_dump() | |
| if id is not None and data_dict["id"]: | |
| del data_dict["id"] | |
| res = self.sync_conn.create( | |
| table if id is None else SurrealRecordID(table, id), data_dict | |
| ) | |
| if isinstance(res, list): | |
| raise RuntimeError( | |
| f"Unexpected result from _inserted_embedded: {res} with {table}:{id}" | |
| ) | |
| try: | |
| return type(document).model_validate(res, by_alias=True) | |
| except Exception as e: | |
| logger.debug(f"Error while validating document: {e}") | |
| raise | |
| def embed_and_insert( | |
| self, | |
| doc: GenericDocument, | |
| table: str | None = None, | |
| id: int | str | None = None, | |
| force: bool = False, | |
| ) -> GenericDocument: | |
| if self.embedder is None: | |
| raise ValueError("Embedder is not initialized") | |
| if not table: | |
| table = self._vector_table | |
| rec_id = SurrealRecordID(table, id) | |
| with logfire.span("Embed and insert {rec_id=}", rec_id=rec_id): | |
| if id is not None and not force: | |
| existing = self.query_one( | |
| "SELECT * FROM ONLY $record", | |
| {"record": rec_id}, | |
| type(doc), | |
| ) | |
| if existing: | |
| return existing | |
| if doc.content: | |
| content = doc.content | |
| while True: | |
| try: | |
| embedding = self.embedder.embed(content) | |
| break | |
| except Exception as e: | |
| # do we need to truncate the chunk to embed it? | |
| if "the input length exceeds the context length" in str(e): | |
| # retry | |
| content = content[: self.embedder.max_length] | |
| logger.info("Retry embedding chunk") | |
| continue | |
| logger.error( | |
| f"Error embedding doc with len={len(content)}: {type(e)} {e}" | |
| ) | |
| raise e | |
| doc.embedding = embedding | |
| return self._insert_embedded(doc, id, table) | |
| else: | |
| return self.insert_document(doc, id, table) | |
| def _extract_similarity_results( | |
| self, res: Value, doc_type: type[GenericDocument] | |
| ) -> list[tuple[GenericDocument, float]]: | |
| if not isinstance(res, list): | |
| raise RuntimeError(f"Unexpected result from vector search: {res}") | |
| results: list[tuple[GenericDocument, float]] = [] | |
| for record in res: | |
| if isinstance(record, dict): | |
| score = cast(float, record.get("score", 0)) | |
| else: | |
| score = 0 | |
| if isinstance(record, dict): | |
| results.append((doc_type.model_validate(record), score)) | |
| return results | |
| def vector_search_from_text( | |
| self, | |
| doc_type: type[GenericDocument], | |
| text: str, | |
| *, | |
| table: str, | |
| k: int, | |
| score_threshold: float = -1, | |
| effort: int | None = 40, | |
| ) -> tuple[list[tuple[GenericDocument, float]], float]: | |
| if self.embedder is None: | |
| raise ValueError("Embedder is not initialized") | |
| embedding = self.embedder.embed(text) | |
| res, time = self.execute( | |
| "vector_search.surql", | |
| { | |
| "embedding": cast(list[Value], embedding), | |
| "threshold": score_threshold, | |
| }, | |
| { | |
| "table": table, | |
| "k": k, | |
| "effort_param": f",{effort}" if effort is not None else "", | |
| }, | |
| ) | |
| return self._extract_similarity_results(res, doc_type), time | |
| def vector_search( | |
| self, | |
| doc_type: type[GenericDocument], | |
| query_embeddings: list[float], | |
| *, | |
| table: str | None = None, | |
| k: int = 5, | |
| effort: None = None, | |
| threshold: float = 0, | |
| ) -> tuple[list[tuple[GenericDocument, float]], float]: | |
| res, time = self.execute( | |
| "vector_search.surql", | |
| { | |
| "embedding": cast(list[Value], query_embeddings), | |
| "threshold": threshold, | |
| }, | |
| { | |
| "k": k, | |
| "table": table if table is not None else self._vector_table, | |
| "effort_param": f",{effort}" if effort else "", | |
| }, | |
| ) | |
| return self._extract_similarity_results(res, doc_type), time | |
| async def async_vector_search( | |
| self, | |
| doc_type: type[GenericDocument], | |
| query_embeddings: list[float], | |
| *, | |
| table: str | None = None, | |
| k: int = 5, | |
| effort: None = None, | |
| threshold: float = 0, | |
| ) -> tuple[list[tuple[GenericDocument, float]], float]: | |
| res, time = await self.async_execute( | |
| "vector_search.surql", | |
| { | |
| "embedding": cast(list[Value], query_embeddings), | |
| "threshold": threshold, | |
| }, | |
| { | |
| "k": k, | |
| "table": table if table is not None else self._vector_table, | |
| "effort_param": f",{effort}" if effort else "", | |
| }, | |
| ) | |
| return self._extract_similarity_results(res, doc_type), time | |
| # ========================================================================== | |
| # Graph | |
| # ========================================================================== | |
| def relate( | |
| self, | |
| in_: SurrealRecordID, | |
| relation: str, | |
| out: SurrealRecordID | list[SurrealRecordID], | |
| ) -> None: | |
| all = [out] if not isinstance(out, list) else out | |
| for out in all: | |
| _res = self.sync_conn.insert_relation(relation, {"in": in_, "out": out}) | |
| # TODO: batch relate when supported | |
| # _res = self.sync_conn.query("relate $in->$rel->$out", {"in":in_, "out":out, "rel":relation}) | |
| def _add_graph_nodes( | |
| self, | |
| src_table: str, | |
| dest_table: str, | |
| destinations: list[Node], | |
| edge_name: str, | |
| relations: Relations, | |
| ) -> None: | |
| for dest in destinations: | |
| node = asdict(dest) | |
| try: | |
| _ = self.sync_conn.upsert( | |
| SurrealRecordID(dest_table, dest.content), node | |
| ) | |
| except Exception as e: | |
| print(f"Failed: {e} with {node}") | |
| for doc_id, cats in relations.items(): | |
| try: | |
| self.relate( | |
| SurrealRecordID(src_table, doc_id), | |
| edge_name, | |
| [SurrealRecordID(dest_table, cat) for cat in cats], | |
| ) | |
| except Exception as e: | |
| print(f"Failed: {e}") | |
| def add_graph_nodes( | |
| self, | |
| src_table: str, | |
| dest_table: str, | |
| destinations: set[str], | |
| edge_name: str, | |
| relations: Relations, | |
| ) -> None: | |
| node_destinations = [Node(dest, None) for dest in destinations] | |
| return self._add_graph_nodes( | |
| src_table, | |
| dest_table, | |
| node_destinations, | |
| edge_name, | |
| relations, | |
| ) | |
| def add_graph_nodes_with_embeddings( | |
| self, | |
| src_table: str, | |
| dest_table: str, | |
| destinations: set[str], | |
| edge_name: str, | |
| relations: Relations, | |
| ) -> None: | |
| if self.embedder is None: | |
| raise ValueError("Embedder is not initialized") | |
| node_destinations = [ | |
| Node(dest, self.embedder.embed(dest)) for dest in destinations if dest | |
| ] | |
| return self._add_graph_nodes( | |
| src_table, | |
| dest_table, | |
| node_destinations, | |
| edge_name, | |
| relations, | |
| ) | |
| def recursive_graph_query( | |
| self, | |
| doc_type: type[GenericDocument], | |
| id: RecordID, | |
| rel: str, | |
| levels: int = 5, | |
| ) -> list[RecursiveResult[GenericDocument]]: | |
| rels = ", ".join( | |
| [f"@.{{{i}}}(->{rel}->?) AS bucket{i}" for i in range(1, levels + 1)] | |
| ) | |
| query = f"SELECT *, {rels} FROM $record" | |
| res = self.sync_conn.query(query, {"record": id}) | |
| if not isinstance(res, list): | |
| raise RuntimeError( | |
| f"Unexpected result from recursive_graph_query: {res} with {query}" | |
| ) | |
| results: list[RecursiveResult[GenericDocument]] = [] | |
| for item in res: | |
| buckets: list[RecordID] = [] | |
| for i in range(1, levels + 1): | |
| if isinstance(item, dict) and f"bucket{i}" in item: | |
| bucket = item.get(f"bucket{i}") | |
| if bucket is not None and isinstance(bucket, SurrealRecordID): | |
| buckets.append(bucket) | |
| try: | |
| results.append( | |
| RecursiveResult[GenericDocument]( | |
| buckets=buckets, inner=doc_type.model_validate(item) | |
| ) | |
| ) | |
| except ValidationError as e: | |
| print(f"Validation error: {e}") | |
| return results | |
| def graph_query_inward( | |
| self, | |
| doc_type: type[GenericDocument], | |
| id: RecordID | list[RecordID], | |
| rel: str, | |
| src: str, | |
| embedding: list[float] | None, | |
| ) -> tuple[list[GenericDocument], float]: | |
| res, time = self.execute( | |
| "graph_query_in.surql", | |
| { | |
| "record": cast(RecordID, id), | |
| "embedding": cast(list[Value], embedding), | |
| }, | |
| {"relation": rel, "src": src}, | |
| ) | |
| if isinstance(res, list): | |
| return list(map(lambda x: doc_type.model_validate(x), res)), time | |
| raise ValueError(f"Unexpected result from graph_query_inward: {res}") | |
| def graph_siblings( | |
| self, | |
| doc_type: type[GenericDocument], | |
| id: RecordID, | |
| relation: str, | |
| src: str, | |
| dest: str, | |
| ) -> tuple[list[GenericDocument], float]: | |
| res, time = self.execute( | |
| "graph_siblings.surql", | |
| {"record": id}, | |
| { | |
| "relation": relation, | |
| "src": src, | |
| "dest": dest, | |
| }, | |
| ) | |
| if isinstance(res, list): | |
| return list(map(lambda x: doc_type.model_validate(x), res)), time | |
| raise ValueError(f"Unexpected result from graph_siblings: {res}") | |