CSRC-Car-Manual-RAG / src /knowledge_graph.py
Bryceeee's picture
Upload 17 files
0cfa3a6 verified
"""
Knowledge Graph Visualization Module
Creates knowledge maps and similarity heatmaps from document relationships
"""
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import re
import json
from typing import Tuple, Optional, Dict, List
from openai import OpenAI
from pathlib import Path
class KnowledgeGraphGenerator:
"""Generates knowledge graphs and visualizations"""
def __init__(self, client: OpenAI, vector_store_id: str, output_dir: str = "output"):
self.client = client
self.vector_store_id = vector_store_id
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def get_files_from_vector_store(self) -> List[str]:
"""Get list of files from vector store"""
try:
query = "List all documents in the manual"
response = self.client.responses.create(
input=query,
model="gpt-4o-mini",
tools=[{
"type": "file_search",
"vector_store_ids": [self.vector_store_id],
"max_num_results": 25
}]
)
file_list = []
if response and hasattr(response.output[1].content[0], 'annotations'):
annotations = response.output[1].content[0].annotations
file_list = list(set([annotation.filename for annotation in annotations]))
file_list = [f.replace('.pdf', '') for f in file_list]
file_list.sort()
return file_list
except Exception as e:
print(f"❌ Error getting files: {str(e)}")
return []
def extract_topics_from_content(self, file_list: List[str]) -> Tuple[Dict[str, List[str]], List[str]]:
"""Extract topics from document content using GPT"""
all_topics = set()
file_topics = {}
file_descriptions = {}
print("πŸ“– Getting content descriptions for each file...")
# Get descriptions for each file
for file in file_list:
try:
query = f"What is the main purpose and key concepts covered in the document titled '{file}'? Be brief and focused on technical concepts."
response = self.client.responses.create(
input=query,
model="gpt-4o-mini",
tools=[{
"type": "file_search",
"vector_store_ids": [self.vector_store_id]
}]
)
if response and hasattr(response.output[1], 'content'):
description = response.output[1].content[0].text
file_descriptions[file] = description
print(f" βœ“ Got description for {file}")
else:
file_descriptions[file] = f"Information about {file}"
except Exception as e:
print(f" ⚠️ Error getting description for {file}: {e}")
file_descriptions[file] = f"Information about {file}"
# Extract topics from descriptions
prompt = "Extract key technical concepts (single words or short phrases) from these document descriptions. Focus on functional concepts, components, and technologies.\n\n"
for file, desc in file_descriptions.items():
prompt += f"Document: {file}\nDescription: {desc}\n\n"
prompt += "\nFor each document, list 3-5 key technical concepts. Format as a JSON object where keys are document names and values are arrays of concepts."
try:
response = self.client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You extract key technical concepts from document descriptions in a structured way."},
{"role": "user", "content": prompt}
],
temperature=0.3
)
topics_text = response.choices[0].message.content
json_match = re.search(r'\{.*\}', topics_text, re.DOTALL)
if json_match:
try:
file_topics = json.loads(json_match.group(0))
for topics in file_topics.values():
all_topics.update(topics)
print(f"βœ… Successfully extracted topics for {len(file_topics)} documents")
except json.JSONDecodeError:
print("⚠️ Error parsing JSON response, using fallback")
file_topics = self._create_fallback_topics(file_list)
else:
file_topics = self._create_fallback_topics(file_list)
except Exception as e:
print(f"⚠️ Error extracting topics: {e}, using fallback")
file_topics = self._create_fallback_topics(file_list)
# Ensure all files have topics
for file in file_list:
if file not in file_topics or not file_topics[file]:
words = [word for word in re.findall(r'\b[A-Za-z]{3,}\b', file)
if word.lower() not in ['the', 'and', 'for', 'with', 'function', 'of']]
file_topics[file] = words if words else ["Topic"]
return file_topics, list(all_topics)
def _create_fallback_topics(self, file_list: List[str]) -> Dict[str, List[str]]:
"""Create fallback topics from filenames"""
file_topics = {}
for file in file_list:
words = [word for word in re.findall(r'\b[A-Za-z]{3,}\b', file)
if word.lower() not in ['the', 'and', 'for', 'with', 'function', 'of']]
file_topics[file] = words if words else ["Topic"]
return file_topics
def analyze_document_relationships(self, file_list: List[str],
file_topics: Dict[str, List[str]]) -> np.ndarray:
"""Analyze relationships between documents based on topics"""
n = len(file_list)
similarity_matrix = np.zeros((n, n))
# Create topic vectors
all_topics = set()
for topics in file_topics.values():
all_topics.update(topics)
topic_list = list(all_topics)
# Create binary vectors for each document
topic_vectors = {}
for file in file_list:
vector = np.zeros(len(topic_list))
for i, topic in enumerate(topic_list):
if topic in file_topics.get(file, []):
vector[i] = 1
topic_vectors[file] = vector
# Calculate cosine similarity
for i, file1 in enumerate(file_list):
for j, file2 in enumerate(file_list):
if i == j:
similarity_matrix[i][j] = 1.0
else:
vec1 = topic_vectors.get(file1, np.zeros(len(topic_list)))
vec2 = topic_vectors.get(file2, np.zeros(len(topic_list)))
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 > 0 and norm2 > 0:
similarity_matrix[i][j] = dot_product / (norm1 * norm2)
return similarity_matrix
def create_knowledge_graph(self, file_list: List[str], file_topics: Dict[str, List[str]],
similarity_matrix: np.ndarray) -> nx.Graph:
"""Create knowledge graph from documents and topics"""
G = nx.Graph()
# Add document nodes
for file in file_list:
G.add_node(file, type='document', size=700)
# Add topic nodes and connections
for file, topics in file_topics.items():
for topic in topics:
if topic not in G:
G.add_node(topic, type='topic', size=500)
G.add_edge(file, topic, weight=3)
# Add edges between similar documents
for i, file1 in enumerate(file_list):
for j, file2 in enumerate(file_list):
if i < j:
sim = similarity_matrix[i][j]
if sim > 0.25:
G.add_edge(file1, file2, weight=sim * 5)
return G
def save_knowledge_graph(self, G: nx.Graph) -> str:
"""Save knowledge graph visualization"""
plt.figure(figsize=(16, 12))
pos = nx.kamada_kawai_layout(G)
document_nodes = [n for n, attr in G.nodes(data=True) if attr.get('type') == 'document']
topic_nodes = [n for n, attr in G.nodes(data=True) if attr.get('type') == 'topic']
edge_widths = [G[u][v].get('weight', 1) * 0.6 for u, v in G.edges()]
nx.draw_networkx_nodes(G, pos, nodelist=document_nodes, node_color='#5B9BD5',
node_size=800, alpha=0.8)
nx.draw_networkx_nodes(G, pos, nodelist=topic_nodes, node_color='#70AD47',
node_size=600, alpha=0.8)
nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.7, edge_color='#A5A5A5')
# Create labels
doc_labels = {}
for node in document_nodes:
if len(node) > 20:
shortened = re.sub(r'(?:Function|Operating|Setting|Activating|Deactivating) of ', '', node)
shortened = re.sub(r' Assist', '', shortened)
if len(shortened) > 20:
shortened = shortened[:18] + '...'
doc_labels[node] = shortened
else:
doc_labels[node] = node
# Draw labels
for node, label in doc_labels.items():
x, y = pos[node]
plt.text(x, y, label, fontsize=9, fontweight='bold',
ha='center', va='center',
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
for node in topic_nodes:
x, y = pos[node]
plt.text(x, y, node, fontsize=8, ha='center', va='center',
bbox=dict(facecolor='#E8F4E5', alpha=0.9, edgecolor='none', boxstyle='round,pad=0.2'))
plt.title("System Knowledge Map", fontsize=18)
plt.axis('off')
plt.tight_layout()
output_path = self.output_dir / "knowledge_graph.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"βœ… Knowledge graph saved to {output_path}")
return str(output_path)
def save_similarity_heatmap(self, matrix: np.ndarray, labels: List[str]) -> str:
"""Save similarity heatmap"""
plt.figure(figsize=(12, 10))
plt.imshow(matrix, cmap='Blues')
plt.colorbar(label='Similarity')
# Shorten labels
shortened_labels = []
for label in labels:
if len(label) > 15:
shortened = re.sub(r'(?:Function|Operating|Setting|Activating|Deactivating) of ', '', label)
shortened = re.sub(r' Assist', '', shortened)
if len(shortened) > 15:
shortened = shortened[:13] + '...'
shortened_labels.append(shortened)
else:
shortened_labels.append(label)
plt.xticks(range(len(labels)), shortened_labels, rotation=45, ha='right')
plt.yticks(range(len(labels)), shortened_labels)
# Add similarity values
for i in range(len(labels)):
for j in range(len(labels)):
if i != j:
plt.text(j, i, f'{matrix[i, j]:.2f}',
ha="center", va="center",
color="white" if matrix[i, j] > 0.5 else "black")
plt.title("Document Similarity Heatmap", fontsize=16)
plt.tight_layout()
output_path = self.output_dir / "similarity_heatmap.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"βœ… Similarity heatmap saved to {output_path}")
return str(output_path)
def generate_visualizations(self) -> Tuple[Optional[str], Optional[str]]:
"""Generate both knowledge graph and heatmap visualizations"""
print("πŸ”„ Generating knowledge graph visualizations...")
file_list = self.get_files_from_vector_store()
if not file_list:
print("⚠️ No files found. Cannot create knowledge map.")
return None, None
print("πŸ“Š Extracting topics from content...")
file_topics, all_topics = self.extract_topics_from_content(file_list)
print("πŸ”— Analyzing document relationships...")
similarity_matrix = self.analyze_document_relationships(file_list, file_topics)
print("🎨 Creating knowledge graph...")
G = self.create_knowledge_graph(file_list, file_topics, similarity_matrix)
print("πŸ’Ύ Saving visualizations...")
graph_path = self.save_knowledge_graph(G)
heatmap_path = self.save_similarity_heatmap(similarity_matrix, file_list)
print("βœ… Dynamic visualizations complete!")
return graph_path, heatmap_path