File size: 9,157 Bytes
f871fed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
"""
Knowledge Graph Domain Models

Models for concept nodes and edges that form the knowledge graph.
"""

from datetime import datetime
from typing import List, Optional, Literal
from pydantic import BaseModel, Field

from open_notebook.database.repository import repo


class ConceptNode(BaseModel):
    """A concept node in the knowledge graph"""
    id: Optional[str] = None
    label: str
    type: Literal["concept", "person", "event", "place", "organization"] = "concept"
    description: Optional[str] = None
    source_ids: List[str] = Field(default_factory=list)
    notebook_id: Optional[str] = None
    embedding: Optional[List[float]] = None
    importance: float = Field(default=0.5, ge=0, le=1)
    mentions: int = 1
    created: Optional[datetime] = None
    updated: Optional[datetime] = None
    
    class Config:
        table_name = "concept_node"
    
    @classmethod
    async def get(cls, id: str) -> Optional["ConceptNode"]:
        """Get a concept node by ID"""
        result = await repo.get(id)
        if result:
            return cls(**result)
        return None
    
    @classmethod
    async def find_by_notebook(cls, notebook_id: str) -> List["ConceptNode"]:
        """Get all concept nodes for a notebook"""
        query = f"SELECT * FROM concept_node WHERE notebook_id = $notebook_id"
        results = await repo.query(query, {"notebook_id": notebook_id})
        
        return [cls(**r) for r in results] if results else []
    
    @classmethod
    async def find_by_label(cls, label: str, notebook_id: Optional[str] = None) -> Optional["ConceptNode"]:
        """Find a concept by label, optionally scoped to notebook"""
        if notebook_id:
            query = "SELECT * FROM concept_node WHERE label = $label AND notebook_id = $notebook_id LIMIT 1"
            results = await repo.query(query, {"label": label, "notebook_id": notebook_id})
        else:
            query = "SELECT * FROM concept_node WHERE label = $label LIMIT 1"
            results = await repo.query(query, {"label": label})
        
        if results:
            return cls(**results[0])
        return None
    
    async def save(self) -> "ConceptNode":
        """Save the concept node"""
        self.updated = datetime.now()
        if not self.created:
            self.created = datetime.now()
        
        data = self.model_dump(exclude={"id"}, exclude_none=True)
        
        if self.id:
            result = await repo.update(self.id, data)
        else:
            result = await repo.create("concept_node", data)
        
        if result:
            # repo.create returns a list, extract first element
            if isinstance(result, list) and len(result) > 0:
                result = result[0]
            if isinstance(result, dict):
                return ConceptNode(**result)
        return self
    
    async def delete(self) -> bool:
        """Delete the concept node"""
        if self.id:
            return await repo.delete(self.id)
        return False


class ConceptEdge(BaseModel):
    """An edge/relationship between concept nodes"""
    id: Optional[str] = None
    source_node: str  # concept_node ID
    target_node: str  # concept_node ID
    relationship: str  # is_a, part_of, causes, related_to, etc.
    description: Optional[str] = None
    weight: float = 1.0
    source_ids: List[str] = Field(default_factory=list)
    notebook_id: Optional[str] = None
    created: Optional[datetime] = None
    
    class Config:
        table_name = "concept_edge"
    
    @classmethod
    async def get(cls, id: str) -> Optional["ConceptEdge"]:
        """Get an edge by ID"""
        result = await repo.get(id)
        if result:
            return cls(**result)
        return None
    
    @classmethod
    async def find_by_notebook(cls, notebook_id: str) -> List["ConceptEdge"]:
        """Get all edges for a notebook"""
        query = f"SELECT * FROM concept_edge WHERE notebook_id = $notebook_id"
        results = await repo.query(query, {"notebook_id": notebook_id})
        
        return [cls(**r) for r in results] if results else []
    
    @classmethod
    async def find_by_node(cls, node_id: str) -> List["ConceptEdge"]:
        """Find all edges connected to a node"""
        query = "SELECT * FROM concept_edge WHERE source_node = $node_id OR target_node = $node_id"
        results = await repo.query(query, {"node_id": node_id})
        return [cls(**r) for r in results]
    
    async def save(self) -> "ConceptEdge":
        """Save the edge"""
        if not self.created:
            self.created = datetime.now()
        
        data = self.model_dump(exclude={"id"}, exclude_none=True)
        
        if self.id:
            result = await repo.update(self.id, data)
        else:
            result = await repo.create("concept_edge", data)
        
        if result:
            # repo.create returns a list, extract first element
            if isinstance(result, list) and len(result) > 0:
                result = result[0]
            if isinstance(result, dict):
                return ConceptEdge(**result)
        return self
    
    async def delete(self) -> bool:
        """Delete the edge"""
        if self.id:
            return await repo.delete(self.id)
        return False


