#!/usr/bin/env python3 """ SIMPLE GraphRAG - No more complexity hell! Just semantic search + graph expansion. That's it. """ from sentence_transformers import SentenceTransformer import networkx as nx from sklearn.metrics.pairwise import cosine_similarity import pickle import os from typing import List, Dict, Tuple, Optional import numpy as np class GraphRAG: """Dead simple GraphRAG - find stuff, expand from there.""" def __init__(self, graph_path: str = "/tmp/topic_graph.gpickle"): self.graph_path = graph_path self.embedder = SentenceTransformer('all-MiniLM-L6-v2') self.graph: Optional[nx.DiGraph] = None self.node_embeddings: Dict[str, np.ndarray] = {} def load_graph_with_embeddings(self) -> bool: """Load graph and compute semantic embeddings.""" if not os.path.exists(self.graph_path): return False try: with open(self.graph_path, "rb") as f: self.graph = pickle.load(f) self._compute_embeddings() return True except Exception: return False def _compute_embeddings(self): """Compute AI embeddings for all nodes.""" for node, attrs in self.graph.nodes(data=True): label = attrs.get("label", "") name = attrs.get("name", str(node)) text = f"{label}: {name}" embedding = self.embedder.encode(text) self.node_embeddings[node] = embedding def semantic_search(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]: """Find nodes most similar to query.""" if not self.node_embeddings: return [] query_embedding = self.embedder.encode(query) similarities = [] for node, embedding in self.node_embeddings.items(): sim = cosine_similarity([query_embedding], [embedding])[0][0] if sim >= 0.2: # Reasonable threshold similarities.append((node, sim)) return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] def expand_from_nodes(self, start_nodes: List[str], max_nodes: int = 10, direct_only: bool = False) -> set: """Expand from starting nodes following only relevant connections.""" connected = set(start_nodes) if direct_only: # Only add direct neighbors - no expansion of expansions for node in start_nodes: if node not in self.graph: continue # Add direct neighbors (outgoing) for neighbor in self.graph.neighbors(node): connected.add(neighbor) # Add direct predecessors (incoming) for predecessor in self.graph.predecessors(node): connected.add(predecessor) return connected # Topic-centered expansion: only expand through meaningful relationships to_expand = list(start_nodes) while to_expand and len(connected) < max_nodes: current_node = to_expand.pop(0) if current_node not in self.graph: continue current_attrs = self.graph.nodes.get(current_node, {}) current_label = current_attrs.get('label', '') # Add directly connected nodes based on meaningful relationships for neighbor in self.graph.neighbors(current_node): if len(connected) >= max_nodes: break if neighbor not in connected: neighbor_attrs = self.graph.nodes.get(neighbor, {}) neighbor_label = neighbor_attrs.get('label', '') # Get the edge relationship edge_data = self.graph.get_edge_data(current_node, neighbor, {}) edge_label = edge_data.get('label', '') # Only include nodes with strong semantic relationships should_include = False if current_label == 'Topic': # From Topic: include tasks directly labeled with this topic if neighbor_label == 'Task' and edge_label == 'HAS_TASK': should_include = True elif current_label == 'Task': # From Task: include assignees, dates, summaries - NOT other tasks if neighbor_label in ['Person', 'Date', 'Summary', 'Email Index'] and edge_label in ['RESPONSIBLE_TO', 'COLLABORATED_BY', 'DUE_ON', 'START_ON', 'BASED_ON', 'LINKED_TO']: should_include = True elif current_label == 'Person': # From Person: include their role/department/organization hierarchy if neighbor_label in ['Role', 'Department', 'Organization'] and edge_label in ['HAS_ROLE', 'BELONGS_TO', 'IS_IN']: should_include = True elif current_label == 'Role': # From Role: include department if neighbor_label == 'Department' and edge_label == 'BELONGS_TO': should_include = True elif current_label == 'Department': # From Department: include organization if neighbor_label == 'Organization' and edge_label == 'IS_IN': should_include = True if should_include: connected.add(neighbor) # Queue for expansion to get full hierarchies if neighbor_label in ['Task', 'Person', 'Role', 'Department']: to_expand.append(neighbor) # Also check predecessors for reverse relationships for predecessor in self.graph.predecessors(current_node): if len(connected) >= max_nodes: break if predecessor not in connected: pred_attrs = self.graph.nodes.get(predecessor, {}) pred_label = pred_attrs.get('label', '') # Get the edge relationship edge_data = self.graph.get_edge_data(predecessor, current_node, {}) edge_label = edge_data.get('label', '') # Include meaningful reverse relationships should_include = False if current_label == 'Task' and pred_label == 'Topic' and edge_label == 'HAS_TASK': should_include = True elif current_label in ['Date', 'Summary', 'Email Index'] and pred_label == 'Task': should_include = True elif current_label in ['Role', 'Department', 'Organization'] and pred_label == 'Person': should_include = True if should_include: connected.add(predecessor) return connected def query(self, query: str, direct_only: bool = False, max_nodes: int = 25) -> Dict: """Topic-centered query for maximum accuracy.""" if not self.load_graph_with_embeddings(): return { 'query': query, 'error': 'No graph found. Process emails first.', 'nodes': [] } # Step 1: Try topic name matching first (highest accuracy) topic_matches = self.search_topics_by_name(query, semantic_threshold=0.5) if topic_matches: # Found topic name matches - use ALL good matches for inclusive approach good_topics = [topic for topic, score in topic_matches if score >= 0.5] start_nodes = good_topics # Include all related topics all_nodes = self.expand_from_nodes( start_nodes, max_nodes=max_nodes, direct_only=direct_only ) confidence = topic_matches[0][1] # Use best match confidence explanation = f"Found {len(all_nodes)} nodes from {len(good_topics)} related topic(s)" if direct_only: explanation += " (direct neighbors only)" return { 'query': query, 'relevant_nodes': [(topic, score) for topic, score in topic_matches if score >= 0.5], 'all_nodes': list(all_nodes), 'confidence_score': round(confidence, 3), 'explanation': explanation, 'method': 'topic_name_search' } # No topic matches found - show actual topics available_topics = [] for node, attrs in self.graph.nodes(data=True): if attrs.get('label') == 'Topic': topic_name = attrs.get('name', str(node)) available_topics.append(topic_name) if available_topics: topic_list = ", ".join(available_topics) error_msg = f'No topic found matching "{query}". Available: {topic_list}.' else: error_msg = f'No topic found matching "{query}". No topics in graph.' return { 'query': query, 'error': error_msg, 'nodes': [], 'method': 'no_match' } def generate_visualization_html(self, query: str, result: Dict) -> str: """Generate visualization HTML content directly without saving to file.""" try: from pyvis.network import Network except ImportError: return "

