Spaces:
Sleeping
Sleeping
| import networkx as nx | |
| import requests | |
| from typing import Dict, List, Optional, Tuple | |
| class KnowledgeGraph: | |
| """Manages the knowledge graph for image-concept relationships""" | |
| def __init__(self): | |
| self.graph = nx.MultiDiGraph() | |
| self.conceptnet_api_base = "http://api.conceptnet.io" | |
| self.relationship_weights = { | |
| 'IsA': 1.0, | |
| 'HasProperty': 0.8, | |
| 'RelatedTo': 0.7, | |
| 'PartOf': 0.9, | |
| 'UsedFor': 0.8, | |
| 'CapableOf': 0.8, | |
| 'AtLocation': 0.7, | |
| 'default': 0.5 | |
| } | |
| self.decay_factor = 0.8 # Weight decay for depth | |
| def get_relationship_weight(self, relationship: str, confidence: float = 1.0) -> float: | |
| """ | |
| Calculate edge weight based on relationship type and confidence | |
| """ | |
| base_weight = self.relationship_weights.get(relationship, | |
| self.relationship_weights['default']) | |
| return base_weight * confidence | |
| def add_weighted_relationships(self, source: str, relationships: List[Dict]) -> None: | |
| """ | |
| Enhanced version of add_relationships with confidence scores | |
| """ | |
| for rel in relationships: | |
| # Extract confidence from ConceptNet response | |
| confidence = rel.get('weight', 1.0) | |
| rel_type = rel['relationship'] | |
| # Calculate weight | |
| weight = self.get_relationship_weight(rel_type, confidence) | |
| # Add nodes and weighted edge | |
| for node in [rel['source'], rel['target']]: | |
| if not self.graph.has_node(node): | |
| self.graph.add_node(node, node_type='concept') | |
| self.graph.add_edge( | |
| rel['source'], | |
| rel['target'], | |
| relationship=rel_type, | |
| weight=weight, | |
| confidence=confidence | |
| ) | |
| def add_image_node(self, image_id: str, caption: str) -> None: | |
| """ | |
| Add an image node with its caption to the graph | |
| Args: | |
| image_id: Unique identifier for the image | |
| caption: BLIP-2 generated caption for the image | |
| """ | |
| # Add image node with its properties | |
| self.graph.add_node( | |
| image_id, | |
| node_type='image', | |
| caption=caption | |
| ) | |
| # Extract main concepts from caption (simple tokenization for now) | |
| concepts = [word.lower() for word in caption.split()] | |
| # Add edges between image and its concepts | |
| for concept in concepts: | |
| self.graph.add_node( | |
| concept, | |
| node_type='concept' | |
| ) | |
| self.graph.add_edge( | |
| image_id, | |
| concept, | |
| relationship='has_concept' | |
| ) | |
| # def expand_concept(self, concept: str) -> List[Dict]: | |
| # """ | |
| # Query ConceptNet for relationships about a concept | |
| # Args: | |
| # concept: The concept to query relationships for | |
| # Returns: | |
| # List of dictionaries containing relationship data | |
| # """ | |
| # # Format concept for ConceptNet API (lowercase, replace spaces with underscores) | |
| # formatted_concept = f"/c/en/{concept.lower().replace(' ', '_')}" | |
| # try: | |
| # # Query ConceptNet API | |
| # response = requests.get( | |
| # f"{self.conceptnet_api_base}{formatted_concept}", | |
| # params={'limit': 50} # Adjust limit as needed | |
| # ) | |
| # response.raise_for_status() | |
| # # Extract edges (relationships) from response | |
| # edges = response.json().get('edges', []) | |
| # # Filter and format relationships | |
| # relationships = [] | |
| # for edge in edges: | |
| # # Only consider English relationships | |
| # if all(lang.endswith('/en') for lang in [edge['start']['language'], edge['end']['language']]): | |
| # relationships.append({ | |
| # 'source': edge['start']['label'], | |
| # 'target': edge['end']['label'], | |
| # 'relationship': edge['rel']['label'] | |
| # }) | |
| # return relationships | |
| # except requests.exceptions.RequestException as e: | |
| # print(f"Error querying ConceptNet: {e}") | |
| # return [] | |
| def expand_concept(self, concept: str) -> List[Dict]: | |
| """Query ConceptNet for relationships about a concept""" | |
| formatted_concept = f"/c/en/{concept.lower().replace(' ', '_')}" | |
| try: | |
| response = requests.get( | |
| f"{self.conceptnet_api_base}{formatted_concept}", | |
| params={'limit': 50} | |
| ) | |
| response.raise_for_status() | |
| edges = response.json().get('edges', []) | |
| relationships = [] | |
| for edge in edges: | |
| try: | |
| # More robust checking of edge structure | |
| start = edge.get('start', {}) | |
| end = edge.get('end', {}) | |
| # Check if we have valid English concepts | |
| if (start.get('language', '') == 'en' and | |
| end.get('language', '') == 'en'): | |
| relationships.append({ | |
| 'source': start.get('label', ''), | |
| 'target': end.get('label', ''), | |
| 'relationship': edge.get('rel', {}).get('label', '') | |
| }) | |
| except (KeyError, TypeError): | |
| continue # Skip malformed edges | |
| return relationships | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error querying ConceptNet: {e}") | |
| return [] | |
| def add_relationships(self, source: str, relationships: List[Dict]) -> None: | |
| """ | |
| Add relationships from ConceptNet to our graph | |
| Args: | |
| source: The source concept | |
| relationships: List of relationship dictionaries from expand_concept | |
| """ | |
| for rel in relationships: | |
| # Add source and target nodes if they don't exist | |
| for node in [rel['source'], rel['target']]: | |
| if not self.graph.has_node(node): | |
| self.graph.add_node( | |
| node, | |
| node_type='concept' | |
| ) | |
| # Add edge with relationship type | |
| self.graph.add_edge( | |
| rel['source'], | |
| rel['target'], | |
| relationship=rel['relationship'], | |
| weight=1.0 # Default weight, could be adjusted based on ConceptNet's confidence | |
| ) | |
| # Add reverse relationship for bidirectional search | |
| self.graph.add_edge( | |
| rel['target'], | |
| rel['source'], | |
| relationship=f"reverse_{rel['relationship']}", | |
| weight=1.0 | |
| ) | |
| def search(self, query: str, limit: int = 5) -> List[str]: | |
| """ | |
| Search for images based on semantic query | |
| Args: | |
| query: Search query string | |
| limit: Maximum number of results to return | |
| Returns: | |
| List of image IDs ordered by relevance | |
| """ | |
| # First, expand the query concept to understand its relationships | |
| query_relationships = self.expand_concept(query) | |
| # Add query relationships temporarily to graph | |
| temp_query_node = f"_query_{query}" | |
| self.add_relationships(temp_query_node, query_relationships) | |
| try: | |
| # Find all image nodes | |
| image_nodes = [n for n, attr in self.graph.nodes(data=True) | |
| if attr.get('node_type') == 'image'] | |
| # Calculate relevance scores for each image | |
| image_scores = [] | |
| for image_id in image_nodes: | |
| # Use shortest path length as a relevance metric | |
| # Shorter paths = more relevant | |
| try: | |
| path_length = nx.shortest_path_length( | |
| self.graph, | |
| source=temp_query_node, | |
| target=image_id | |
| ) | |
| score = 1.0 / (1.0 + path_length) # Convert distance to similarity score | |
| image_scores.append((image_id, score)) | |
| except nx.NetworkXNoPath: | |
| continue | |
| # Sort by score and return top results | |
| image_scores.sort(key=lambda x: x[1], reverse=True) | |
| return [img_id for img_id, _ in image_scores[:limit]] | |
| finally: | |
| # Clean up temporary query nodes | |
| self.graph.remove_node(temp_query_node) | |
| def get_related_concepts(self, concept: str, relationship_type: Optional[str] = None) -> List[Tuple[str, str]]: | |
| """ | |
| Get concepts related to given concept, optionally filtered by relationship type | |
| Args: | |
| concept: The source concept to find relations for | |
| relationship_type: Optional filter for specific relationship types | |
| Returns: | |
| List of tuples containing (related_concept, relationship_type) | |
| """ | |
| related_concepts = [] | |
| # Get all outgoing edges from the concept | |
| if self.graph.has_node(concept): | |
| for _, target, edge_data in self.graph.out_edges(concept, data=True): | |
| rel_type = edge_data.get('relationship', '') | |
| # Filter by relationship_type if specified | |
| if relationship_type is None or rel_type == relationship_type: | |
| # Don't include temporary query nodes or image nodes | |
| if (not target.startswith('_query_') and | |
| self.graph.nodes[target].get('node_type') != 'image'): | |
| related_concepts.append((target, rel_type)) | |
| # If no direct relationships found, try expanding from ConceptNet | |
| if not related_concepts: | |
| new_relationships = self.expand_concept(concept) | |
| self.add_relationships(concept, new_relationships) | |
| # Try again with newly added relationships | |
| for _, target, edge_data in self.graph.out_edges(concept, data=True): | |
| rel_type = edge_data.get('relationship', '') | |
| if relationship_type is None or rel_type == relationship_type: | |
| if (not target.startswith('_query_') and | |
| self.graph.nodes[target].get('node_type') != 'image'): | |
| related_concepts.append((target, rel_type)) | |
| return related_concepts | |
| def search_with_depth(self, query: str, max_depth: int = 3, limit: int = 5) -> List[Tuple[str, float]]: | |
| """Enhanced search with depth control and path weights""" | |
| temp_query_node = f"_query_{query}" | |
| try: | |
| # First add the query node and its relationships | |
| query_relationships = self.expand_concept(query) | |
| if not query_relationships: | |
| print(f"No relationships found for query: {query}") | |
| return [] | |
| # Ensure query node is added before adding relationships | |
| self.graph.add_node(temp_query_node, node_type='query') | |
| self.add_weighted_relationships(temp_query_node, query_relationships) | |
| # Rest of the search logic... | |
| image_scores = {} | |
| image_nodes = [n for n, attr in self.graph.nodes(data=True) | |
| if attr.get('node_type') == 'image'] | |
| for image_id in image_nodes: | |
| try: | |
| paths = nx.all_simple_paths( | |
| self.graph, | |
| source=temp_query_node, | |
| target=image_id, | |
| cutoff=max_depth | |
| ) | |
| path_scores = [] | |
| for path in paths: | |
| score = self._calculate_path_score(path) | |
| path_scores.append(score) | |
| if path_scores: | |
| image_scores[image_id] = max(path_scores) | |
| except nx.NetworkXNoPath: | |
| continue | |
| # Sort and return results | |
| sorted_results = sorted( | |
| image_scores.items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| return sorted_results[:limit] | |
| finally: | |
| # Clean up: remove temporary query node | |
| if self.graph.has_node(temp_query_node): | |
| self.graph.remove_node(temp_query_node) | |
| def search_with_depth1(self, query: str, max_depth: int = 3, limit: int = 5) -> List[Tuple[str, float]]: | |
| """ | |
| Enhanced search with depth control and path weights | |
| Args: | |
| query: Search query string | |
| max_depth: Maximum path length to consider | |
| limit: Maximum number of results to return | |
| Returns: | |
| List of tuples (image_id, relevance_score) | |
| """ | |
| temp_query_node = f"_query_{query}" | |
| query_relationships = self.expand_concept(query) | |
| self.add_weighted_relationships(temp_query_node, query_relationships) | |
| try: | |
| image_scores = {} | |
| # Find all image nodes | |
| image_nodes = [n for n, attr in self.graph.nodes(data=True) | |
| if attr.get('node_type') == 'image'] | |
| for image_id in image_nodes: | |
| # Get all paths up to max_depth | |
| try: | |
| paths = nx.all_simple_paths( | |
| self.graph, | |
| source=temp_query_node, | |
| target=image_id, | |
| cutoff=max_depth | |
| ) | |
| # Calculate score for each path | |
| path_scores = [] | |
| for path in paths: | |
| score = self._calculate_path_score(path) | |
| path_scores.append(score) | |
| # Use maximum score from all paths | |
| if path_scores: | |
| image_scores[image_id] = max(path_scores) | |
| except nx.NetworkXNoPath: | |
| continue | |
| # Sort and return top results | |
| sorted_results = sorted( | |
| image_scores.items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| return sorted_results[:limit] | |
| finally: | |
| self.graph.remove_node(temp_query_node) | |
| def _calculate_path_score(self, path: List[str]) -> float: | |
| """ | |
| Calculate score for a path based on relationship weights and depth | |
| """ | |
| total_score = 0 | |
| path_length = len(path) - 1 | |
| for i in range(path_length): | |
| # Get edge data (might have multiple edges between nodes) | |
| edges = self.graph.get_edge_data(path[i], path[i + 1]) | |
| if edges: | |
| # Use maximum weight among parallel edges | |
| max_weight = max( | |
| edge.get('weight', self.relationship_weights['default']) | |
| for edge in edges.values() | |
| ) | |
| # Apply depth decay | |
| depth_factor = self.decay_factor ** i | |
| total_score += max_weight * depth_factor | |
| # Normalize by path length | |
| return total_score / path_length if path_length > 0 else 0 | |