import os import argparse # 1. Setup Arguments parser = argparse.ArgumentParser() parser.add_argument("--lang", type=str, default="en") parser.add_argument("--num_shards", type=int, default=20) parser.add_argument("--cuda", type=str, default="0") args = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda import json import tqdm import numpy as np import pandas as pd import textstat import spacy import torch import pickle # Added for saving text chunks efficiently from sentence_transformers import SentenceTransformer, util # device = "cuda" if torch.cuda.is_available() else "cpu" # Define Paths for the "Vector Database" db_dir = "/home/mshahidul/readctrl/data/vector_db" os.makedirs(db_dir, exist_ok=True) embs_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_embs.pt") text_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_chunks.pkl") # 2. Load Models model = SentenceTransformer('all-MiniLM-L6-v2') nlp = spacy.load(f"{args.lang}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"]) # (Helper functions get_parse_tree_stats and walk_tree remain the same...) def walk_tree(node, depth): if not list(node.children): return depth return max([walk_tree(child, depth + 1) for child in node.children], default=depth) def get_parse_tree_stats(text): doc = nlp(text) depths = [walk_tree(sent.root, 1) for sent in doc.sents] return np.mean(depths) if depths else 0 # --------------------------------------------------------- # 3. Step 1 & 2: Load or Create Vector Database # --------------------------------------------------------- if os.path.exists(embs_cache_path) and os.path.exists(text_cache_path): print("Loading cached vector database...") all_chunk_embs = torch.load(embs_cache_path) with open(text_cache_path, "rb") as f: all_wiki_chunks = pickle.load(f) print(f"Loaded {len(all_wiki_chunks)} chunks from cache.") else: print(f"Cache not found. Merging {args.num_shards} shards and encoding...") all_wiki_chunks = [] for i in range(args.num_shards): path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{i}.parquet" if os.path.exists(path): df_shard = pd.read_parquet(path) all_wiki_chunks.extend(df_shard['text'].tolist()) print(f"Total merged chunks: {len(all_wiki_chunks)}") # Encoding all_chunk_embs = model.encode(all_wiki_chunks, convert_to_tensor=True, show_progress_bar=True) # SAVE the vector database print("Saving vector database for future use...") torch.save(all_chunk_embs, embs_cache_path) with open(text_cache_path, "wb") as f: pickle.dump(all_wiki_chunks, f) print("Database saved successfully.") # --------------------------------------------------------- # 4. Step 3: Run Target Documents # --------------------------------------------------------- # (The rest of your target processing logic remains the same) with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{args.lang}_v1.json", "r") as f: res = json.load(f) my_targets = [] for item in res: for key, val in item['diff_label_texts'].items(): my_targets.append({"index": item['index'], "label": key, "text": val}) target_texts = [d['text'] for d in my_targets] target_embs = model.encode(target_texts, convert_to_tensor=True) print("Running semantic search...") search_results = util.semantic_search(target_embs, all_chunk_embs, top_k=25) processed_data = [] for i, hits in enumerate(tqdm.tqdm(search_results)): doc = my_targets[i] doc_len = len(doc['text'].split()) wiki_anchor = None best_fallback = None min_delta = float('inf') for hit in hits: cand_text = all_wiki_chunks[hit['corpus_id']] cand_len = len(cand_text.split()) len_diff = abs(cand_len - doc_len) if len_diff < min_delta: min_delta = len_diff best_fallback = cand_text if 0.8 <= (cand_len / doc_len) <= 1.2: wiki_anchor = cand_text break final_anchor = wiki_anchor if wiki_anchor else best_fallback processed_data.append({ "index": doc['index'], "label": doc['label'], "original_doc": doc['text'], "wiki_anchor": final_anchor, "doc_fkgl": textstat.flesch_kincaid_grade(doc['text']), "wiki_fkgl": textstat.flesch_kincaid_grade(final_anchor), "doc_tree_depth": get_parse_tree_stats(doc['text']), "wiki_tree_depth": get_parse_tree_stats(final_anchor) }) final_save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v1/crowdsourcing_input_{args.lang}_fully_merged_v2.json" with open(final_save_path, "w") as f: json.dump(processed_data, f, indent=2) print(f"Done! Results saved to {final_save_path}")