pyvis not installed

" if not self.graph: return "

No graph loaded

" try: net = Network(height="600px", width="100%") # Show only connected nodes nodes_to_show = set(result.get('all_nodes', [])) if not nodes_to_show: return "

No nodes found in query result

" subgraph = self.graph.subgraph(nodes_to_show) # Colors for topic-centered hierarchy colors = { 'Topic': '#FF6B9D', # Pink - most important 'Task': '#90EE90', # Light green 'Person': '#87CEEB', # Sky blue 'Role': '#FFA500', # Orange 'Department': '#DDA0DD', # Plum 'Organization': '#F0E68C', # Khaki 'Date': '#D3D3D3', # Light gray 'Summary': '#FFE4B5', # Moccasin 'Email Index': '#E6E6FA' # Lavender } # Add nodes with topic-centered sizing for node, attrs in subgraph.nodes(data=True): label = attrs.get('label', '') name = attrs.get('name', str(node)) color = colors.get(label, '#BDC3C7') # Smaller node sizing for better readability if label == 'Topic': node_size = 25 # Reduced from 50 elif label == 'Task': node_size = 20 # Reduced from 35 elif label == 'Person': node_size = 15 # Reduced from 25 else: node_size = 12 # Reduced from 20 # Shorter display names for better visibility if label == 'Task': display_name = name[:25] + "..." if len(name) > 25 else name elif label == 'Summary': display_name = name[:30] + "..." if len(name) > 30 else name else: display_name = name[:20] + "..." if len(name) > 20 else name # Create detailed tooltip with all attributes tooltip_parts = [f"{label}: {name}"] for key, value in attrs.items(): if key not in ['label', 'name'] and value: tooltip_parts.append(f"{key}: {value}") # For Person nodes, add FULL role/dept/org info to tooltip if label == 'Person': person_details = _get_person_details(self.graph, node) if person_details: details_clean = person_details.strip('() ') tooltip_parts.append(f"Full Details: {details_clean}") tooltip = "
".join(tooltip_parts) net.add_node( node, label=display_name, title=tooltip, color=color, size=node_size, font={'size': 10, 'color': 'black'} # Reduced from 14 ) # Add edges for u, v, edge_attrs in subgraph.edges(data=True): edge_label = edge_attrs.get('label', '') net.add_edge(u, v, label=edge_label) # Set heading and generate HTML net.heading = f"Query: {query}" # Generate HTML content directly html_content = net.generate_html() return html_content except Exception as e: return f"

Error generating visualization: {str(e)}

