File size: 5,410 Bytes
1327f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for processing the data."""

import json
from typing import Any, Dict, List

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import tqdm

import tensorflow as tf
from tf.io import gfile

MAX_REL_PER_ENTITY = 10000
# Add path to your data here:
WIKIPEDIA_LINK_PATH = ''
WIKIPEDIA_GRAPH_PATH = ''
WIKIDATA_EDGE_PATH = ''
WIKIDATA_ENTITY_PATH = ''

NUM_SPLIT = 1024


def construct_kg_graph(rel_list) -> Dict[str, Dict[str, List[str]]]:
  """construct a kg graph with reverse link.

  Args:
    rel_list: KG edge list, in the format of <s, r, t> triplets.

  Returns:

  """
  kg_graph = {}
  kg_rel = {}
  for se, r, te in tqdm.tqdm(rel_list):
    if se == te:
      continue
    if se not in kg_graph:
      kg_graph[se] = {}
    if te not in kg_graph[se]:
      kg_graph[se][te] = []
    if r not in kg_graph[se][te]:
      kg_graph[se][te].append(r)

    if r not in kg_rel:
      kg_rel[r] = {}
    if te not in kg_rel[r]:
      kg_rel[r][te] = 0
    kg_rel[r][te] += 1

  for se, r, te in tqdm.tqdm(rel_list):
    if se == te or kg_rel[r][te] >= MAX_REL_PER_ENTITY:
      continue
    if te not in kg_graph:
      kg_graph[te] = {}
    if se not in kg_graph[te]:
      kg_graph[te][se] = []
    if r + '_R' not in kg_graph[te][se]:
      kg_graph[te][se].append(r + '_R')

  return kg_graph


def load_wiki_from_file(file_path) -> List[Dict[str, None]]:
  data_list = []
  with gfile.Open(file_path, 'r') as fopen:
    lines = fopen.readlines()
  for line in lines:
    data = json.loads(line)
    data_list += [data]
  del lines
  return data_list


def extract_2hop_graph(in_context_ents, kg_graph) -> Dict[Any, Dict[Any, bool]]:
  """For each wikipedia page with N in-context entities, extract a subgraph that contains only 2hop paths between any pair of nodes.

  Args:
    in_context_ents: all entities within each wiki-page, stored as dict.
    kg_graph: the global KG (i.e. WikiData Knowledge Graph), stored as dict.

  Returns:
    Extracted 2-hop subgraph for each page.
  """
  all_nodes = {se: [se] for se in in_context_ents}
  for se in in_context_ents:
    if se in kg_graph:
      for te in kg_graph[se]:
        if te not in in_context_ents:
          if te not in all_nodes:
            all_nodes[te] = [se]
          else:
            all_nodes[te] += [se]
  remain_nodes = {e: True for e in all_nodes if len(all_nodes[e]) > 1}
  for e in in_context_ents:
    remain_nodes[e] = True

  two_graph = {}
  for se in remain_nodes:
    if se in kg_graph:
      for te in kg_graph[se]:
        if te in remain_nodes:
          if se not in two_graph:
            two_graph[se] = {}
          two_graph[se][te] = kg_graph[se][te]
          if te in kg_graph and se in kg_graph[te]:
            if te not in two_graph:
              two_graph[te] = {}
            two_graph[te][se] = kg_graph[te][se]
  return two_graph


def plot_graph(in_graph_ents, graph, entity_dict, print_out_label=True) -> None:
  """Function to plot each wikipedia's subgraph.

  Args:
    in_graph_ents: all in-context entities
    graph: subgraph of each wikipage.
    entity_dict: entityID to name
    print_out_label: whether to print the intermediate label.
  """
  if not graph:
    return
  g = nx.Graph()
  all_label = {}
  in_label = {}
  for se in graph:
    all_label[entity_dict[se]] = entity_dict[se]
    if se in in_graph_ents:
      g.add_node(entity_dict[se], color='red', size=2000)
      in_label[entity_dict[se]] = entity_dict[se]
    if se not in in_graph_ents:
      g.add_node(entity_dict[se], color='blue', size=100)
    for te in graph[se]:
      all_label[entity_dict[te]] = entity_dict[te]
      if te in in_graph_ents:
        g.add_node(entity_dict[te], color='red', size=2000)
        in_label[entity_dict[te]] = entity_dict[te]
      if te not in in_graph_ents:
        g.add_node(entity_dict[te], color='blue', size=100)
      g.add_edge(entity_dict[se], entity_dict[te])
  plt.figure(figsize=(10, 10))
  layout = nx.kamada_kawai_layout(g)
  if print_out_label:
    nx.draw_networkx_nodes(
        g,
        pos=layout,
        node_color=nx.get_node_attributes(g, 'color').values(),
        node_size=list(nx.get_node_attributes(g, 'size').values()))
    nx.draw_networkx_labels(g, pos=layout, labels=all_label)
    nx.draw_networkx_edges(g, pos=layout, alpha=0.3, arrows=False)
  else:
    nx.draw_networkx_nodes(
        g,
        pos=layout,
        node_color=nx.get_node_attributes(g, 'color').values(),
        node_size=list(nx.get_node_attributes(g, 'size').values()))
    nx.draw_networkx_labels(g, pos=layout, labels=in_label)
    nx.draw_networkx_edges(g, pos=layout, alpha=0.3, arrows=False)
  xs = np.array(list(layout.values()))[:, 0]
  xmin, xmax = np.min(xs), np.max(xs)
  plt.xlim(xmin - (xmax - xmin) * 0.2, xmax + (xmax - xmin) * 0.2)
  plt.show()