fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extract 2-hop subgraph from each wikipedia page."""
import json
import multiprocessing as mp
import os
from absl import app
from absl import logging
from scenic.projects.knowledge_visual_language.data.wikidata import data_util
import tqdm
import tensorflow as tf
from tf.io import gfile
def extract_graph_from_file(kg_graph, i) -> None:
"""Extract 2hop graphs for all pages in each file.
Args:
kg_graph: the global KG (i.e. WikiData Knowledge Graph), stored as dict.
i: index of wiki file.
Returns:
A list of wikipedia pages with extracted graph.
"""
d_list = data_util.load_wiki_from_file(
os.path.join(data_util.WIKIPEDIA_LINK_PATH, 'wiki_%d.txt' % i))
with gfile.Open(
os.path.join(data_util.WIKIPEDIA_GRAPH_PATH, 'wiki_graph_%d.txt' % i),
'w') as fopen:
for d in d_list:
if 'link' in d and isinstance(d['link'], dict):
in_context_ents = {ent: True for ent in d['link'] if ent != 'UNK'}
d['graph'] = data_util.extract_2hop_graph(in_context_ents, kg_graph)
fopen.write(json.dumps(d) + '\n')
del d_list
def extract_graph_mp(kg_graph, n_pool=1) -> None:
"""Use multiprocessing to parallize graph extraction.
Args:
kg_graph: the global KG (i.e. WikiData Knowledge Graph), stored as dict.
n_pool: number of process (worker)
"""
jobs = []
if not gfile.Exists(data_util.WIKIPEDIA_GRAPH_PATH):
gfile.MakeDirs(
data_util.WIKIPEDIA_GRAPH_PATH,
mode=gfile.LEGACY_GROUP_WRITABLE_WORLD_READABLE,
)
if n_pool == 1:
for i in range(data_util.NUM_SPLIT):
extract_graph_from_file(kg_graph, i)
else:
with mp.Pool(n_pool) as pool:
for i in range(data_util.NUM_SPLIT):
jobs += [pool.apply_async(extract_graph_from_file, args=(kg_graph, i))]
for job in tqdm.tqdm(jobs):
job.get()
def main(_) -> None:
logging.info('Load WikiData relational edges.')
with gfile.Open(data_util.WIKIDATA_EDGE_PATH) as f:
rel_list = json.load(f)
logging.info('Construct KG Graph from edge list.')
kg_graph = data_util.construct_kg_graph(rel_list)
logging.info('Extract 2-hop subgraphs within wiki-pages.')
extract_graph_mp(kg_graph, n_pool=64)
if __name__ == '__main__':
app.run(main)