import os import argparse parser = argparse.ArgumentParser() parser.add_argument("--lang", type=str, default="en", help="language code") parser.add_argument("--cuda", type=str, default="3", help="CUDA device ID to use") args = parser.parse_args() lang_code = args.lang os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda import json import tqdm import numpy as np import pandas as pd import textstat import spacy import torch import glob from sentence_transformers import SentenceTransformer, util # 1. Load Models model = SentenceTransformer('all-MiniLM-L6-v2') nlp = spacy.load(f"{lang_code}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"]) def get_parse_tree_stats(text): doc = nlp(text) depths = [] for sent in doc.sents: def walk_tree(node, depth): if not list(node.children): return depth return max(walk_tree(child, depth + 1) for child in node.children) depths.append(walk_tree(sent.root, 1)) return np.mean(depths) if depths else 0 # 2. Load and Merge All Shards print("Loading and merging all shards...") shard_pattern = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{lang_code}_shard_*.parquet" shard_files = sorted(glob.glob(shard_pattern)) all_dfs = [] for f in shard_files: all_dfs.append(pd.read_parquet(f)) df_merged = pd.concat(all_dfs, ignore_index=True) wiki_chunks = df_merged['text'].tolist() print(f"Total wiki chunks loaded: {len(wiki_chunks)}") # 3. Encode Merged Chunks (Keep on GPU) print("Encoding merged chunks...") chunk_embs = model.encode(wiki_chunks, convert_to_tensor=True, show_progress_bar=True) # 4. Load Target Docs with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{lang_code}_v1.json", "r") as f: res = json.load(f) my_target_documents = [] for item in res: for key, value in item['diff_label_texts'].items(): my_target_documents.append({"index": item['index'], "label": key, "text": value}) # 5. Output Path (Removed shard_id from filename) save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v2/crowdsourcing_input_{lang_code}_merged_v1.json" os.makedirs(os.path.dirname(save_path), exist_ok=True) processed_data = [] if os.path.exists(save_path): with open(save_path, "r") as f: processed_data = json.load(f) processed_keys = {(d['index'], d['label']) for d in processed_data} # 6. Process Loop print(f"Starting Matching Loop for {len(my_target_documents)} documents...") for doc in tqdm.tqdm(my_target_documents): if (doc['index'], doc['label']) in processed_keys: continue doc_emb = model.encode(doc['text'], convert_to_tensor=True) doc_len = len(doc['text'].split()) # Search across the entire merged corpus hits = util.semantic_search(doc_emb, chunk_embs, top_k=25)[0] wiki_anchor = None best_fallback = None min_delta = float('inf') for hit in hits: cand_text = wiki_chunks[hit['corpus_id']] cand_len = len(cand_text.split()) len_diff = abs(cand_len - doc_len) if len_diff < min_delta: min_delta = len_diff best_fallback = cand_text if 0.8 <= (cand_len / doc_len) <= 1.2: wiki_anchor = cand_text break if not wiki_anchor: wiki_anchor = best_fallback # Calculate Metrics processed_data.append({ "index": doc['index'], "label": doc['label'], "original_doc": doc['text'], "wiki_anchor": wiki_anchor, "doc_fkgl": textstat.flesch_kincaid_grade(doc['text']), "wiki_fkgl": textstat.flesch_kincaid_grade(wiki_anchor), "doc_tree_depth": get_parse_tree_stats(doc['text']), "wiki_tree_depth": get_parse_tree_stats(wiki_anchor), "fkgl_delta": textstat.flesch_kincaid_grade(doc['text']) - textstat.flesch_kincaid_grade(wiki_anchor) }) if len(processed_data) % 20 == 0: with open(save_path, "w") as f: json.dump(processed_data, f, indent=2) # Final Save with open(save_path, "w") as f: json.dump(processed_data, f, indent=2) print(f"Processing complete. Saved to {save_path}")