| import os |
| import argparse |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--lang", type=str, default="en") |
| parser.add_argument("--num_shards", type=int, default=20) |
| parser.add_argument("--cuda", type=str, default="0") |
| args = parser.parse_args() |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda |
| import json |
| import tqdm |
| import numpy as np |
| import pandas as pd |
| import textstat |
| import spacy |
| import torch |
| import pickle |
| from sentence_transformers import SentenceTransformer, util |
|
|
|
|
| |
|
|
| |
| db_dir = "/home/mshahidul/readctrl/data/vector_db" |
| os.makedirs(db_dir, exist_ok=True) |
| embs_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_embs.pt") |
| text_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_chunks.pkl") |
|
|
| |
| model = SentenceTransformer('all-MiniLM-L6-v2') |
| nlp = spacy.load(f"{args.lang}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"]) |
|
|
| |
| def walk_tree(node, depth): |
| if not list(node.children): return depth |
| return max([walk_tree(child, depth + 1) for child in node.children], default=depth) |
|
|
| def get_parse_tree_stats(text): |
| doc = nlp(text) |
| depths = [walk_tree(sent.root, 1) for sent in doc.sents] |
| return np.mean(depths) if depths else 0 |
|
|
| |
| |
| |
| if os.path.exists(embs_cache_path) and os.path.exists(text_cache_path): |
| print("Loading cached vector database...") |
| all_chunk_embs = torch.load(embs_cache_path) |
| with open(text_cache_path, "rb") as f: |
| all_wiki_chunks = pickle.load(f) |
| print(f"Loaded {len(all_wiki_chunks)} chunks from cache.") |
| else: |
| print(f"Cache not found. Merging {args.num_shards} shards and encoding...") |
| all_wiki_chunks = [] |
| for i in range(args.num_shards): |
| path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{i}.parquet" |
| if os.path.exists(path): |
| df_shard = pd.read_parquet(path) |
| all_wiki_chunks.extend(df_shard['text'].tolist()) |
|
|
| print(f"Total merged chunks: {len(all_wiki_chunks)}") |
| |
| |
| all_chunk_embs = model.encode(all_wiki_chunks, convert_to_tensor=True, show_progress_bar=True) |
| |
| |
| print("Saving vector database for future use...") |
| torch.save(all_chunk_embs, embs_cache_path) |
| with open(text_cache_path, "wb") as f: |
| pickle.dump(all_wiki_chunks, f) |
| print("Database saved successfully.") |
|
|
| |
| |
| |
| |
| with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{args.lang}_v1.json", "r") as f: |
| res = json.load(f) |
|
|
| my_targets = [] |
| for item in res: |
| for key, val in item['diff_label_texts'].items(): |
| my_targets.append({"index": item['index'], "label": key, "text": val}) |
|
|
| target_texts = [d['text'] for d in my_targets] |
| target_embs = model.encode(target_texts, convert_to_tensor=True) |
|
|
| print("Running semantic search...") |
| search_results = util.semantic_search(target_embs, all_chunk_embs, top_k=25) |
|
|
| processed_data = [] |
| for i, hits in enumerate(tqdm.tqdm(search_results)): |
| doc = my_targets[i] |
| doc_len = len(doc['text'].split()) |
| |
| wiki_anchor = None |
| best_fallback = None |
| min_delta = float('inf') |
|
|
| for hit in hits: |
| cand_text = all_wiki_chunks[hit['corpus_id']] |
| cand_len = len(cand_text.split()) |
| len_diff = abs(cand_len - doc_len) |
| |
| if len_diff < min_delta: |
| min_delta = len_diff |
| best_fallback = cand_text |
| |
| if 0.8 <= (cand_len / doc_len) <= 1.2: |
| wiki_anchor = cand_text |
| break |
| |
| final_anchor = wiki_anchor if wiki_anchor else best_fallback |
|
|
| processed_data.append({ |
| "index": doc['index'], |
| "label": doc['label'], |
| "original_doc": doc['text'], |
| "wiki_anchor": final_anchor, |
| "doc_fkgl": textstat.flesch_kincaid_grade(doc['text']), |
| "wiki_fkgl": textstat.flesch_kincaid_grade(final_anchor), |
| "doc_tree_depth": get_parse_tree_stats(doc['text']), |
| "wiki_tree_depth": get_parse_tree_stats(final_anchor) |
| }) |
|
|
| final_save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v1/crowdsourcing_input_{args.lang}_fully_merged_v2.json" |
| with open(final_save_path, "w") as f: |
| json.dump(processed_data, f, indent=2) |
|
|
| print(f"Done! Results saved to {final_save_path}") |