contextflow-rl / app /agents /knowledge_graph_agent.py
namish10's picture
Upload app/agents/knowledge_graph_agent.py with huggingface_hub
eba543b verified
"""
Knowledge Graph Agent with GraphRAG
Manages the user's knowledge graph using GraphRAG:
- Nodes: concepts, doubts, topics, resources
- Edges: relationships, dependencies, associations
- GraphRAG for retrieval and generation
"""
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
import json
@dataclass
class GraphNode:
"""Knowledge graph node"""
node_id: str
node_type: str
label: str
properties: Dict = field(default_factory=dict)
embeddings: Optional[List[float]] = None
created_at: datetime = field(default_factory=datetime.now)
@dataclass
class GraphEdge:
"""Knowledge graph edge"""
edge_id: str
source_id: str
target_id: str
relation_type: str
weight: float = 1.0
properties: Dict = field(default_factory=dict)
created_at: datetime = field(default_factory=datetime.now)
@dataclass
class Ontology:
"""Domain ontology for topic structure"""
entity_types: List[Dict] = field(default_factory=list)
relation_types: List[Dict] = field(default_factory=list)
class KnowledgeGraphAgent:
"""
Agent that manages the knowledge graph with GraphRAG capabilities.
Features:
- Entity extraction from doubts and notes
- Relationship discovery
- Graph-based retrieval
- Path finding between concepts
- Ontology generation
"""
def __init__(self, user_id: str, config: Optional[Dict] = None):
self.user_id = user_id
self.config = config or {}
self.nodes: Dict[str, GraphNode] = {}
self.edges: Dict[str, GraphEdge] = {}
self.graph_id = f"cf_graph_{user_id}_{datetime.now().timestamp()}"
self._initialize_default_ontology()
def _initialize_default_ontology(self):
"""Initialize default learning ontology"""
self.ontology = Ontology(
entity_types=[
{'name': 'Concept', 'description': 'A learning concept or topic'},
{'name': 'Doubt', 'description': 'A question or confusion point'},
{'name': 'Resource', 'description': 'Learning resource or material'},
{'name': 'Topic', 'description': 'Main subject area'},
{'name': 'Skill', 'description': 'Developed skill or competency'}
],
relation_types=[
{'name': 'prerequisite_of', 'description': 'Is prerequisite for'},
{'name': 'related_to', 'description': 'Is related to'},
{'name': 'part_of', 'description': 'Is part of'},
{'name': 'helps_understand', 'description': 'Helps understand'},
{'name': 'contrasts_with', 'description': 'Contrasts with'}
]
)
def add_doubt_to_graph(self, doubt_data: Dict) -> GraphNode:
"""Add a captured doubt to the knowledge graph"""
node_id = f"doubt_{doubt_data.get('id', datetime.now().timestamp())}"
concept_tags = doubt_data.get('conceptTags', [])
node = GraphNode(
node_id=node_id,
node_type='Doubt',
label=doubt_data.get('formattedTitle', doubt_data.get('rawText', '')),
properties={
'raw_text': doubt_data.get('rawText', ''),
'summary': doubt_data.get('formattedSummary', ''),
'doubt_type': doubt_data.get('doubtType', 'concept'),
'concepts': concept_tags,
'url': doubt_data.get('page', {}).get('url', ''),
'mastered': doubt_data.get('mastered', False),
'review_count': doubt_data.get('reviewCount', 0)
}
)
self.nodes[node_id] = node
for concept in concept_tags:
self._ensure_concept_node(concept)
self._add_edge(
source=concept,
target=node_id,
relation='part_of'
)
return node
def _ensure_concept_node(self, concept: str) -> GraphNode:
"""Ensure a concept node exists in the graph"""
concept_id = f"concept_{concept.lower().replace(' ', '_')}"
if concept_id in self.nodes:
return self.nodes[concept_id]
node = GraphNode(
node_id=concept_id,
node_type='Concept',
label=concept,
properties={
'mastery_level': 0.0,
'importance': 0.5,
'last_reviewed': None
}
)
self.nodes[concept_id] = node
return node
def _add_edge(
self,
source: str,
target: str,
relation: str,
weight: float = 1.0
) -> GraphEdge:
"""Add an edge between nodes"""
edge_id = f"edge_{source}_{target}_{relation}"
source_id = f"concept_{source.lower().replace(' ', '_')}" if not source.startswith('concept_') else source
target_id = f"concept_{target.lower().replace(' ', '_')}" if not target.startswith('concept_') else target
if source_id not in self.nodes or target_id not in self.nodes:
return None
edge = GraphEdge(
edge_id=edge_id,
source_id=source_id,
target_id=target_id,
relation_type=relation,
weight=weight
)
self.edges[edge_id] = edge
return edge
def add_resource(self, resource_data: Dict) -> GraphNode:
"""Add a learning resource to the graph"""
node_id = f"resource_{resource_data.get('id', datetime.now().timestamp())}"
node = GraphNode(
node_id=node_id,
node_type='Resource',
label=resource_data.get('title', 'Untitled Resource'),
properties={
'url': resource_data.get('url', ''),
'type': resource_data.get('type', 'webpage'),
'topics': resource_data.get('topics', []),
'difficulty': resource_data.get('difficulty', 0.5)
}
)
self.nodes[node_id] = node
for topic in resource_data.get('topics', []):
self._ensure_concept_node(topic)
self._add_edge(topic, node_id, 'part_of')
return node
def add_topic(self, topic: str, parent: Optional[str] = None) -> GraphNode:
"""Add a topic node to the graph"""
node = self._ensure_concept_node(topic)
if parent:
self._ensure_concept_node(parent)
self._add_edge(topic, parent, 'prerequisite_of')
return node
def graphrag_retrieve(
self,
query: str,
top_k: int = 5
) -> List[Dict]:
"""
GraphRAG retrieval - find relevant nodes based on query.
Uses:
- Keyword matching
- Graph traversal
- Relationship scoring
"""
results = []
query_lower = query.lower()
query_terms = query_lower.split()
for node_id, node in self.nodes.items():
score = 0.0
label_lower = node.label.lower()
for term in query_terms:
if term in label_lower:
score += 1.0
if term in str(node.properties).lower():
score += 0.5
if node.node_type == 'Doubt' and 'mastered' in node.properties:
if node.properties['mastered']:
score *= 0.8
if score > 0:
results.append({
'node': node,
'score': score,
'matched_terms': [t for t in query_terms if t in label_lower]
})
results.sort(key=lambda x: x['score'], reverse=True)
return [{
'node_id': r['node'].node_id,
'type': r['node'].node_type,
'label': r['node'].label,
'properties': r['node'].properties,
'score': r['score'],
'related': self._get_related_nodes(r['node'].node_id, limit=3)
} for r in results[:top_k]]
def _get_related_nodes(self, node_id: str, limit: int = 3) -> List[Dict]:
"""Get related nodes through graph traversal"""
related = []
for edge_id, edge in self.edges.items():
if edge.source_id == node_id:
target = self.nodes.get(edge.target_id)
if target:
related.append({
'node_id': target.node_id,
'type': target.node_type,
'label': target.label,
'relation': edge.relation_type
})
elif edge.target_id == node_id:
source = self.nodes.get(edge.source_id)
if source:
related.append({
'node_id': source.node_id,
'type': source.node_type,
'label': source.label,
'relation': edge.relation_type
})
return related[:limit]
def find_learning_path(
self,
from_topic: str,
to_topic: str
) -> List[str]:
"""Find shortest path between two topics using BFS"""
from_id = f"concept_{from_topic.lower().replace(' ', '_')}"
to_id = f"concept_{to_topic.lower().replace(' ', '_')}"
if from_id not in self.nodes or to_id not in self.nodes:
return []
queue = [(from_id, [from_id])]
visited = {from_id}
while queue:
current, path = queue.pop(0)
if current == to_id:
return [self.nodes[n].label for n in path]
for edge_id, edge in self.edges.items():
neighbor = None
if edge.source_id == current:
neighbor = edge.target_id
elif edge.target_id == current:
neighbor = edge.source_id
if neighbor and neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [neighbor]))
return []
def get_topic_mastery(self) -> Dict[str, float]:
"""Calculate mastery level for each topic"""
mastery = {}
for node_id, node in self.nodes.items():
if node.node_type == 'Concept':
related_doubts = self._get_doubt_count(node_id)
total_doubts = len([n for n in self.nodes.values() if n.node_type == 'Doubt'])
if total_doubts > 0:
mastery[node.label] = 1.0 - (related_doubts / total_doubts)
else:
mastery[node.label] = 0.0
return mastery
def _get_doubt_count(self, concept_id: str) -> int:
"""Get number of doubts associated with a concept"""
count = 0
for edge_id, edge in self.edges.items():
if edge.source_id == concept_id and edge.relation_type == 'part_of':
target = self.nodes.get(edge.target_id)
if target and target.node_type == 'Doubt':
count += 1
return count
def get_graph_stats(self) -> Dict:
"""Get graph statistics"""
node_types = {}
for node in self.nodes.values():
node_types[node.node_type] = node_types.get(node.node_type, 0) + 1
relation_types = {}
for edge in self.edges.values():
relation_types[edge.relation_type] = relation_types.get(edge.relation_type, 0) + 1
return {
'graph_id': self.graph_id,
'total_nodes': len(self.nodes),
'total_edges': len(self.edges),
'node_types': node_types,
'relation_types': relation_types,
'mastery': self.get_topic_mastery()
}
def export_graph(self) -> Dict:
"""Export graph for persistence"""
return {
'graph_id': self.graph_id,
'nodes': [
{
'node_id': n.node_id,
'node_type': n.node_type,
'label': n.label,
'properties': n.properties
}
for n in self.nodes.values()
],
'edges': [
{
'edge_id': e.edge_id,
'source_id': e.source_id,
'target_id': e.target_id,
'relation_type': e.relation_type,
'weight': e.weight
}
for e in self.edges.values()
],
'ontology': {
'entity_types': self.ontology.entity_types,
'relation_types': self.ontology.relation_types
}
}
def import_graph(self, graph_data: Dict):
"""Import graph from persistence"""
self.graph_id = graph_data.get('graph_id', self.graph_id)
self.nodes.clear()
self.edges.clear()
for node_data in graph_data.get('nodes', []):
node = GraphNode(
node_id=node_data['node_id'],
node_type=node_data['node_type'],
label=node_data['label'],
properties=node_data.get('properties', {})
)
self.nodes[node.node_id] = node
for edge_data in graph_data.get('edges', []):
edge = GraphEdge(
edge_id=edge_data['edge_id'],
source_id=edge_data['source_id'],
target_id=edge_data['target_id'],
relation_type=edge_data['relation_type'],
weight=edge_data.get('weight', 1.0)
)
self.edges[edge.edge_id] = edge
async def sync_to_zep(self):
"""Sync graph to Zep Cloud for advanced GraphRAG"""
pass
async def sync_to_graph(self):
"""Sync current state"""
pass