rsm-wew068's picture
sdfghjfghhj
67cb4f3
#!/usr/bin/env python3
"""
SIMPLE GraphRAG - No more complexity hell!
Just semantic search + graph expansion. That's it.
"""
from sentence_transformers import SentenceTransformer
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
import pickle
import os
from typing import List, Dict, Tuple, Optional
import numpy as np
class GraphRAG:
"""Dead simple GraphRAG - find stuff, expand from there."""
def __init__(self, graph_path: str = "/tmp/topic_graph.gpickle"):
self.graph_path = graph_path
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
self.graph: Optional[nx.DiGraph] = None
self.node_embeddings: Dict[str, np.ndarray] = {}
def load_graph_with_embeddings(self) -> bool:
"""Load graph and compute semantic embeddings."""
if not os.path.exists(self.graph_path):
return False
try:
with open(self.graph_path, "rb") as f:
self.graph = pickle.load(f)
self._compute_embeddings()
return True
except Exception:
return False
def _compute_embeddings(self):
"""Compute AI embeddings for all nodes."""
for node, attrs in self.graph.nodes(data=True):
label = attrs.get("label", "")
name = attrs.get("name", str(node))
text = f"{label}: {name}"
embedding = self.embedder.encode(text)
self.node_embeddings[node] = embedding
def semantic_search(self, query: str,
top_k: int = 10) -> List[Tuple[str, float]]:
"""Find nodes most similar to query."""
if not self.node_embeddings:
return []
query_embedding = self.embedder.encode(query)
similarities = []
for node, embedding in self.node_embeddings.items():
sim = cosine_similarity([query_embedding], [embedding])[0][0]
if sim >= 0.2: # Reasonable threshold
similarities.append((node, sim))
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def expand_from_nodes(self, start_nodes: List[str], max_nodes: int = 10, direct_only: bool = False) -> set:
"""Expand from starting nodes following only relevant connections."""
connected = set(start_nodes)
if direct_only:
# Only add direct neighbors - no expansion of expansions
for node in start_nodes:
if node not in self.graph:
continue
# Add direct neighbors (outgoing)
for neighbor in self.graph.neighbors(node):
connected.add(neighbor)
# Add direct predecessors (incoming)
for predecessor in self.graph.predecessors(node):
connected.add(predecessor)
return connected
# Topic-centered expansion: only expand through meaningful relationships
to_expand = list(start_nodes)
while to_expand and len(connected) < max_nodes:
current_node = to_expand.pop(0)
if current_node not in self.graph:
continue
current_attrs = self.graph.nodes.get(current_node, {})
current_label = current_attrs.get('label', '')
# Add directly connected nodes based on meaningful relationships
for neighbor in self.graph.neighbors(current_node):
if len(connected) >= max_nodes:
break
if neighbor not in connected:
neighbor_attrs = self.graph.nodes.get(neighbor, {})
neighbor_label = neighbor_attrs.get('label', '')
# Get the edge relationship
edge_data = self.graph.get_edge_data(current_node, neighbor, {})
edge_label = edge_data.get('label', '')
# Only include nodes with strong semantic relationships
should_include = False
if current_label == 'Topic':
# From Topic: include tasks directly labeled with this topic
if neighbor_label == 'Task' and edge_label == 'HAS_TASK':
should_include = True
elif current_label == 'Task':
# From Task: include assignees, dates, summaries - NOT other tasks
if neighbor_label in ['Person', 'Date', 'Summary', 'Email Index'] and edge_label in ['RESPONSIBLE_TO', 'COLLABORATED_BY', 'DUE_ON', 'START_ON', 'BASED_ON', 'LINKED_TO']:
should_include = True
elif current_label == 'Person':
# From Person: include their role/department/organization hierarchy
if neighbor_label in ['Role', 'Department', 'Organization'] and edge_label in ['HAS_ROLE', 'BELONGS_TO', 'IS_IN']:
should_include = True
elif current_label == 'Role':
# From Role: include department
if neighbor_label == 'Department' and edge_label == 'BELONGS_TO':
should_include = True
elif current_label == 'Department':
# From Department: include organization
if neighbor_label == 'Organization' and edge_label == 'IS_IN':
should_include = True
if should_include:
connected.add(neighbor)
# Queue for expansion to get full hierarchies
if neighbor_label in ['Task', 'Person', 'Role', 'Department']:
to_expand.append(neighbor)
# Also check predecessors for reverse relationships
for predecessor in self.graph.predecessors(current_node):
if len(connected) >= max_nodes:
break
if predecessor not in connected:
pred_attrs = self.graph.nodes.get(predecessor, {})
pred_label = pred_attrs.get('label', '')
# Get the edge relationship
edge_data = self.graph.get_edge_data(predecessor, current_node, {})
edge_label = edge_data.get('label', '')
# Include meaningful reverse relationships
should_include = False
if current_label == 'Task' and pred_label == 'Topic' and edge_label == 'HAS_TASK':
should_include = True
elif current_label in ['Date', 'Summary', 'Email Index'] and pred_label == 'Task':
should_include = True
elif current_label in ['Role', 'Department', 'Organization'] and pred_label == 'Person':
should_include = True
if should_include:
connected.add(predecessor)
return connected
def query(self, query: str, direct_only: bool = False, max_nodes: int = 25) -> Dict:
"""Topic-centered query for maximum accuracy."""
if not self.load_graph_with_embeddings():
return {
'query': query,
'error': 'No graph found. Process emails first.',
'nodes': []
}
# Step 1: Try topic name matching first (highest accuracy)
topic_matches = self.search_topics_by_name(query, semantic_threshold=0.5)
if topic_matches:
# Found topic name matches - use ALL good matches for inclusive approach
good_topics = [topic for topic, score in topic_matches if score >= 0.5]
start_nodes = good_topics # Include all related topics
all_nodes = self.expand_from_nodes(
start_nodes, max_nodes=max_nodes, direct_only=direct_only
)
confidence = topic_matches[0][1] # Use best match confidence
explanation = f"Found {len(all_nodes)} nodes from {len(good_topics)} related topic(s)"
if direct_only:
explanation += " (direct neighbors only)"
return {
'query': query,
'relevant_nodes': [(topic, score) for topic, score in topic_matches if score >= 0.5],
'all_nodes': list(all_nodes),
'confidence_score': round(confidence, 3),
'explanation': explanation,
'method': 'topic_name_search'
}
# No topic matches found - show actual topics
available_topics = []
for node, attrs in self.graph.nodes(data=True):
if attrs.get('label') == 'Topic':
topic_name = attrs.get('name', str(node))
available_topics.append(topic_name)
if available_topics:
topic_list = ", ".join(available_topics)
error_msg = f'No topic found matching "{query}". Available: {topic_list}.'
else:
error_msg = f'No topic found matching "{query}". No topics in graph.'
return {
'query': query,
'error': error_msg,
'nodes': [],
'method': 'no_match'
}
def generate_visualization_html(self, query: str, result: Dict) -> str:
"""Generate visualization HTML content directly without saving to file."""
try:
from pyvis.network import Network
except ImportError:
return "<p>pyvis not installed</p>"
if not self.graph:
return "<p>No graph loaded</p>"
try:
net = Network(height="600px", width="100%")
# Show only connected nodes
nodes_to_show = set(result.get('all_nodes', []))
if not nodes_to_show:
return "<p>No nodes found in query result</p>"
subgraph = self.graph.subgraph(nodes_to_show)
# Colors for topic-centered hierarchy
colors = {
'Topic': '#FF6B9D', # Pink - most important
'Task': '#90EE90', # Light green
'Person': '#87CEEB', # Sky blue
'Role': '#FFA500', # Orange
'Department': '#DDA0DD', # Plum
'Organization': '#F0E68C', # Khaki
'Date': '#D3D3D3', # Light gray
'Summary': '#FFE4B5', # Moccasin
'Email Index': '#E6E6FA' # Lavender
}
# Add nodes with topic-centered sizing
for node, attrs in subgraph.nodes(data=True):
label = attrs.get('label', '')
name = attrs.get('name', str(node))
color = colors.get(label, '#BDC3C7')
# Smaller node sizing for better readability
if label == 'Topic':
node_size = 25 # Reduced from 50
elif label == 'Task':
node_size = 20 # Reduced from 35
elif label == 'Person':
node_size = 15 # Reduced from 25
else:
node_size = 12 # Reduced from 20
# Shorter display names for better visibility
if label == 'Task':
display_name = name[:25] + "..." if len(name) > 25 else name
elif label == 'Summary':
display_name = name[:30] + "..." if len(name) > 30 else name
else:
display_name = name[:20] + "..." if len(name) > 20 else name
# Create detailed tooltip with all attributes
tooltip_parts = [f"<b>{label}</b>: {name}"]
for key, value in attrs.items():
if key not in ['label', 'name'] and value:
tooltip_parts.append(f"{key}: {value}")
# For Person nodes, add FULL role/dept/org info to tooltip
if label == 'Person':
person_details = _get_person_details(self.graph, node)
if person_details:
details_clean = person_details.strip('() ')
tooltip_parts.append(f"<b>Full Details:</b> {details_clean}")
tooltip = "<br>".join(tooltip_parts)
net.add_node(
node,
label=display_name,
title=tooltip,
color=color,
size=node_size,
font={'size': 10, 'color': 'black'} # Reduced from 14
)
# Add edges
for u, v, edge_attrs in subgraph.edges(data=True):
edge_label = edge_attrs.get('label', '')
net.add_edge(u, v, label=edge_label)
# Set heading and generate HTML
net.heading = f"Query: {query}"
# Generate HTML content directly
html_content = net.generate_html()
return html_content
except Exception as e:
return f"<p>Error generating visualization: {str(e)}</p>"
def search_topics_by_name(self, query: str, semantic_threshold: float = 0.5) -> List[Tuple[str, float]]:
"""Search for topics using semantic similarity with flexible matching."""
if not self.graph or not self.node_embeddings:
return []
# Encode the query
query_embedding = self.embedder.encode(query)
topic_matches = []
for node, attrs in self.graph.nodes(data=True):
if attrs.get('label') == 'Topic':
# Get the embedding for this topic node
if node in self.node_embeddings:
topic_embedding = self.node_embeddings[node]
# Calculate semantic similarity
similarity = cosine_similarity([query_embedding], [topic_embedding])[0][0]
# Also check for substring matches to catch variations
topic_name = attrs.get('name', str(node)).lower()
query_lower = query.lower()
# Boost similarity for substring matches or close variations
if (query_lower in topic_name or
any(word in topic_name for word in query_lower.split()) or
similarity >= semantic_threshold):
# Give higher score to exact or close matches
if query_lower in topic_name:
similarity = max(similarity, 0.9)
topic_matches.append((node, similarity))
return sorted(topic_matches, key=lambda x: x[1], reverse=True)
# Compatibility methods
def query_with_semantic_reasoning(self, query: str) -> Dict:
return self.query(query)
def format_response(result: Dict) -> str:
"""Format response like the old system with structured task details."""
if 'error' in result:
return result['error']
if not result.get('all_nodes'):
return "No information found."
try:
import pickle
with open("/tmp/topic_graph.gpickle", "rb") as f:
graph = pickle.load(f)
# Find all tasks in the result
tasks = []
for node in result.get('all_nodes', []):
if node in graph:
attrs = graph.nodes[node]
if attrs.get('label') == 'Task':
tasks.append(node)
if not tasks:
return "No tasks found in the results."
# Format each task in the structured format
response_parts = []
for task_node in tasks:
task_attrs = graph.nodes[task_node]
task_name = task_attrs.get('name', str(task_node))
task_info = [f"**Task:** {task_name}"]
# Find the topic for this task
for neighbor in graph.neighbors(task_node):
edge_data = graph.get_edge_data(task_node, neighbor, {})
edge_label = edge_data.get('label', '')
neighbor_attrs = graph.nodes[neighbor]
if neighbor_attrs.get('label') == 'Topic':
topic_name = neighbor_attrs.get('name', neighbor)
task_info.append(f"**Topic:** {topic_name}")
break
# Get all the direct neighbors with their relationships
for neighbor in graph.neighbors(task_node):
edge_data = graph.get_edge_data(task_node, neighbor, {})
edge_label = edge_data.get('label', '')
neighbor_attrs = graph.nodes[neighbor]
neighbor_name = neighbor_attrs.get('name', neighbor)
neighbor_label = neighbor_attrs.get('label', '')
if edge_label == 'START_ON':
task_info.append(f" β€’ **Start Date:** {neighbor_name}")
elif edge_label == 'DUE_ON':
task_info.append(f" β€’ **Due Date:** {neighbor_name}")
elif edge_label == 'BASED_ON' or neighbor_label == 'Summary':
task_info.append(f" β€’ **Summary:** {neighbor_name}")
elif edge_label == 'LINKED_TO' or neighbor_label == 'Email Index':
task_info.append(f" β€’ **Email Index:** {neighbor_name}")
elif edge_label == 'RESPONSIBLE_TO':
# Get role/dept/org info for the person
person_details = _get_person_details(graph, neighbor)
task_info.append(f" β€’ **Responsible To:** {neighbor_name}{person_details}")
elif edge_label == 'COLLABORATED_BY':
person_details = _get_person_details(graph, neighbor)
task_info.append(f" β€’ **Collaborated By:** {neighbor_name}{person_details}")
response_parts.append("\n".join(task_info))
# Add confidence at the end
confidence = result.get('confidence_score', 0.0)
conf_text = "🟒 High" if confidence > 0.7 else "🟑 Medium" if confidence > 0.4 else "πŸ”΄ Low"
response_parts.append(f"\n**Confidence:** {conf_text} ({confidence})")
return "\n\n".join(response_parts)
except Exception as e:
return f"πŸ“Š Error formatting response: {str(e)}"
def _get_person_details(graph, person_node):
"""Get role, department, organization details for a person."""
details = []
for neighbor in graph.neighbors(person_node):
edge_data = graph.get_edge_data(person_node, neighbor, {})
edge_label = edge_data.get('label', '')
neighbor_attrs = graph.nodes[neighbor]
if edge_label == 'HAS_ROLE' or neighbor_attrs.get('label') == 'Role':
role_name = neighbor_attrs.get('name', neighbor)
details.append(f"Role: {role_name}")
# Get department for this role
for dept_neighbor in graph.neighbors(neighbor):
dept_edge = graph.get_edge_data(neighbor, dept_neighbor, {})
dept_edge_label = dept_edge.get('label', '')
dept_attrs = graph.nodes[dept_neighbor]
if dept_edge_label == 'BELONGS_TO' or dept_attrs.get('label') == 'Department':
dept_name = dept_attrs.get('name', dept_neighbor)
details.append(f"Department: {dept_name}")
# Get organization for this department
for org_neighbor in graph.neighbors(dept_neighbor):
org_edge = graph.get_edge_data(dept_neighbor, org_neighbor, {})
org_edge_label = org_edge.get('label', '')
org_attrs = graph.nodes[org_neighbor]
if org_edge_label == 'IS_IN' or org_attrs.get('label') == 'Organization':
org_name = org_attrs.get('name', org_neighbor)
details.append(f"Organization: {org_name}")
break
break
break
if details:
return f" ({', '.join(details)})"
return ""
def format_graphrag_response(result: Dict) -> str:
"""Compatibility function."""
return format_response(result)