import pandas as pd from sklearn.metrics.pairwise import cosine_similarity from typing import List, Any, Optional from collections import defaultdict, deque def predict_topic_nth_degree( new_article_title: str, new_article_embedding: List[float], edges: List[str], G: Any, max_depth: int = 1, is_weighted: bool = False, decay_factor: float = 1.0, ) -> Optional[str]: """ Predicts topic based on neighbors up to n-degrees away. Args: max_depth: How many hops to traverse (1 = direct neighbors, 2 = neighbors of neighbors). decay_factor: Multiplier for distance. 1.0 = no decay. 0.5 means a neighbor at depth 2 has half the voting power of depth 1. """ # 1. Setup BFS # Queue stores: (current_node_name, current_depth) queue = deque() # We maintain a visited set to avoid cycles and processing the same node twice visited = set() visited.add(new_article_title) # 2. Initialize BFS with the "Virtual" First Hop # We iterate the input list 'edges' manually because the new article isn't in G. for ref in edges: if ref in G and ref not in visited: visited.add(ref) queue.append((ref, 1)) # Depth 1 if not queue: return None topic_scores = defaultdict(float) # 3. Process BFS while queue: current_node, current_depth = queue.popleft() # --- Score Calculation --- node_data = G.nodes[current_node] topic = node_data.get("label") if topic: # Determine base weight if is_weighted: neighbor_embedding = node_data["embedding"] # Calculate similarity base_score = cosine_similarity( [new_article_embedding], [neighbor_embedding] )[0][0] else: base_score = 1.0 # Apply Distance Decay # Formula: Score * (decay ^ (depth - 1)) # Depth 1: Score * 1 # Depth 2: Score * decay weighted_score = base_score * (decay_factor ** (current_depth - 1)) topic_scores[topic] += weighted_score # --- Expand to next level if within limit --- if current_depth < max_depth: for neighbor in G.neighbors(current_node): if neighbor not in visited: visited.add(neighbor) queue.append((neighbor, current_depth + 1)) # 4. Determine Winner if not topic_scores: return None columns = ["Class", "Score"] result_df = pd.DataFrame( [(topic, score) for topic, score in topic_scores.items()], columns=columns ).sort_values(by="Score", ascending=False) return result_df