Metin's picture
Initial commit
d97a439
raw
history blame
4.43 kB
import networkx as nx
import pandas as pd
def get_unique_article_titles(df: pd.DataFrame) -> list[str]:
unique_articles = df["article_title_processed"].unique()
unique_articles_sorted = sorted(unique_articles.tolist())
return unique_articles_sorted
def create_graph_from_df(df, directed: bool = False) -> nx.Graph:
G = nx.Graph()
for i, row in df.iterrows():
node_title = row["article_title_processed"]
node_class = row["predicted_topic"]
G.add_node(node_title, label=node_class, embedding=row["embedding"])
for i, row in df.iterrows():
node_title = row["article_title_processed"]
references = eval(row["links_processed"])
for ref in references:
if ref in G and ref != node_title:
G.add_edge(node_title, ref)
if not directed:
G.add_edge(ref, node_title)
return G
def gather_neighbors(
graph: nx.DiGraph, node_title: str, references: list[str], depth: int = 1
):
neighbors = set()
modified_graph = graph.copy()
modified_graph.add_node(node_title)
for ref in references:
if ref in modified_graph and ref != node_title:
modified_graph.add_edge(node_title, ref)
neighbors = get_neighbors_for_visualizer(modified_graph, node_title, depth=depth)
return neighbors
def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
"""
Returns the neighbors of a node within a given depth in a format
compatible with Cytoscape-style visualizers.
Args:
graph (nx.Graph): The source NetworkX graph.
start_node: The title/ID of the node to start from.
depth (int): How many hops (degrees of separation) to traverse.
Returns:
dict: A dictionary containing 'nodes' and 'edges' formatted for the visualizer.
"""
# 1. Create a subgraph of neighbors within the specified depth
# If the node doesn't exist, return empty structure or raise error
if start_node not in graph:
return {"nodes": [], "edges": []}
subgraph = nx.ego_graph(graph, start_node, radius=depth)
# 2. Prepare data structures
nodes_data = []
edges_data = []
# Helper to map actual node names (titles) to integer IDs required by the format
# The example uses 1-based integers for IDs.
node_to_id_map = {}
current_id = 1
# 3. Process Nodes
for node in subgraph.nodes():
# Assign an integer ID
node_to_id_map[node] = current_id
# Get attributes (safely default if label is missing)
# We ignore 'embedding' as requested
node_attrs = subgraph.nodes[node]
label = node_attrs.get("label", "Unknown")
node_obj = {
"data": {
"id": current_id,
"label": label,
"name": str(node), # Using the node title/ID as 'name'
}
}
nodes_data.append(node_obj)
current_id += 1
# 4. Process Edges
# Edge IDs usually need to be unique strings or integers.
# We continue the counter from where nodes left off to ensure uniqueness.
edge_id_counter = current_id
for u, v in subgraph.edges():
source_id = node_to_id_map[u]
target_id = node_to_id_map[v]
# Get edge attributes if they exist (e.g., relationship type)
edge_attrs = subgraph.edges[u, v]
edge_label = edge_attrs.get("label", "CITES") # Default label if none exists
edge_obj = {
"data": {
"id": edge_id_counter,
"label": edge_label,
"source": source_id,
"target": target_id,
}
}
edges_data.append(edge_obj)
edge_id_counter += 1
# 5. Return the final structure
return {"nodes": nodes_data, "edges": edges_data}
if __name__ == "__main__":
data = pd.read_parquet(
r"C:\Users\pc\Desktop\Projects\Masters\data_mining\semantic_knowledge_graph\demo\input\train_data_with_embeddings.parquet"
)
graph = create_graph_from_df(data)
test_title = "Sample Article Title"
test_references = ["finansal matematik", "genel yapay zekâ", "andrej karpathy"]
neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
# print(f"References for '{test_title}': {test_references}")
print(f"Neighbors of '{test_title}': {neighbors}")