import networkx as nx import requests from typing import Dict, List, Optional, Tuple class KnowledgeGraph: """Manages the knowledge graph for image-concept relationships""" def __init__(self): self.graph = nx.MultiDiGraph() self.conceptnet_api_base = "http://api.conceptnet.io" self.relationship_weights = { 'IsA': 1.0, 'HasProperty': 0.8, 'RelatedTo': 0.7, 'PartOf': 0.9, 'UsedFor': 0.8, 'CapableOf': 0.8, 'AtLocation': 0.7, 'default': 0.5 } self.decay_factor = 0.8 # Weight decay for depth def get_relationship_weight(self, relationship: str, confidence: float = 1.0) -> float: """ Calculate edge weight based on relationship type and confidence """ base_weight = self.relationship_weights.get(relationship, self.relationship_weights['default']) return base_weight * confidence def add_weighted_relationships(self, source: str, relationships: List[Dict]) -> None: """ Enhanced version of add_relationships with confidence scores """ for rel in relationships: # Extract confidence from ConceptNet response confidence = rel.get('weight', 1.0) rel_type = rel['relationship'] # Calculate weight weight = self.get_relationship_weight(rel_type, confidence) # Add nodes and weighted edge for node in [rel['source'], rel['target']]: if not self.graph.has_node(node): self.graph.add_node(node, node_type='concept') self.graph.add_edge( rel['source'], rel['target'], relationship=rel_type, weight=weight, confidence=confidence ) def add_image_node(self, image_id: str, caption: str) -> None: """ Add an image node with its caption to the graph Args: image_id: Unique identifier for the image caption: BLIP-2 generated caption for the image """ # Add image node with its properties self.graph.add_node( image_id, node_type='image', caption=caption ) # Extract main concepts from caption (simple tokenization for now) concepts = [word.lower() for word in caption.split()] # Add edges between image and its concepts for concept in concepts: self.graph.add_node( concept, node_type='concept' ) self.graph.add_edge( image_id, concept, relationship='has_concept' ) # def expand_concept(self, concept: str) -> List[Dict]: # """ # Query ConceptNet for relationships about a concept # Args: # concept: The concept to query relationships for # Returns: # List of dictionaries containing relationship data # """ # # Format concept for ConceptNet API (lowercase, replace spaces with underscores) # formatted_concept = f"/c/en/{concept.lower().replace(' ', '_')}" # try: # # Query ConceptNet API # response = requests.get( # f"{self.conceptnet_api_base}{formatted_concept}", # params={'limit': 50} # Adjust limit as needed # ) # response.raise_for_status() # # Extract edges (relationships) from response # edges = response.json().get('edges', []) # # Filter and format relationships # relationships = [] # for edge in edges: # # Only consider English relationships # if all(lang.endswith('/en') for lang in [edge['start']['language'], edge['end']['language']]): # relationships.append({ # 'source': edge['start']['label'], # 'target': edge['end']['label'], # 'relationship': edge['rel']['label'] # }) # return relationships # except requests.exceptions.RequestException as e: # print(f"Error querying ConceptNet: {e}") # return [] def expand_concept(self, concept: str) -> List[Dict]: """Query ConceptNet for relationships about a concept""" formatted_concept = f"/c/en/{concept.lower().replace(' ', '_')}" try: response = requests.get( f"{self.conceptnet_api_base}{formatted_concept}", params={'limit': 50} ) response.raise_for_status() edges = response.json().get('edges', []) relationships = [] for edge in edges: try: # More robust checking of edge structure start = edge.get('start', {}) end = edge.get('end', {}) # Check if we have valid English concepts if (start.get('language', '') == 'en' and end.get('language', '') == 'en'): relationships.append({ 'source': start.get('label', ''), 'target': end.get('label', ''), 'relationship': edge.get('rel', {}).get('label', '') }) except (KeyError, TypeError): continue # Skip malformed edges return relationships except requests.exceptions.RequestException as e: print(f"Error querying ConceptNet: {e}") return [] def add_relationships(self, source: str, relationships: List[Dict]) -> None: """ Add relationships from ConceptNet to our graph Args: source: The source concept relationships: List of relationship dictionaries from expand_concept """ for rel in relationships: # Add source and target nodes if they don't exist for node in [rel['source'], rel['target']]: if not self.graph.has_node(node): self.graph.add_node( node, node_type='concept' ) # Add edge with relationship type self.graph.add_edge( rel['source'], rel['target'], relationship=rel['relationship'], weight=1.0 # Default weight, could be adjusted based on ConceptNet's confidence ) # Add reverse relationship for bidirectional search self.graph.add_edge( rel['target'], rel['source'], relationship=f"reverse_{rel['relationship']}", weight=1.0 ) def search(self, query: str, limit: int = 5) -> List[str]: """ Search for images based on semantic query Args: query: Search query string limit: Maximum number of results to return Returns: List of image IDs ordered by relevance """ # First, expand the query concept to understand its relationships query_relationships = self.expand_concept(query) # Add query relationships temporarily to graph temp_query_node = f"_query_{query}" self.add_relationships(temp_query_node, query_relationships) try: # Find all image nodes image_nodes = [n for n, attr in self.graph.nodes(data=True) if attr.get('node_type') == 'image'] # Calculate relevance scores for each image image_scores = [] for image_id in image_nodes: # Use shortest path length as a relevance metric # Shorter paths = more relevant try: path_length = nx.shortest_path_length( self.graph, source=temp_query_node, target=image_id ) score = 1.0 / (1.0 + path_length) # Convert distance to similarity score image_scores.append((image_id, score)) except nx.NetworkXNoPath: continue # Sort by score and return top results image_scores.sort(key=lambda x: x[1], reverse=True) return [img_id for img_id, _ in image_scores[:limit]] finally: # Clean up temporary query nodes self.graph.remove_node(temp_query_node) def get_related_concepts(self, concept: str, relationship_type: Optional[str] = None) -> List[Tuple[str, str]]: """ Get concepts related to given concept, optionally filtered by relationship type Args: concept: The source concept to find relations for relationship_type: Optional filter for specific relationship types Returns: List of tuples containing (related_concept, relationship_type) """ related_concepts = [] # Get all outgoing edges from the concept if self.graph.has_node(concept): for _, target, edge_data in self.graph.out_edges(concept, data=True): rel_type = edge_data.get('relationship', '') # Filter by relationship_type if specified if relationship_type is None or rel_type == relationship_type: # Don't include temporary query nodes or image nodes if (not target.startswith('_query_') and self.graph.nodes[target].get('node_type') != 'image'): related_concepts.append((target, rel_type)) # If no direct relationships found, try expanding from ConceptNet if not related_concepts: new_relationships = self.expand_concept(concept) self.add_relationships(concept, new_relationships) # Try again with newly added relationships for _, target, edge_data in self.graph.out_edges(concept, data=True): rel_type = edge_data.get('relationship', '') if relationship_type is None or rel_type == relationship_type: if (not target.startswith('_query_') and self.graph.nodes[target].get('node_type') != 'image'): related_concepts.append((target, rel_type)) return related_concepts def search_with_depth(self, query: str, max_depth: int = 3, limit: int = 5) -> List[Tuple[str, float]]: """Enhanced search with depth control and path weights""" temp_query_node = f"_query_{query}" try: # First add the query node and its relationships query_relationships = self.expand_concept(query) if not query_relationships: print(f"No relationships found for query: {query}") return [] # Ensure query node is added before adding relationships self.graph.add_node(temp_query_node, node_type='query') self.add_weighted_relationships(temp_query_node, query_relationships) # Rest of the search logic... image_scores = {} image_nodes = [n for n, attr in self.graph.nodes(data=True) if attr.get('node_type') == 'image'] for image_id in image_nodes: try: paths = nx.all_simple_paths( self.graph, source=temp_query_node, target=image_id, cutoff=max_depth ) path_scores = [] for path in paths: score = self._calculate_path_score(path) path_scores.append(score) if path_scores: image_scores[image_id] = max(path_scores) except nx.NetworkXNoPath: continue # Sort and return results sorted_results = sorted( image_scores.items(), key=lambda x: x[1], reverse=True ) return sorted_results[:limit] finally: # Clean up: remove temporary query node if self.graph.has_node(temp_query_node): self.graph.remove_node(temp_query_node) def search_with_depth1(self, query: str, max_depth: int = 3, limit: int = 5) -> List[Tuple[str, float]]: """ Enhanced search with depth control and path weights Args: query: Search query string max_depth: Maximum path length to consider limit: Maximum number of results to return Returns: List of tuples (image_id, relevance_score) """ temp_query_node = f"_query_{query}" query_relationships = self.expand_concept(query) self.add_weighted_relationships(temp_query_node, query_relationships) try: image_scores = {} # Find all image nodes image_nodes = [n for n, attr in self.graph.nodes(data=True) if attr.get('node_type') == 'image'] for image_id in image_nodes: # Get all paths up to max_depth try: paths = nx.all_simple_paths( self.graph, source=temp_query_node, target=image_id, cutoff=max_depth ) # Calculate score for each path path_scores = [] for path in paths: score = self._calculate_path_score(path) path_scores.append(score) # Use maximum score from all paths if path_scores: image_scores[image_id] = max(path_scores) except nx.NetworkXNoPath: continue # Sort and return top results sorted_results = sorted( image_scores.items(), key=lambda x: x[1], reverse=True ) return sorted_results[:limit] finally: self.graph.remove_node(temp_query_node) def _calculate_path_score(self, path: List[str]) -> float: """ Calculate score for a path based on relationship weights and depth """ total_score = 0 path_length = len(path) - 1 for i in range(path_length): # Get edge data (might have multiple edges between nodes) edges = self.graph.get_edge_data(path[i], path[i + 1]) if edges: # Use maximum weight among parallel edges max_weight = max( edge.get('weight', self.relationship_weights['default']) for edge in edges.values() ) # Apply depth decay depth_factor = self.decay_factor ** i total_score += max_weight * depth_factor # Normalize by path length return total_score / path_length if path_length > 0 else 0