semantic_image_search / modules /knowledge_graph.py
Chamin09's picture
Update modules/knowledge_graph.py
4e9058c verified
import networkx as nx
import requests
from typing import Dict, List, Optional, Tuple
class KnowledgeGraph:
"""Manages the knowledge graph for image-concept relationships"""
def __init__(self):
self.graph = nx.MultiDiGraph()
self.conceptnet_api_base = "http://api.conceptnet.io"
self.relationship_weights = {
'IsA': 1.0,
'HasProperty': 0.8,
'RelatedTo': 0.7,
'PartOf': 0.9,
'UsedFor': 0.8,
'CapableOf': 0.8,
'AtLocation': 0.7,
'default': 0.5
}
self.decay_factor = 0.8 # Weight decay for depth
def get_relationship_weight(self, relationship: str, confidence: float = 1.0) -> float:
"""
Calculate edge weight based on relationship type and confidence
"""
base_weight = self.relationship_weights.get(relationship,
self.relationship_weights['default'])
return base_weight * confidence
def add_weighted_relationships(self, source: str, relationships: List[Dict]) -> None:
"""
Enhanced version of add_relationships with confidence scores
"""
for rel in relationships:
# Extract confidence from ConceptNet response
confidence = rel.get('weight', 1.0)
rel_type = rel['relationship']
# Calculate weight
weight = self.get_relationship_weight(rel_type, confidence)
# Add nodes and weighted edge
for node in [rel['source'], rel['target']]:
if not self.graph.has_node(node):
self.graph.add_node(node, node_type='concept')
self.graph.add_edge(
rel['source'],
rel['target'],
relationship=rel_type,
weight=weight,
confidence=confidence
)
def add_image_node(self, image_id: str, caption: str) -> None:
"""
Add an image node with its caption to the graph
Args:
image_id: Unique identifier for the image
caption: BLIP-2 generated caption for the image
"""
# Add image node with its properties
self.graph.add_node(
image_id,
node_type='image',
caption=caption
)
# Extract main concepts from caption (simple tokenization for now)
concepts = [word.lower() for word in caption.split()]
# Add edges between image and its concepts
for concept in concepts:
self.graph.add_node(
concept,
node_type='concept'
)
self.graph.add_edge(
image_id,
concept,
relationship='has_concept'
)
# def expand_concept(self, concept: str) -> List[Dict]:
# """
# Query ConceptNet for relationships about a concept
# Args:
# concept: The concept to query relationships for
# Returns:
# List of dictionaries containing relationship data
# """
# # Format concept for ConceptNet API (lowercase, replace spaces with underscores)
# formatted_concept = f"/c/en/{concept.lower().replace(' ', '_')}"
# try:
# # Query ConceptNet API
# response = requests.get(
# f"{self.conceptnet_api_base}{formatted_concept}",
# params={'limit': 50} # Adjust limit as needed
# )
# response.raise_for_status()
# # Extract edges (relationships) from response
# edges = response.json().get('edges', [])
# # Filter and format relationships
# relationships = []
# for edge in edges:
# # Only consider English relationships
# if all(lang.endswith('/en') for lang in [edge['start']['language'], edge['end']['language']]):
# relationships.append({
# 'source': edge['start']['label'],
# 'target': edge['end']['label'],
# 'relationship': edge['rel']['label']
# })
# return relationships
# except requests.exceptions.RequestException as e:
# print(f"Error querying ConceptNet: {e}")
# return []
def expand_concept(self, concept: str) -> List[Dict]:
"""Query ConceptNet for relationships about a concept"""
formatted_concept = f"/c/en/{concept.lower().replace(' ', '_')}"
try:
response = requests.get(
f"{self.conceptnet_api_base}{formatted_concept}",
params={'limit': 50}
)
response.raise_for_status()
edges = response.json().get('edges', [])
relationships = []
for edge in edges:
try:
# More robust checking of edge structure
start = edge.get('start', {})
end = edge.get('end', {})
# Check if we have valid English concepts
if (start.get('language', '') == 'en' and
end.get('language', '') == 'en'):
relationships.append({
'source': start.get('label', ''),
'target': end.get('label', ''),
'relationship': edge.get('rel', {}).get('label', '')
})
except (KeyError, TypeError):
continue # Skip malformed edges
return relationships
except requests.exceptions.RequestException as e:
print(f"Error querying ConceptNet: {e}")
return []
def add_relationships(self, source: str, relationships: List[Dict]) -> None:
"""
Add relationships from ConceptNet to our graph
Args:
source: The source concept
relationships: List of relationship dictionaries from expand_concept
"""
for rel in relationships:
# Add source and target nodes if they don't exist
for node in [rel['source'], rel['target']]:
if not self.graph.has_node(node):
self.graph.add_node(
node,
node_type='concept'
)
# Add edge with relationship type
self.graph.add_edge(
rel['source'],
rel['target'],
relationship=rel['relationship'],
weight=1.0 # Default weight, could be adjusted based on ConceptNet's confidence
)
# Add reverse relationship for bidirectional search
self.graph.add_edge(
rel['target'],
rel['source'],
relationship=f"reverse_{rel['relationship']}",
weight=1.0
)
def search(self, query: str, limit: int = 5) -> List[str]:
"""
Search for images based on semantic query
Args:
query: Search query string
limit: Maximum number of results to return
Returns:
List of image IDs ordered by relevance
"""
# First, expand the query concept to understand its relationships
query_relationships = self.expand_concept(query)
# Add query relationships temporarily to graph
temp_query_node = f"_query_{query}"
self.add_relationships(temp_query_node, query_relationships)
try:
# Find all image nodes
image_nodes = [n for n, attr in self.graph.nodes(data=True)
if attr.get('node_type') == 'image']
# Calculate relevance scores for each image
image_scores = []
for image_id in image_nodes:
# Use shortest path length as a relevance metric
# Shorter paths = more relevant
try:
path_length = nx.shortest_path_length(
self.graph,
source=temp_query_node,
target=image_id
)
score = 1.0 / (1.0 + path_length) # Convert distance to similarity score
image_scores.append((image_id, score))
except nx.NetworkXNoPath:
continue
# Sort by score and return top results
image_scores.sort(key=lambda x: x[1], reverse=True)
return [img_id for img_id, _ in image_scores[:limit]]
finally:
# Clean up temporary query nodes
self.graph.remove_node(temp_query_node)
def get_related_concepts(self, concept: str, relationship_type: Optional[str] = None) -> List[Tuple[str, str]]:
"""
Get concepts related to given concept, optionally filtered by relationship type
Args:
concept: The source concept to find relations for
relationship_type: Optional filter for specific relationship types
Returns:
List of tuples containing (related_concept, relationship_type)
"""
related_concepts = []
# Get all outgoing edges from the concept
if self.graph.has_node(concept):
for _, target, edge_data in self.graph.out_edges(concept, data=True):
rel_type = edge_data.get('relationship', '')
# Filter by relationship_type if specified
if relationship_type is None or rel_type == relationship_type:
# Don't include temporary query nodes or image nodes
if (not target.startswith('_query_') and
self.graph.nodes[target].get('node_type') != 'image'):
related_concepts.append((target, rel_type))
# If no direct relationships found, try expanding from ConceptNet
if not related_concepts:
new_relationships = self.expand_concept(concept)
self.add_relationships(concept, new_relationships)
# Try again with newly added relationships
for _, target, edge_data in self.graph.out_edges(concept, data=True):
rel_type = edge_data.get('relationship', '')
if relationship_type is None or rel_type == relationship_type:
if (not target.startswith('_query_') and
self.graph.nodes[target].get('node_type') != 'image'):
related_concepts.append((target, rel_type))
return related_concepts
def search_with_depth(self, query: str, max_depth: int = 3, limit: int = 5) -> List[Tuple[str, float]]:
"""Enhanced search with depth control and path weights"""
temp_query_node = f"_query_{query}"
try:
# First add the query node and its relationships
query_relationships = self.expand_concept(query)
if not query_relationships:
print(f"No relationships found for query: {query}")
return []
# Ensure query node is added before adding relationships
self.graph.add_node(temp_query_node, node_type='query')
self.add_weighted_relationships(temp_query_node, query_relationships)
# Rest of the search logic...
image_scores = {}
image_nodes = [n for n, attr in self.graph.nodes(data=True)
if attr.get('node_type') == 'image']
for image_id in image_nodes:
try:
paths = nx.all_simple_paths(
self.graph,
source=temp_query_node,
target=image_id,
cutoff=max_depth
)
path_scores = []
for path in paths:
score = self._calculate_path_score(path)
path_scores.append(score)
if path_scores:
image_scores[image_id] = max(path_scores)
except nx.NetworkXNoPath:
continue
# Sort and return results
sorted_results = sorted(
image_scores.items(),
key=lambda x: x[1],
reverse=True
)
return sorted_results[:limit]
finally:
# Clean up: remove temporary query node
if self.graph.has_node(temp_query_node):
self.graph.remove_node(temp_query_node)
def search_with_depth1(self, query: str, max_depth: int = 3, limit: int = 5) -> List[Tuple[str, float]]:
"""
Enhanced search with depth control and path weights
Args:
query: Search query string
max_depth: Maximum path length to consider
limit: Maximum number of results to return
Returns:
List of tuples (image_id, relevance_score)
"""
temp_query_node = f"_query_{query}"
query_relationships = self.expand_concept(query)
self.add_weighted_relationships(temp_query_node, query_relationships)
try:
image_scores = {}
# Find all image nodes
image_nodes = [n for n, attr in self.graph.nodes(data=True)
if attr.get('node_type') == 'image']
for image_id in image_nodes:
# Get all paths up to max_depth
try:
paths = nx.all_simple_paths(
self.graph,
source=temp_query_node,
target=image_id,
cutoff=max_depth
)
# Calculate score for each path
path_scores = []
for path in paths:
score = self._calculate_path_score(path)
path_scores.append(score)
# Use maximum score from all paths
if path_scores:
image_scores[image_id] = max(path_scores)
except nx.NetworkXNoPath:
continue
# Sort and return top results
sorted_results = sorted(
image_scores.items(),
key=lambda x: x[1],
reverse=True
)
return sorted_results[:limit]
finally:
self.graph.remove_node(temp_query_node)
def _calculate_path_score(self, path: List[str]) -> float:
"""
Calculate score for a path based on relationship weights and depth
"""
total_score = 0
path_length = len(path) - 1
for i in range(path_length):
# Get edge data (might have multiple edges between nodes)
edges = self.graph.get_edge_data(path[i], path[i + 1])
if edges:
# Use maximum weight among parallel edges
max_weight = max(
edge.get('weight', self.relationship_weights['default'])
for edge in edges.values()
)
# Apply depth decay
depth_factor = self.decay_factor ** i
total_score += max_weight * depth_factor
# Normalize by path length
return total_score / path_length if path_length > 0 else 0