| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Functions for processing the data.""" |
| |
|
| | import json |
| | from typing import Any, Dict, List |
| |
|
| | import matplotlib.pyplot as plt |
| | import networkx as nx |
| | import numpy as np |
| | import tqdm |
| |
|
| | import tensorflow as tf |
| | from tf.io import gfile |
| |
|
| | MAX_REL_PER_ENTITY = 10000 |
| | |
| | WIKIPEDIA_LINK_PATH = '' |
| | WIKIPEDIA_GRAPH_PATH = '' |
| | WIKIDATA_EDGE_PATH = '' |
| | WIKIDATA_ENTITY_PATH = '' |
| |
|
| | NUM_SPLIT = 1024 |
| |
|
| |
|
| | def construct_kg_graph(rel_list) -> Dict[str, Dict[str, List[str]]]: |
| | """construct a kg graph with reverse link. |
| | |
| | Args: |
| | rel_list: KG edge list, in the format of <s, r, t> triplets. |
| | |
| | Returns: |
| | |
| | """ |
| | kg_graph = {} |
| | kg_rel = {} |
| | for se, r, te in tqdm.tqdm(rel_list): |
| | if se == te: |
| | continue |
| | if se not in kg_graph: |
| | kg_graph[se] = {} |
| | if te not in kg_graph[se]: |
| | kg_graph[se][te] = [] |
| | if r not in kg_graph[se][te]: |
| | kg_graph[se][te].append(r) |
| |
|
| | if r not in kg_rel: |
| | kg_rel[r] = {} |
| | if te not in kg_rel[r]: |
| | kg_rel[r][te] = 0 |
| | kg_rel[r][te] += 1 |
| |
|
| | for se, r, te in tqdm.tqdm(rel_list): |
| | if se == te or kg_rel[r][te] >= MAX_REL_PER_ENTITY: |
| | continue |
| | if te not in kg_graph: |
| | kg_graph[te] = {} |
| | if se not in kg_graph[te]: |
| | kg_graph[te][se] = [] |
| | if r + '_R' not in kg_graph[te][se]: |
| | kg_graph[te][se].append(r + '_R') |
| |
|
| | return kg_graph |
| |
|
| |
|
| | def load_wiki_from_file(file_path) -> List[Dict[str, None]]: |
| | data_list = [] |
| | with gfile.Open(file_path, 'r') as fopen: |
| | lines = fopen.readlines() |
| | for line in lines: |
| | data = json.loads(line) |
| | data_list += [data] |
| | del lines |
| | return data_list |
| |
|
| |
|
| | def extract_2hop_graph(in_context_ents, kg_graph) -> Dict[Any, Dict[Any, bool]]: |
| | """For each wikipedia page with N in-context entities, extract a subgraph that contains only 2hop paths between any pair of nodes. |
| | |
| | Args: |
| | in_context_ents: all entities within each wiki-page, stored as dict. |
| | kg_graph: the global KG (i.e. WikiData Knowledge Graph), stored as dict. |
| | |
| | Returns: |
| | Extracted 2-hop subgraph for each page. |
| | """ |
| | all_nodes = {se: [se] for se in in_context_ents} |
| | for se in in_context_ents: |
| | if se in kg_graph: |
| | for te in kg_graph[se]: |
| | if te not in in_context_ents: |
| | if te not in all_nodes: |
| | all_nodes[te] = [se] |
| | else: |
| | all_nodes[te] += [se] |
| | remain_nodes = {e: True for e in all_nodes if len(all_nodes[e]) > 1} |
| | for e in in_context_ents: |
| | remain_nodes[e] = True |
| |
|
| | two_graph = {} |
| | for se in remain_nodes: |
| | if se in kg_graph: |
| | for te in kg_graph[se]: |
| | if te in remain_nodes: |
| | if se not in two_graph: |
| | two_graph[se] = {} |
| | two_graph[se][te] = kg_graph[se][te] |
| | if te in kg_graph and se in kg_graph[te]: |
| | if te not in two_graph: |
| | two_graph[te] = {} |
| | two_graph[te][se] = kg_graph[te][se] |
| | return two_graph |
| |
|
| |
|
| | def plot_graph(in_graph_ents, graph, entity_dict, print_out_label=True) -> None: |
| | """Function to plot each wikipedia's subgraph. |
| | |
| | Args: |
| | in_graph_ents: all in-context entities |
| | graph: subgraph of each wikipage. |
| | entity_dict: entityID to name |
| | print_out_label: whether to print the intermediate label. |
| | """ |
| | if not graph: |
| | return |
| | g = nx.Graph() |
| | all_label = {} |
| | in_label = {} |
| | for se in graph: |
| | all_label[entity_dict[se]] = entity_dict[se] |
| | if se in in_graph_ents: |
| | g.add_node(entity_dict[se], color='red', size=2000) |
| | in_label[entity_dict[se]] = entity_dict[se] |
| | if se not in in_graph_ents: |
| | g.add_node(entity_dict[se], color='blue', size=100) |
| | for te in graph[se]: |
| | all_label[entity_dict[te]] = entity_dict[te] |
| | if te in in_graph_ents: |
| | g.add_node(entity_dict[te], color='red', size=2000) |
| | in_label[entity_dict[te]] = entity_dict[te] |
| | if te not in in_graph_ents: |
| | g.add_node(entity_dict[te], color='blue', size=100) |
| | g.add_edge(entity_dict[se], entity_dict[te]) |
| | plt.figure(figsize=(10, 10)) |
| | layout = nx.kamada_kawai_layout(g) |
| | if print_out_label: |
| | nx.draw_networkx_nodes( |
| | g, |
| | pos=layout, |
| | node_color=nx.get_node_attributes(g, 'color').values(), |
| | node_size=list(nx.get_node_attributes(g, 'size').values())) |
| | nx.draw_networkx_labels(g, pos=layout, labels=all_label) |
| | nx.draw_networkx_edges(g, pos=layout, alpha=0.3, arrows=False) |
| | else: |
| | nx.draw_networkx_nodes( |
| | g, |
| | pos=layout, |
| | node_color=nx.get_node_attributes(g, 'color').values(), |
| | node_size=list(nx.get_node_attributes(g, 'size').values())) |
| | nx.draw_networkx_labels(g, pos=layout, labels=in_label) |
| | nx.draw_networkx_edges(g, pos=layout, alpha=0.3, arrows=False) |
| | xs = np.array(list(layout.values()))[:, 0] |
| | xmin, xmax = np.min(xs), np.max(xs) |
| | plt.xlim(xmin - (xmax - xmin) * 0.2, xmax + (xmax - xmin) * 0.2) |
| | plt.show() |
| |
|