| | import os |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--lang", type=str, default="en") |
| | parser.add_argument("--num_shards", type=int, default=20) |
| | parser.add_argument("--cuda", type=str, default="0") |
| | args = parser.parse_args() |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda |
| | import json |
| | import tqdm |
| | import numpy as np |
| | import pandas as pd |
| | import textstat |
| | import spacy |
| | import torch |
| | import pickle |
| | from sentence_transformers import SentenceTransformer, util |
| |
|
| |
|
| | |
| |
|
| | |
| | db_dir = "/home/mshahidul/readctrl/data/vector_db" |
| | os.makedirs(db_dir, exist_ok=True) |
| | embs_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_embs.pt") |
| | text_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_chunks.pkl") |
| |
|
| | |
| | model = SentenceTransformer('all-MiniLM-L6-v2') |
| | nlp = spacy.load(f"{args.lang}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"]) |
| |
|
| | |
| | def walk_tree(node, depth): |
| | if not list(node.children): return depth |
| | return max([walk_tree(child, depth + 1) for child in node.children], default=depth) |
| |
|
| | def get_parse_tree_stats(text): |
| | doc = nlp(text) |
| | depths = [walk_tree(sent.root, 1) for sent in doc.sents] |
| | return np.mean(depths) if depths else 0 |
| |
|
| | |
| | |
| | |
| | if os.path.exists(embs_cache_path) and os.path.exists(text_cache_path): |
| | print("Loading cached vector database...") |
| | all_chunk_embs = torch.load(embs_cache_path) |
| | with open(text_cache_path, "rb") as f: |
| | all_wiki_chunks = pickle.load(f) |
| | print(f"Loaded {len(all_wiki_chunks)} chunks from cache.") |
| | else: |
| | print(f"Cache not found. Merging {args.num_shards} shards and encoding...") |
| | all_wiki_chunks = [] |
| | for i in range(args.num_shards): |
| | path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{i}.parquet" |
| | if os.path.exists(path): |
| | df_shard = pd.read_parquet(path) |
| | all_wiki_chunks.extend(df_shard['text'].tolist()) |
| |
|
| | print(f"Total merged chunks: {len(all_wiki_chunks)}") |
| | |
| | |
| | all_chunk_embs = model.encode(all_wiki_chunks, convert_to_tensor=True, show_progress_bar=True) |
| | |
| | |
| | print("Saving vector database for future use...") |
| | torch.save(all_chunk_embs, embs_cache_path) |
| | with open(text_cache_path, "wb") as f: |
| | pickle.dump(all_wiki_chunks, f) |
| | print("Database saved successfully.") |
| |
|
| | |
| | |
| | |
| | |
| | with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{args.lang}_v1.json", "r") as f: |
| | res = json.load(f) |
| |
|
| | my_targets = [] |
| | for item in res: |
| | for key, val in item['diff_label_texts'].items(): |
| | my_targets.append({"index": item['index'], "label": key, "text": val}) |
| |
|
| | target_texts = [d['text'] for d in my_targets] |
| | target_embs = model.encode(target_texts, convert_to_tensor=True) |
| |
|
| | print("Running semantic search...") |
| | search_results = util.semantic_search(target_embs, all_chunk_embs, top_k=25) |
| |
|
| | processed_data = [] |
| | for i, hits in enumerate(tqdm.tqdm(search_results)): |
| | doc = my_targets[i] |
| | doc_len = len(doc['text'].split()) |
| | |
| | wiki_anchor = None |
| | best_fallback = None |
| | min_delta = float('inf') |
| |
|
| | for hit in hits: |
| | cand_text = all_wiki_chunks[hit['corpus_id']] |
| | cand_len = len(cand_text.split()) |
| | len_diff = abs(cand_len - doc_len) |
| | |
| | if len_diff < min_delta: |
| | min_delta = len_diff |
| | best_fallback = cand_text |
| | |
| | if 0.8 <= (cand_len / doc_len) <= 1.2: |
| | wiki_anchor = cand_text |
| | break |
| | |
| | final_anchor = wiki_anchor if wiki_anchor else best_fallback |
| |
|
| | processed_data.append({ |
| | "index": doc['index'], |
| | "label": doc['label'], |
| | "original_doc": doc['text'], |
| | "wiki_anchor": final_anchor, |
| | "doc_fkgl": textstat.flesch_kincaid_grade(doc['text']), |
| | "wiki_fkgl": textstat.flesch_kincaid_grade(final_anchor), |
| | "doc_tree_depth": get_parse_tree_stats(doc['text']), |
| | "wiki_tree_depth": get_parse_tree_stats(final_anchor) |
| | }) |
| |
|
| | final_save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v1/crowdsourcing_input_{args.lang}_fully_merged_v2.json" |
| | with open(final_save_path, "w") as f: |
| | json.dump(processed_data, f, indent=2) |
| |
|
| | print(f"Done! Results saved to {final_save_path}") |