| import spacy |
| import zstandard as zstd |
| import json |
| import typing |
| import os |
| from tqdm import tqdm |
| import multiprocessing |
| import random |
| from langdetect import detect |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--input_dir', type=str, help='Path to the input file') |
| args = parser.parse_args() |
| input_dir = args.input_dir |
|
|
|
|
| def is_english(text): |
| try: |
| lang = detect(text) |
| return lang == 'en' |
| except: |
| return False |
| |
| def process_text(texts, model, out_f, lock): |
| for text in texts: |
| doc = model(text) |
| freq_cnt = {} |
| for e in doc.ents: |
| if e not in freq_cnt: |
| freq_cnt[e] = 0 |
| freq_cnt[e] += 1 |
| if len(freq_cnt) == 0: |
| continue |
| sorted_freq = sorted(freq_cnt.items(), key = lambda x:[1]) |
| most_freq = sorted_freq[-1][0] |
| data = {'text':text, 'main_entity':most_freq.text, 'label': most_freq.label_, 'id': most_freq.kb_id_} |
| json_data = json.dumps(data) |
| with lock: |
| out_f.write(json_data + '\n') |
| out_f.flush() |
| |
| def run_ner_linking(texts: typing.List[str], ner_model_path: str): |
| nlp = spacy.load(ner_model_path) |
| out_f = open('result/temp_store_data.json', 'w', encoding='utf-8') |
| lock = multiprocessing.Lock() |
| processes = [] |
|
|
| for i in tqdm(range(0, len(texts), 1000)): |
| p = multiprocessing.Process(target=process_text, args=(texts[i:i+1000], nlp, out_f, lock)) |
| processes.append(p) |
| p.start() |
|
|
| for p in processes: |
| p.join() |
| |
| out_f.close() |
| return |
|
|
| wikipedia_out_path='result/wikipedia.json' |
| subdirectories = [f.path for f in os.scandir(input_dir) if f.is_dir()] |
| wikipedia_data = [] |
| for sub_dir in subdirectories: |
| chunk_dir = sub_dir+'/' |
| zst_files = [f for f in os.listdir(chunk_dir) if f.endswith('.zst')] |
| for file in tqdm(zst_files): |
| with open(chunk_dir+file, 'rb') as compressed_file: |
| decompressor = zstd.ZstdDecompressor() |
| with decompressor.stream_reader(compressed_file) as reader: |
| decompressed_data = reader.read() |
| for line in decompressed_data.splitlines(): |
| data = json.loads(line) |
| |
| if data['meta']['redpajama_set_name']=='RedPajamaWikipedia': |
| if is_english(data['text']): |
| wikipedia_data.append(data) |
| |
| with open(wikipedia_out_path, 'w', encoding='utf-8') as f: |
| for data in wikipedia_data: |
| json_data = json.dumps(data) |
| f.write(json_data+'\n') |
| |
| wikipedia_data = [] |
| ner_model_path = 'kc-ner-model' |
| with open(wikipedia_out_path, 'r', encoding='utf-8') as f: |
| for line in tqdm(f): |
| data = json.loads(line) |
| wikipedia_data.append(data['text']) |
| run_ner_linking(wikipedia_data, ner_model_path) |
|
|
| entity_info_path = 'result/entity_info.json' |
| with open(entity_info_path, 'r', encoding='utf-8') as f: |
| entity_info = json.load(f) |
| all_original_data = [] |
|
|
| category = {} |
| all_data = [] |
| with open('result/temp_store_data.json', 'r', encoding='utf-8') as f: |
| for line in f: |
| data = json.loads(line) |
| all_data.append(data) |
| if data['label'] not in category: |
| category[data['label']] = [] |
| category[data['label']].append(data['main_entity']) |
| |
| with open('result/processed_data.json', 'w', encoding='utf-8') as f: |
| for data in tqdm(all_data): |
| text = data['text'] |
| main_entity = [data['main_entity']] |
| if data['id'] in entity_info: |
| main_entity.extend(entity_info[data['id']]['aliases']) |
| if len(category[data['label']]) == 1: |
| continue |
| replaced_eneity = random.sample(category[data['label']], 1) |
| while replaced_eneity[0] in main_entity: |
| replaced_eneity = random.sample(category[data['label']], 1) |
| for entity in main_entity: |
| text = text.replace(entity, replaced_eneity[0]) |
| data = { |
| 'text':text, |
| 'original_main_entity':main_entity, |
| 'replaced_entity':replaced_eneity[0] |
| } |
| json_data = json.dumps(data) |
| f.write(json_data+'\n') |
| |
|
|