class KnowledgeGraphMeta(BaseModel):
    """Metadata about a notebook's knowledge graph"""
    id: Optional[str] = None
    notebook_id: str
    node_count: int = 0
    edge_count: int = 0
    last_built: Optional[datetime] = None
    build_status: Literal["pending", "building", "completed", "error"] = "pending"
    error_message: Optional[str] = None
    created: Optional[datetime] = None
    updated: Optional[datetime] = None
    
    class Config:
        table_name = "knowledge_graph_meta"
    
    @classmethod
    async def get_for_notebook(cls, notebook_id: str) -> Optional["KnowledgeGraphMeta"]:
        """Get metadata for a notebook's knowledge graph"""
        query = "SELECT * FROM knowledge_graph_meta WHERE notebook_id = $notebook_id LIMIT 1"
        results = await repo.query(query, {"notebook_id": notebook_id})
        if results:
            return cls(**results[0])
        return None
    
    async def save(self) -> "KnowledgeGraphMeta":
        """Save the metadata"""
        self.updated = datetime.now()
        if not self.created:
            self.created = datetime.now()
        
        data = self.model_dump(exclude={"id"}, exclude_none=True)
        
        if self.id:
            result = await repo.update(self.id, data)
        else:
            result = await repo.create("knowledge_graph_meta", data)
        
        if result:
            # repo.create returns a list, extract first element
            if isinstance(result, list) and len(result) > 0:
                result = result[0]
            if isinstance(result, dict):
                return KnowledgeGraphMeta(**result)
        return self


class KnowledgeGraph(BaseModel):
    """A complete knowledge graph with nodes and edges"""
    notebook_id: str
    nodes: List[ConceptNode] = Field(default_factory=list)
    edges: List[ConceptEdge] = Field(default_factory=list)
    meta: Optional[KnowledgeGraphMeta] = None
    
    @classmethod
    async def load(cls, notebook_id: str) -> "KnowledgeGraph":
        """Load the complete knowledge graph for a notebook"""
        nodes = await ConceptNode.find_by_notebook(notebook_id)
        edges = await ConceptEdge.find_by_notebook(notebook_id)
        meta = await KnowledgeGraphMeta.get_for_notebook(notebook_id)
        
        return cls(
            notebook_id=notebook_id,
            nodes=nodes,
            edges=edges,
            meta=meta
        )
    
    def to_graph_data(self) -> dict:
        """Convert to format suitable for visualization libraries"""
        return {
            "nodes": [
                {
                    "id": node.id,
                    "label": node.label,
                    "type": node.type,
                    "description": node.description,
                    "importance": node.importance,
                    "mentions": node.mentions,
                    "val": node.importance * 10 + node.mentions,  # Size for force graph
                    "color": self._get_node_color(node.type),
                }
                for node in self.nodes
            ],
            "links": [
                {
                    "source": edge.source_node,
                    "target": edge.target_node,
                    "relationship": edge.relationship,
                    "weight": edge.weight,
                }
                for edge in self.edges
            ],
        }
    
    @staticmethod
    def _get_node_color(node_type: str) -> str:
        """Get color for node type"""
        colors = {
            "concept": "#60a5fa",    # blue-400
            "person": "#a78bfa",     # violet-400
            "event": "#fbbf24",      # amber-400
            "place": "#34d399",      # emerald-400
            "organization": "#f87171",  # red-400
        }
        return colors.get(node_type, "#94a3b8")  # slate-400 default