fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for processing the data."""
import json
from typing import Any, Dict, List
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import tqdm
import tensorflow as tf
from tf.io import gfile
MAX_REL_PER_ENTITY = 10000
# Add path to your data here:
WIKIPEDIA_LINK_PATH = ''
WIKIPEDIA_GRAPH_PATH = ''
WIKIDATA_EDGE_PATH = ''
WIKIDATA_ENTITY_PATH = ''
NUM_SPLIT = 1024
def construct_kg_graph(rel_list) -> Dict[str, Dict[str, List[str]]]:
"""construct a kg graph with reverse link.
Args:
rel_list: KG edge list, in the format of <s, r, t> triplets.
Returns:
"""
kg_graph = {}
kg_rel = {}
for se, r, te in tqdm.tqdm(rel_list):
if se == te:
continue
if se not in kg_graph:
kg_graph[se] = {}
if te not in kg_graph[se]:
kg_graph[se][te] = []
if r not in kg_graph[se][te]:
kg_graph[se][te].append(r)
if r not in kg_rel:
kg_rel[r] = {}
if te not in kg_rel[r]:
kg_rel[r][te] = 0
kg_rel[r][te] += 1
for se, r, te in tqdm.tqdm(rel_list):
if se == te or kg_rel[r][te] >= MAX_REL_PER_ENTITY:
continue
if te not in kg_graph:
kg_graph[te] = {}
if se not in kg_graph[te]:
kg_graph[te][se] = []
if r + '_R' not in kg_graph[te][se]:
kg_graph[te][se].append(r + '_R')
return kg_graph
def load_wiki_from_file(file_path) -> List[Dict[str, None]]:
data_list = []
with gfile.Open(file_path, 'r') as fopen:
lines = fopen.readlines()
for line in lines:
data = json.loads(line)
data_list += [data]
del lines
return data_list
def extract_2hop_graph(in_context_ents, kg_graph) -> Dict[Any, Dict[Any, bool]]:
"""For each wikipedia page with N in-context entities, extract a subgraph that contains only 2hop paths between any pair of nodes.
Args:
in_context_ents: all entities within each wiki-page, stored as dict.
kg_graph: the global KG (i.e. WikiData Knowledge Graph), stored as dict.
Returns:
Extracted 2-hop subgraph for each page.
"""
all_nodes = {se: [se] for se in in_context_ents}
for se in in_context_ents:
if se in kg_graph:
for te in kg_graph[se]:
if te not in in_context_ents:
if te not in all_nodes:
all_nodes[te] = [se]
else:
all_nodes[te] += [se]
remain_nodes = {e: True for e in all_nodes if len(all_nodes[e]) > 1}
for e in in_context_ents:
remain_nodes[e] = True
two_graph = {}
for se in remain_nodes:
if se in kg_graph:
for te in kg_graph[se]:
if te in remain_nodes:
if se not in two_graph:
two_graph[se] = {}
two_graph[se][te] = kg_graph[se][te]
if te in kg_graph and se in kg_graph[te]:
if te not in two_graph:
two_graph[te] = {}
two_graph[te][se] = kg_graph[te][se]
return two_graph
def plot_graph(in_graph_ents, graph, entity_dict, print_out_label=True) -> None:
"""Function to plot each wikipedia's subgraph.
Args:
in_graph_ents: all in-context entities
graph: subgraph of each wikipage.
entity_dict: entityID to name
print_out_label: whether to print the intermediate label.
"""
if not graph:
return
g = nx.Graph()
all_label = {}
in_label = {}
for se in graph:
all_label[entity_dict[se]] = entity_dict[se]
if se in in_graph_ents:
g.add_node(entity_dict[se], color='red', size=2000)
in_label[entity_dict[se]] = entity_dict[se]
if se not in in_graph_ents:
g.add_node(entity_dict[se], color='blue', size=100)
for te in graph[se]:
all_label[entity_dict[te]] = entity_dict[te]
if te in in_graph_ents:
g.add_node(entity_dict[te], color='red', size=2000)
in_label[entity_dict[te]] = entity_dict[te]
if te not in in_graph_ents:
g.add_node(entity_dict[te], color='blue', size=100)
g.add_edge(entity_dict[se], entity_dict[te])
plt.figure(figsize=(10, 10))
layout = nx.kamada_kawai_layout(g)
if print_out_label:
nx.draw_networkx_nodes(
g,
pos=layout,
node_color=nx.get_node_attributes(g, 'color').values(),
node_size=list(nx.get_node_attributes(g, 'size').values()))
nx.draw_networkx_labels(g, pos=layout, labels=all_label)
nx.draw_networkx_edges(g, pos=layout, alpha=0.3, arrows=False)
else:
nx.draw_networkx_nodes(
g,
pos=layout,
node_color=nx.get_node_attributes(g, 'color').values(),
node_size=list(nx.get_node_attributes(g, 'size').values()))
nx.draw_networkx_labels(g, pos=layout, labels=in_label)
nx.draw_networkx_edges(g, pos=layout, alpha=0.3, arrows=False)
xs = np.array(list(layout.values()))[:, 0]
xmin, xmax = np.min(xs), np.max(xs)
plt.xlim(xmin - (xmax - xmin) * 0.2, xmax + (xmax - xmin) * 0.2)
plt.show()