Spaces:
Sleeping
Sleeping
| """ | |
| Knowledge Graph Domain Models | |
| Models for concept nodes and edges that form the knowledge graph. | |
| """ | |
| from datetime import datetime | |
| from typing import List, Optional, Literal | |
| from pydantic import BaseModel, Field | |
| from open_notebook.database.repository import repo | |
| class ConceptNode(BaseModel): | |
| """A concept node in the knowledge graph""" | |
| id: Optional[str] = None | |
| label: str | |
| type: Literal["concept", "person", "event", "place", "organization"] = "concept" | |
| description: Optional[str] = None | |
| source_ids: List[str] = Field(default_factory=list) | |
| notebook_id: Optional[str] = None | |
| embedding: Optional[List[float]] = None | |
| importance: float = Field(default=0.5, ge=0, le=1) | |
| mentions: int = 1 | |
| created: Optional[datetime] = None | |
| updated: Optional[datetime] = None | |
| class Config: | |
| table_name = "concept_node" | |
| async def get(cls, id: str) -> Optional["ConceptNode"]: | |
| """Get a concept node by ID""" | |
| result = await repo.get(id) | |
| if result: | |
| return cls(**result) | |
| return None | |
| async def find_by_notebook(cls, notebook_id: str) -> List["ConceptNode"]: | |
| """Get all concept nodes for a notebook""" | |
| query = f"SELECT * FROM concept_node WHERE notebook_id = $notebook_id" | |
| results = await repo.query(query, {"notebook_id": notebook_id}) | |
| return [cls(**r) for r in results] if results else [] | |
| async def find_by_label(cls, label: str, notebook_id: Optional[str] = None) -> Optional["ConceptNode"]: | |
| """Find a concept by label, optionally scoped to notebook""" | |
| if notebook_id: | |
| query = "SELECT * FROM concept_node WHERE label = $label AND notebook_id = $notebook_id LIMIT 1" | |
| results = await repo.query(query, {"label": label, "notebook_id": notebook_id}) | |
| else: | |
| query = "SELECT * FROM concept_node WHERE label = $label LIMIT 1" | |
| results = await repo.query(query, {"label": label}) | |
| if results: | |
| return cls(**results[0]) | |
| return None | |
| async def save(self) -> "ConceptNode": | |
| """Save the concept node""" | |
| self.updated = datetime.now() | |
| if not self.created: | |
| self.created = datetime.now() | |
| data = self.model_dump(exclude={"id"}, exclude_none=True) | |
| if self.id: | |
| result = await repo.update(self.id, data) | |
| else: | |
| result = await repo.create("concept_node", data) | |
| if result: | |
| # repo.create returns a list, extract first element | |
| if isinstance(result, list) and len(result) > 0: | |
| result = result[0] | |
| if isinstance(result, dict): | |
| return ConceptNode(**result) | |
| return self | |
| async def delete(self) -> bool: | |
| """Delete the concept node""" | |
| if self.id: | |
| return await repo.delete(self.id) | |
| return False | |
| class ConceptEdge(BaseModel): | |
| """An edge/relationship between concept nodes""" | |
| id: Optional[str] = None | |
| source_node: str # concept_node ID | |
| target_node: str # concept_node ID | |
| relationship: str # is_a, part_of, causes, related_to, etc. | |
| description: Optional[str] = None | |
| weight: float = 1.0 | |
| source_ids: List[str] = Field(default_factory=list) | |
| notebook_id: Optional[str] = None | |
| created: Optional[datetime] = None | |
| class Config: | |
| table_name = "concept_edge" | |
| async def get(cls, id: str) -> Optional["ConceptEdge"]: | |
| """Get an edge by ID""" | |
| result = await repo.get(id) | |
| if result: | |
| return cls(**result) | |
| return None | |
| async def find_by_notebook(cls, notebook_id: str) -> List["ConceptEdge"]: | |
| """Get all edges for a notebook""" | |
| query = f"SELECT * FROM concept_edge WHERE notebook_id = $notebook_id" | |
| results = await repo.query(query, {"notebook_id": notebook_id}) | |
| return [cls(**r) for r in results] if results else [] | |
| async def find_by_node(cls, node_id: str) -> List["ConceptEdge"]: | |
| """Find all edges connected to a node""" | |
| query = "SELECT * FROM concept_edge WHERE source_node = $node_id OR target_node = $node_id" | |
| results = await repo.query(query, {"node_id": node_id}) | |
| return [cls(**r) for r in results] | |
| async def save(self) -> "ConceptEdge": | |
| """Save the edge""" | |
| if not self.created: | |
| self.created = datetime.now() | |
| data = self.model_dump(exclude={"id"}, exclude_none=True) | |
| if self.id: | |
| result = await repo.update(self.id, data) | |
| else: | |
| result = await repo.create("concept_edge", data) | |
| if result: | |
| # repo.create returns a list, extract first element | |
| if isinstance(result, list) and len(result) > 0: | |
| result = result[0] | |
| if isinstance(result, dict): | |
| return ConceptEdge(**result) | |
| return self | |
| async def delete(self) -> bool: | |
| """Delete the edge""" | |
| if self.id: | |
| return await repo.delete(self.id) | |
| return False | |
| class KnowledgeGraphMeta(BaseModel): | |
| """Metadata about a notebook's knowledge graph""" | |
| id: Optional[str] = None | |
| notebook_id: str | |
| node_count: int = 0 | |
| edge_count: int = 0 | |
| last_built: Optional[datetime] = None | |
| build_status: Literal["pending", "building", "completed", "error"] = "pending" | |
| error_message: Optional[str] = None | |
| created: Optional[datetime] = None | |
| updated: Optional[datetime] = None | |
| class Config: | |
| table_name = "knowledge_graph_meta" | |
| async def get_for_notebook(cls, notebook_id: str) -> Optional["KnowledgeGraphMeta"]: | |
| """Get metadata for a notebook's knowledge graph""" | |
| query = "SELECT * FROM knowledge_graph_meta WHERE notebook_id = $notebook_id LIMIT 1" | |
| results = await repo.query(query, {"notebook_id": notebook_id}) | |
| if results: | |
| return cls(**results[0]) | |
| return None | |
| async def save(self) -> "KnowledgeGraphMeta": | |
| """Save the metadata""" | |
| self.updated = datetime.now() | |
| if not self.created: | |
| self.created = datetime.now() | |
| data = self.model_dump(exclude={"id"}, exclude_none=True) | |
| if self.id: | |
| result = await repo.update(self.id, data) | |
| else: | |
| result = await repo.create("knowledge_graph_meta", data) | |
| if result: | |
| # repo.create returns a list, extract first element | |
| if isinstance(result, list) and len(result) > 0: | |
| result = result[0] | |
| if isinstance(result, dict): | |
| return KnowledgeGraphMeta(**result) | |
| return self | |
| class KnowledgeGraph(BaseModel): | |
| """A complete knowledge graph with nodes and edges""" | |
| notebook_id: str | |
| nodes: List[ConceptNode] = Field(default_factory=list) | |
| edges: List[ConceptEdge] = Field(default_factory=list) | |
| meta: Optional[KnowledgeGraphMeta] = None | |
| async def load(cls, notebook_id: str) -> "KnowledgeGraph": | |
| """Load the complete knowledge graph for a notebook""" | |
| nodes = await ConceptNode.find_by_notebook(notebook_id) | |
| edges = await ConceptEdge.find_by_notebook(notebook_id) | |
| meta = await KnowledgeGraphMeta.get_for_notebook(notebook_id) | |
| return cls( | |
| notebook_id=notebook_id, | |
| nodes=nodes, | |
| edges=edges, | |
| meta=meta | |
| ) | |
| def to_graph_data(self) -> dict: | |
| """Convert to format suitable for visualization libraries""" | |
| return { | |
| "nodes": [ | |
| { | |
| "id": node.id, | |
| "label": node.label, | |
| "type": node.type, | |
| "description": node.description, | |
| "importance": node.importance, | |
| "mentions": node.mentions, | |
| "val": node.importance * 10 + node.mentions, # Size for force graph | |
| "color": self._get_node_color(node.type), | |
| } | |
| for node in self.nodes | |
| ], | |
| "links": [ | |
| { | |
| "source": edge.source_node, | |
| "target": edge.target_node, | |
| "relationship": edge.relationship, | |
| "weight": edge.weight, | |
| } | |
| for edge in self.edges | |
| ], | |
| } | |
| def _get_node_color(node_type: str) -> str: | |
| """Get color for node type""" | |
| colors = { | |
| "concept": "#60a5fa", # blue-400 | |
| "person": "#a78bfa", # violet-400 | |
| "event": "#fbbf24", # amber-400 | |
| "place": "#34d399", # emerald-400 | |
| "organization": "#f87171", # red-400 | |
| } | |
| return colors.get(node_type, "#94a3b8") # slate-400 default | |