github-actions[bot]
Auto-sync from demo at Fri Dec 26 08:29:01 UTC 2025
7566ac3
import json
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Set
try:
import kuzu
except ImportError:
kuzu = None
from graphgen.bases.base_storage import BaseGraphStorage
@dataclass
class KuzuStorage(BaseGraphStorage):
"""
Graph storage implementation based on KuzuDB.
Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties,
we map the data to a generic schema:
- Node Table 'Entity': {id: STRING, data: STRING (JSON)}
- Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)}
"""
working_dir: str = None
namespace: str = None
_db: Any = None
_conn: Any = None
def __post_init__(self):
if kuzu is None:
raise ImportError(
"KuzuDB is not installed. Please install it via `pip install kuzu`."
)
self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu")
self._init_db()
def _init_db(self):
# KuzuDB automatically creates the directory
self._db = kuzu.Database(self.db_path)
self._conn = kuzu.Connection(self._db)
self._init_schema()
print(f"KuzuDB initialized at {self.db_path}")
def _init_schema(self):
"""Initialize the generic Node and Edge tables if they don't exist."""
# Check and create Node table
try:
# We use a generic table name "Entity" to store all nodes
self._conn.execute(
"CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))"
)
print("Created KuzuDB Node Table 'Entity'")
except RuntimeError as e:
# Usually throws if table exists, verify safely or ignore
print("Node Table 'Entity' already exists or error:", e)
# Check and create Edge table
try:
# We use a generic table name "Relation" to store all edges
self._conn.execute(
"CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)"
)
print("Created KuzuDB Rel Table 'Relation'")
except RuntimeError as e:
print("Rel Table 'Relation' already exists or error:", e)
def index_done_callback(self):
"""KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""
@staticmethod
def _safe_json_loads(data_str: str) -> dict:
if not isinstance(data_str, str) or not data_str.strip():
return {}
try:
return json.loads(data_str)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
return {}
def is_directed(self) -> bool:
return True
def get_all_node_degrees(self) -> Dict[str, int]:
query = """
MATCH (n:Entity)
OPTIONAL MATCH (n)-[r]-()
RETURN n.id, count(r) as degree
"""
result = self._conn.execute(query)
degree_map = {}
while result.has_next():
row = result.get_next()
if row and len(row) >= 2:
node_id, degree = row[0], row[1]
degree_map[node_id] = int(degree)
return degree_map
def get_isolated_nodes(self) -> List[str]:
query = """
MATCH (n:Entity)
WHERE NOT (n)--()
RETURN n.id
"""
result = self._conn.execute(query)
return [row[0] for row in result if row]
def get_node_count(self) -> int:
result = self._conn.execute("MATCH (n:Entity) RETURN count(n)")
return result.get_next()[0]
def get_edge_count(self) -> int:
result = self._conn.execute("MATCH ()-[e:Relation]->() RETURN count(e)")
return result.get_next()[0]
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
parent = {}
rank = {}
def find(x: str) -> str:
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x: str, y: str):
root_x, root_y = find(x), find(y)
if root_x == root_y:
return
if rank[root_x] < rank[root_y]:
parent[root_x] = root_y
elif rank[root_x] > rank[root_y]:
parent[root_y] = root_x
else:
parent[root_y] = root_x
rank[root_x] += 1
all_nodes = self.get_all_node_degrees().keys()
for node_id in all_nodes:
parent[node_id] = node_id
rank[node_id] = 0
query = (
"""
MATCH (a:Entity)-[e:Relation]-(b:Entity)
RETURN DISTINCT a.id, b.id
"""
if undirected
else """
MATCH (a:Entity)-[e:Relation]->(b:Entity)
RETURN DISTINCT a.id, b.id
"""
)
result = self._conn.execute(query)
for row in result:
if row and len(row) >= 2:
union(row[0], row[1])
components_dict = defaultdict(set)
for node_id in all_nodes:
root = find(node_id)
components_dict[root].add(node_id)
return list(components_dict.values())
def has_node(self, node_id: str) -> bool:
result = self._conn.execute(
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
)
count = result.get_next()[0]
return count > 0
def has_edge(self, source_node_id: str, target_node_id: str):
result = self._conn.execute(
"MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)",
{"src": source_node_id, "dst": target_node_id},
)
count = result.get_next()[0]
return count > 0
def node_degree(self, node_id: str) -> int:
# Calculate total degree (incoming + outgoing)
query = """
MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity)
RETURN count(e)
"""
result = self._conn.execute(query, {"id": node_id})
if result.has_next():
return result.get_next()[0]
return 0
def edge_degree(self, src_id: str, tgt_id: str) -> int:
# In this context, usually checks existence or multiplicity.
# Kuzu supports multi-edges, so we count them.
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
RETURN count(e)
"""
result = self._conn.execute(query, {"src": src_id, "dst": tgt_id})
if result.has_next():
return result.get_next()[0]
return 0
def get_node(self, node_id: str) -> Any:
result = self._conn.execute(
"MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
)
if not result.has_next():
return None
data_str = result.get_next()[0]
return self._safe_json_loads(data_str)
def update_node(self, node_id: str, node_data: dict[str, str]):
current_data = self.get_node(node_id)
if current_data is None:
print(f"Node {node_id} not found for update.")
return
# Merge existing data with new data
current_data.update(node_data)
try:
json_data = json.dumps(current_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(f"Error serializing JSON for node {node_id}: {e}")
return
self._conn.execute(
"MATCH (a:Entity {id: $id}) SET a.data = $data",
{"id": node_id, "data": json_data},
)
def get_all_nodes(self) -> Any:
"""Returns List[Tuple[id, data_dict]]"""
result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data")
nodes = []
while result.has_next():
row = result.get_next()
if row is None or len(row) < 2:
continue
node_id, data_str = row[0], row[1]
data = self._safe_json_loads(data_str)
nodes.append((node_id, data))
return nodes
def get_edge(self, source_node_id: str, target_node_id: str):
# Warning: If multiple edges exist, this returns the first one found
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
RETURN e.data
"""
result = self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id}
)
if not result.has_next():
return None
data_str = result.get_next()[0]
return self._safe_json_loads(data_str)
def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
current_data = self.get_edge(source_node_id, target_node_id)
if current_data is None:
print(f"Edge {source_node_id}->{target_node_id} not found for update.")
return
current_data.update(edge_data)
try:
json_data = json.dumps(current_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(
f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}"
)
return
self._conn.execute(
"""
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
SET e.data = $data
""",
{"src": source_node_id, "dst": target_node_id, "data": json_data},
)
def get_all_edges(self) -> Any:
"""Returns List[Tuple[src, dst, data_dict]]"""
query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data"
result = self._conn.execute(query)
edges = []
while result.has_next():
row = result.get_next()
if row is None or len(row) < 3:
continue
src, dst, data_str = row[0], row[1], row[2]
data = self._safe_json_loads(data_str)
edges.append((src, dst, data))
return edges
def get_node_edges(self, source_node_id: str) -> Any:
"""Returns generic edges connected to this node (outgoing)"""
query = """
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity)
RETURN a.id, b.id, e.data
"""
result = self._conn.execute(query, {"src": source_node_id})
edges = []
while result.has_next():
row = result.get_next()
if row is None or len(row) < 3:
continue
src, dst, data_str = row[0], row[1], row[2]
data = self._safe_json_loads(data_str)
edges.append((src, dst, data))
return edges
def upsert_node(self, node_id: str, node_data: dict[str, str]):
"""
Insert or Update node.
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
"""
try:
json_data = json.dumps(node_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(f"Error serializing JSON for node {node_id}: {e}")
return
query = """
MERGE (a:Entity {id: $id})
ON MATCH SET a.data = $data
ON CREATE SET a.data = $data
"""
self._conn.execute(query, {"id": node_id, "data": json_data})
def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
):
"""
Insert or Update edge.
Note: We explicitly ensure nodes exist before merging the edge to avoid errors,
although GraphGen generally creates nodes before edges.
"""
# Ensure source node exists
if not self.has_node(source_node_id):
self.upsert_node(source_node_id, {})
# Ensure target node exists
if not self.has_node(target_node_id):
self.upsert_node(target_node_id, {})
try:
json_data = json.dumps(edge_data, ensure_ascii=False)
except (TypeError, ValueError) as e:
print(
f"Error serializing JSON for edge {source_node_id}->{target_node_id}: {e}"
)
return
query = """
MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
MERGE (a)-[e:Relation]->(b)
ON MATCH SET e.data = $data
ON CREATE SET e.data = $data
"""
self._conn.execute(
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
)
def delete_node(self, node_id: str):
# DETACH DELETE removes the node and all connected edges
query = "MATCH (a:Entity {id: $id}) DETACH DELETE a"
self._conn.execute(query, {"id": node_id})
print(f"Node {node_id} deleted from KuzuDB.")
def clear(self):
"""Clear all data but keep schema (or drop tables)."""
self._conn.execute("MATCH (n) DETACH DELETE n")
print(f"Graph {self.namespace} cleared.")
def reload(self):
"""For databases that need reloading, KuzuDB auto-manages this."""