| |
| |
| """ |
| Reference: |
| - [graphrag](https://github.com/microsoft/graphrag) |
| """ |
|
|
| from typing import Any |
|
|
| import numpy as np |
| import networkx as nx |
| from graphrag.leiden import stable_largest_connected_component |
|
|
|
|
| @dataclass |
| class NodeEmbeddings: |
| """Node embeddings class definition.""" |
|
|
| nodes: list[str] |
| embeddings: np.ndarray |
|
|
|
|
| def embed_nod2vec( |
| graph: nx.Graph | nx.DiGraph, |
| dimensions: int = 1536, |
| num_walks: int = 10, |
| walk_length: int = 40, |
| window_size: int = 2, |
| iterations: int = 3, |
| random_seed: int = 86, |
| ) -> NodeEmbeddings: |
| """Generate node embeddings using Node2Vec.""" |
| |
| lcc_tensors = gc.embed.node2vec_embed( |
| graph=graph, |
| dimensions=dimensions, |
| window_size=window_size, |
| iterations=iterations, |
| num_walks=num_walks, |
| walk_length=walk_length, |
| random_seed=random_seed, |
| ) |
| return NodeEmbeddings(embeddings=lcc_tensors[0], nodes=lcc_tensors[1]) |
|
|
|
|
| def run(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: |
| """Run method definition.""" |
| if args.get("use_lcc", True): |
| graph = stable_largest_connected_component(graph) |
|
|
| |
| embeddings = embed_nod2vec( |
| graph=graph, |
| dimensions=args.get("dimensions", 1536), |
| num_walks=args.get("num_walks", 10), |
| walk_length=args.get("walk_length", 40), |
| window_size=args.get("window_size", 2), |
| iterations=args.get("iterations", 3), |
| random_seed=args.get("random_seed", 86), |
| ) |
|
|
| pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) |
| sorted_pairs = sorted(pairs, key=lambda x: x[0]) |
|
|
| return dict(sorted_pairs) |