File size: 4,915 Bytes
c7a6fe6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import os
import argparse
# 1. Setup Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en")
parser.add_argument("--num_shards", type=int, default=20)
parser.add_argument("--cuda", type=str, default="0")
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
import json
import tqdm
import numpy as np
import pandas as pd
import textstat
import spacy
import torch
import pickle # Added for saving text chunks efficiently
from sentence_transformers import SentenceTransformer, util
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Define Paths for the "Vector Database"
db_dir = "/home/mshahidul/readctrl/data/vector_db"
os.makedirs(db_dir, exist_ok=True)
embs_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_embs.pt")
text_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_chunks.pkl")
# 2. Load Models
model = SentenceTransformer('all-MiniLM-L6-v2')
nlp = spacy.load(f"{args.lang}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"])
# (Helper functions get_parse_tree_stats and walk_tree remain the same...)
def walk_tree(node, depth):
if not list(node.children): return depth
return max([walk_tree(child, depth + 1) for child in node.children], default=depth)
def get_parse_tree_stats(text):
doc = nlp(text)
depths = [walk_tree(sent.root, 1) for sent in doc.sents]
return np.mean(depths) if depths else 0
# ---------------------------------------------------------
# 3. Step 1 & 2: Load or Create Vector Database
# ---------------------------------------------------------
if os.path.exists(embs_cache_path) and os.path.exists(text_cache_path):
print("Loading cached vector database...")
all_chunk_embs = torch.load(embs_cache_path)
with open(text_cache_path, "rb") as f:
all_wiki_chunks = pickle.load(f)
print(f"Loaded {len(all_wiki_chunks)} chunks from cache.")
else:
print(f"Cache not found. Merging {args.num_shards} shards and encoding...")
all_wiki_chunks = []
for i in range(args.num_shards):
path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{i}.parquet"
if os.path.exists(path):
df_shard = pd.read_parquet(path)
all_wiki_chunks.extend(df_shard['text'].tolist())
print(f"Total merged chunks: {len(all_wiki_chunks)}")
# Encoding
all_chunk_embs = model.encode(all_wiki_chunks, convert_to_tensor=True, show_progress_bar=True)
# SAVE the vector database
print("Saving vector database for future use...")
torch.save(all_chunk_embs, embs_cache_path)
with open(text_cache_path, "wb") as f:
pickle.dump(all_wiki_chunks, f)
print("Database saved successfully.")
# ---------------------------------------------------------
# 4. Step 3: Run Target Documents
# ---------------------------------------------------------
# (The rest of your target processing logic remains the same)
with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{args.lang}_v1.json", "r") as f:
res = json.load(f)
my_targets = []
for item in res:
for key, val in item['diff_label_texts'].items():
my_targets.append({"index": item['index'], "label": key, "text": val})
target_texts = [d['text'] for d in my_targets]
target_embs = model.encode(target_texts, convert_to_tensor=True)
print("Running semantic search...")
search_results = util.semantic_search(target_embs, all_chunk_embs, top_k=25)
processed_data = []
for i, hits in enumerate(tqdm.tqdm(search_results)):
doc = my_targets[i]
doc_len = len(doc['text'].split())
wiki_anchor = None
best_fallback = None
min_delta = float('inf')
for hit in hits:
cand_text = all_wiki_chunks[hit['corpus_id']]
cand_len = len(cand_text.split())
len_diff = abs(cand_len - doc_len)
if len_diff < min_delta:
min_delta = len_diff
best_fallback = cand_text
if 0.8 <= (cand_len / doc_len) <= 1.2:
wiki_anchor = cand_text
break
final_anchor = wiki_anchor if wiki_anchor else best_fallback
processed_data.append({
"index": doc['index'],
"label": doc['label'],
"original_doc": doc['text'],
"wiki_anchor": final_anchor,
"doc_fkgl": textstat.flesch_kincaid_grade(doc['text']),
"wiki_fkgl": textstat.flesch_kincaid_grade(final_anchor),
"doc_tree_depth": get_parse_tree_stats(doc['text']),
"wiki_tree_depth": get_parse_tree_stats(final_anchor)
})
final_save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v1/crowdsourcing_input_{args.lang}_fully_merged_v2.json"
with open(final_save_path, "w") as f:
json.dump(processed_data, f, indent=2)
print(f"Done! Results saved to {final_save_path}") |