Spaces:
Running
Running
| import json | |
| import os | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Set | |
| try: | |
| import kuzu | |
| except ImportError: | |
| kuzu = None | |
| from graphgen.bases.base_storage import BaseGraphStorage | |
| class KuzuStorage(BaseGraphStorage): | |
| """ | |
| Graph storage implementation based on KuzuDB. | |
| Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties, | |
| we map the data to a generic schema: | |
| - Node Table 'Entity': {id: STRING, data: STRING (JSON)} | |
| - Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)} | |
| """ | |
| working_dir: str = None | |
| namespace: str = None | |
| _db: Any = None | |
| _conn: Any = None | |
| def __post_init__(self): | |
| if kuzu is None: | |
| raise ImportError( | |
| "KuzuDB is not installed. Please install it via `pip install kuzu`." | |
| ) | |
| self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu") | |
| self._init_db() | |
| def _init_db(self): | |
| # KuzuDB automatically creates the directory | |
| self._db = kuzu.Database(self.db_path) | |
| self._conn = kuzu.Connection(self._db) | |
| self._init_schema() | |
| print(f"KuzuDB initialized at {self.db_path}") | |
| def _init_schema(self): | |
| """Initialize the generic Node and Edge tables if they don't exist.""" | |
| # Check and create Node table | |
| try: | |
| # We use a generic table name "Entity" to store all nodes | |
| self._conn.execute( | |
| "CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))" | |
| ) | |
| print("Created KuzuDB Node Table 'Entity'") | |
| except RuntimeError as e: | |
| # Usually throws if table exists, verify safely or ignore | |
| print("Node Table 'Entity' already exists or error:", e) | |
| # Check and create Edge table | |
| try: | |
| # We use a generic table name "Relation" to store all edges | |
| self._conn.execute( | |
| "CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)" | |
| ) | |
| print("Created KuzuDB Rel Table 'Relation'") | |
| except RuntimeError as e: | |
| print("Rel Table 'Relation' already exists or error:", e) | |
| def index_done_callback(self): | |
| """KuzuDB is ACID, changes are immediate, but we can verify generic persistence here.""" | |
| def _safe_json_loads(data_str: str) -> dict: | |
| if not isinstance(data_str, str) or not data_str.strip(): | |
| return {} | |
| try: | |
| return json.loads(data_str) | |
| except json.JSONDecodeError as e: | |
| print(f"Error decoding JSON: {e}") | |
| return {} | |
| def is_directed(self) -> bool: | |
| return True | |
| def get_all_node_degrees(self) -> Dict[str, int]: | |
| query = """ | |
| MATCH (n:Entity) | |
| OPTIONAL MATCH (n)-[r]-() | |
| RETURN n.id, count(r) as degree | |
| """ | |
| result = self._conn.execute(query) | |
| degree_map = {} | |
| while result.has_next(): | |
| row = result.get_next() | |
| if row and len(row) >= 2: | |
| node_id, degree = row[0], row[1] | |
| degree_map[node_id] = int(degree) | |
| return degree_map | |
| def get_isolated_nodes(self) -> List[str]: | |
| query = """ | |
| MATCH (n:Entity) | |
| WHERE NOT (n)--() | |
| RETURN n.id | |
| """ | |
| result = self._conn.execute(query) | |
| return [row[0] for row in result if row] | |
| def get_node_count(self) -> int: | |
| result = self._conn.execute("MATCH (n:Entity) RETURN count(n)") | |
| return result.get_next()[0] | |
| def get_edge_count(self) -> int: | |
| result = self._conn.execute("MATCH ()-[e:Relation]->() RETURN count(e)") | |
| return result.get_next()[0] | |
| def get_connected_components(self, undirected: bool = True) -> List[Set[str]]: | |
| parent = {} | |
| rank = {} | |
| def find(x: str) -> str: | |
| if parent[x] != x: | |
| parent[x] = find(parent[x]) | |
| return parent[x] | |
| def union(x: str, y: str): | |
| root_x, root_y = find(x), find(y) | |
| if root_x == root_y: | |
| return | |
| if rank[root_x] < rank[root_y]: | |
| parent[root_x] = root_y | |
| elif rank[root_x] > rank[root_y]: | |
| parent[root_y] = root_x | |
| else: | |
| parent[root_y] = root_x | |
| rank[root_x] += 1 | |
| all_nodes = self.get_all_node_degrees().keys() | |
| for node_id in all_nodes: | |
| parent[node_id] = node_id | |
| rank[node_id] = 0 | |
| query = ( | |
| """ | |
| MATCH (a:Entity)-[e:Relation]-(b:Entity) | |
| RETURN DISTINCT a.id, b.id | |
| """ | |
| if undirected | |
| else """ | |
| MATCH (a:Entity)-[e:Relation]->(b:Entity) | |
| RETURN DISTINCT a.id, b.id | |
| """ | |
| ) | |
| result = self._conn.execute(query) | |
| for row in result: | |
| if row and len(row) >= 2: | |
| union(row[0], row[1]) | |
| components_dict = defaultdict(set) | |
| for node_id in all_nodes: | |
| root = find(node_id) | |
| components_dict[root].add(node_id) | |
| return list(components_dict.values()) | |
| def has_node(self, node_id: str) -> bool: | |
| result = self._conn.execute( | |
| "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id} | |
| ) | |
| count = result.get_next()[0] | |
| return count > 0 | |
| def has_edge(self, source_node_id: str, target_node_id: str): | |
| result = self._conn.execute( | |
| "MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)", | |
| {"src": source_node_id, "dst": target_node_id}, | |
| ) | |
| count = result.get_next()[0] | |
| return count > 0 | |
| def node_degree(self, node_id: str) -> int: | |
| # Calculate total degree (incoming + outgoing) | |
| query = """ | |
| MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity) | |
| RETURN count(e) | |
| """ | |
| result = self._conn.execute(query, {"id": node_id}) | |
| if result.has_next(): | |
| return result.get_next()[0] | |
| return 0 | |
| def edge_degree(self, src_id: str, tgt_id: str) -> int: | |
| # In this context, usually checks existence or multiplicity. | |
| # Kuzu supports multi-edges, so we count them. | |
| query = """ | |
| MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) | |
| RETURN count(e) | |
| """ | |
| result = self._conn.execute(query, {"src": src_id, "dst": tgt_id}) | |
| if result.has_next(): | |
| return result.get_next()[0] | |
| return 0 | |
| def get_node(self, node_id: str) -> Any: | |
| result = self._conn.execute( | |
| "MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id} | |
| ) | |
| if not result.has_next(): | |
| return None | |
| data_str = result.get_next()[0] | |
| return self._safe_json_loads(data_str) | |
| def update_node(self, node_id: str, node_data: dict[str, str]): | |
| current_data = self.get_node(node_id) | |
| if current_data is None: | |
| print(f"Node {node_id} not found for update.") | |
| return | |
| # Merge existing data with new data | |
| current_data.update(node_data) | |
| try: | |
| json_data = json.dumps(current_data, ensure_ascii=False) | |
| except (TypeError, ValueError) as e: | |
| print(f"Error serializing JSON for node {node_id}: {e}") | |
| return | |
| self._conn.execute( | |
| "MATCH (a:Entity {id: $id}) SET a.data = $data", | |
| {"id": node_id, "data": json_data}, | |
| ) | |
| def get_all_nodes(self) -> Any: | |
| """Returns List[Tuple[id, data_dict]]""" | |
| result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data") | |
| nodes = [] | |
| while result.has_next(): | |
| row = result.get_next() | |
| if row is None or len(row) < 2: | |
| continue | |
| node_id, data_str = row[0], row[1] | |
| data = self._safe_json_loads(data_str) | |
| nodes.append((node_id, data)) | |
| return nodes | |
| def get_edge(self, source_node_id: str, target_node_id: str): | |
| # Warning: If multiple edges exist, this returns the first one found | |
| query = """ | |
| MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) | |
| RETURN e.data | |
| """ | |
| result = self._conn.execute( | |
| query, {"src": source_node_id, "dst": target_node_id} | |
| ) | |
| if not result.has_next(): | |
| return None | |
| data_str = result.get_next()[0] | |
| return self._safe_json_loads(data_str) | |
| def update_edge( | |
| self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | |
| ): | |
| current_data = self.get_edge(source_node_id, target_node_id) | |
| if current_data is None: | |
| print(f"Edge {source_node_id}->{target_node_id} not found for update.") | |
| return | |
| current_data.update(edge_data) | |
| try: | |
| json_data = json.dumps(current_data, ensure_ascii=False) | |
| except (TypeError, ValueError) as e: | |
| print( | |
| f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}" | |
| ) | |
| return | |
| self._conn.execute( | |
| """ | |
| MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) | |
| SET e.data = $data | |
| """, | |
| {"src": source_node_id, "dst": target_node_id, "data": json_data}, | |
| ) | |
| def get_all_edges(self) -> Any: | |
| """Returns List[Tuple[src, dst, data_dict]]""" | |
| query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data" | |
| result = self._conn.execute(query) | |
| edges = [] | |
| while result.has_next(): | |
| row = result.get_next() | |
| if row is None or len(row) < 3: | |
| continue | |
| src, dst, data_str = row[0], row[1], row[2] | |
| data = self._safe_json_loads(data_str) | |
| edges.append((src, dst, data)) | |
| return edges | |
| def get_node_edges(self, source_node_id: str) -> Any: | |
| """Returns generic edges connected to this node (outgoing)""" | |
| query = """ | |
| MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity) | |
| RETURN a.id, b.id, e.data | |
| """ | |
| result = self._conn.execute(query, {"src": source_node_id}) | |
| edges = [] | |
| while result.has_next(): | |
| row = result.get_next() | |
| if row is None or len(row) < 3: | |
| continue | |
| src, dst, data_str = row[0], row[1], row[2] | |
| data = self._safe_json_loads(data_str) | |
| edges.append((src, dst, data)) | |
| return edges | |
| def upsert_node(self, node_id: str, node_data: dict[str, str]): | |
| """ | |
| Insert or Update node. | |
| Kuzu supports MERGE clause (similar to Neo4j) to handle upserts. | |
| """ | |
| try: | |
| json_data = json.dumps(node_data, ensure_ascii=False) | |
| except (TypeError, ValueError) as e: | |
| print(f"Error serializing JSON for node {node_id}: {e}") | |
| return | |
| query = """ | |
| MERGE (a:Entity {id: $id}) | |
| ON MATCH SET a.data = $data | |
| ON CREATE SET a.data = $data | |
| """ | |
| self._conn.execute(query, {"id": node_id, "data": json_data}) | |
| def upsert_edge( | |
| self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] | |
| ): | |
| """ | |
| Insert or Update edge. | |
| Note: We explicitly ensure nodes exist before merging the edge to avoid errors, | |
| although GraphGen generally creates nodes before edges. | |
| """ | |
| # Ensure source node exists | |
| if not self.has_node(source_node_id): | |
| self.upsert_node(source_node_id, {}) | |
| # Ensure target node exists | |
| if not self.has_node(target_node_id): | |
| self.upsert_node(target_node_id, {}) | |
| try: | |
| json_data = json.dumps(edge_data, ensure_ascii=False) | |
| except (TypeError, ValueError) as e: | |
| print( | |
| f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}" | |
| ) | |
| return | |
| query = """ | |
| MATCH (a:Entity {id: $src}), (b:Entity {id: $dst}) | |
| MERGE (a)-[e:Relation]->(b) | |
| ON MATCH SET e.data = $data | |
| ON CREATE SET e.data = $data | |
| """ | |
| self._conn.execute( | |
| query, {"src": source_node_id, "dst": target_node_id, "data": json_data} | |
| ) | |
| def delete_node(self, node_id: str): | |
| # DETACH DELETE removes the node and all connected edges | |
| query = "MATCH (a:Entity {id: $id}) DETACH DELETE a" | |
| self._conn.execute(query, {"id": node_id}) | |
| print(f"Node {node_id} deleted from KuzuDB.") | |
| def clear(self): | |
| """Clear all data but keep schema (or drop tables).""" | |
| self._conn.execute("MATCH (n) DETACH DELETE n") | |
| print(f"Graph {self.namespace} cleared.") | |
| def reload(self): | |
| """For databases that need reloading, KuzuDB auto-manages this.""" | |