" def search_topics_by_name(self, query: str, semantic_threshold: float = 0.5) -> List[Tuple[str, float]]: """Search for topics using semantic similarity with flexible matching.""" if not self.graph or not self.node_embeddings: return [] # Encode the query query_embedding = self.embedder.encode(query) topic_matches = [] for node, attrs in self.graph.nodes(data=True): if attrs.get('label') == 'Topic': # Get the embedding for this topic node if node in self.node_embeddings: topic_embedding = self.node_embeddings[node] # Calculate semantic similarity similarity = cosine_similarity([query_embedding], [topic_embedding])[0][0] # Also check for substring matches to catch variations topic_name = attrs.get('name', str(node)).lower() query_lower = query.lower() # Boost similarity for substring matches or close variations if (query_lower in topic_name or any(word in topic_name for word in query_lower.split()) or similarity >= semantic_threshold): # Give higher score to exact or close matches if query_lower in topic_name: similarity = max(similarity, 0.9) topic_matches.append((node, similarity)) return sorted(topic_matches, key=lambda x: x[1], reverse=True) # Compatibility methods def query_with_semantic_reasoning(self, query: str) -> Dict: return self.query(query) def format_response(result: Dict) -> str: """Format response like the old system with structured task details.""" if 'error' in result: return result['error'] if not result.get('all_nodes'): return "No information found." try: import pickle with open("/tmp/topic_graph.gpickle", "rb") as f: graph = pickle.load(f) # Find all tasks in the result tasks = [] for node in result.get('all_nodes', []): if node in graph: attrs = graph.nodes[node] if attrs.get('label') == 'Task': tasks.append(node) if not tasks: return "No tasks found in the results." # Format each task in the structured format response_parts = [] for task_node in tasks: task_attrs = graph.nodes[task_node] task_name = task_attrs.get('name', str(task_node)) task_info = [f"**Task:** {task_name}"] # Find the topic for this task for neighbor in graph.neighbors(task_node): edge_data = graph.get_edge_data(task_node, neighbor, {}) edge_label = edge_data.get('label', '') neighbor_attrs = graph.nodes[neighbor] if neighbor_attrs.get('label') == 'Topic': topic_name = neighbor_attrs.get('name', neighbor) task_info.append(f"**Topic:** {topic_name}") break # Get all the direct neighbors with their relationships for neighbor in graph.neighbors(task_node): edge_data = graph.get_edge_data(task_node, neighbor, {}) edge_label = edge_data.get('label', '') neighbor_attrs = graph.nodes[neighbor] neighbor_name = neighbor_attrs.get('name', neighbor) neighbor_label = neighbor_attrs.get('label', '') if edge_label == 'START_ON': task_info.append(f" • **Start Date:** {neighbor_name}") elif edge_label == 'DUE_ON': task_info.append(f" • **Due Date:** {neighbor_name}") elif edge_label == 'BASED_ON' or neighbor_label == 'Summary': task_info.append(f" • **Summary:** {neighbor_name}") elif edge_label == 'LINKED_TO' or neighbor_label == 'Email Index': task_info.append(f" • **Email Index:** {neighbor_name}") elif edge_label == 'RESPONSIBLE_TO': # Get role/dept/org info for the person person_details = _get_person_details(graph, neighbor) task_info.append(f" • **Responsible To:** {neighbor_name}{person_details}") elif edge_label == 'COLLABORATED_BY': person_details = _get_person_details(graph, neighbor) task_info.append(f" • **Collaborated By:** {neighbor_name}{person_details}") response_parts.append("\n".join(task_info)) # Add confidence at the end confidence = result.get('confidence_score', 0.0) conf_text = "🟢 High" if confidence > 0.7 else "🟡 Medium" if confidence > 0.4 else "🔴 Low" response_parts.append(f"\n**Confidence:** {conf_text} ({confidence})") return "\n\n".join(response_parts) except Exception as e: return f"📊 Error formatting response: {str(e)}" def _get_person_details(graph, person_node): """Get role, department, organization details for a person.""" details = [] for neighbor in graph.neighbors(person_node): edge_data = graph.get_edge_data(person_node, neighbor, {}) edge_label = edge_data.get('label', '') neighbor_attrs = graph.nodes[neighbor] if edge_label == 'HAS_ROLE' or neighbor_attrs.get('label') == 'Role': role_name = neighbor_attrs.get('name', neighbor) details.append(f"Role: {role_name}") # Get department for this role for dept_neighbor in graph.neighbors(neighbor): dept_edge = graph.get_edge_data(neighbor, dept_neighbor, {}) dept_edge_label = dept_edge.get('label', '') dept_attrs = graph.nodes[dept_neighbor] if dept_edge_label == 'BELONGS_TO' or dept_attrs.get('label') == 'Department': dept_name = dept_attrs.get('name', dept_neighbor) details.append(f"Department: {dept_name}") # Get organization for this department for org_neighbor in graph.neighbors(dept_neighbor): org_edge = graph.get_edge_data(dept_neighbor, org_neighbor, {}) org_edge_label = org_edge.get('label', '') org_attrs = graph.nodes[org_neighbor] if org_edge_label == 'IS_IN' or org_attrs.get('label') == 'Organization': org_name = org_attrs.get('name', org_neighbor) details.append(f"Organization: {org_name}") break break break if details: return f" ({', '.join(details)})" return "" def format_graphrag_response(result: Dict) -> str: """Compatibility function.""" return format_response(result)