readctrl / code /vectordb_build /data_annotate_data_prep_test_v4.py
shahidul034's picture
Add files using upload-large-folder tool
c7a6fe6 verified
import os
import argparse
# 1. Setup Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en")
parser.add_argument("--num_shards", type=int, default=20)
parser.add_argument("--cuda", type=str, default="0")
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
import json
import tqdm
import numpy as np
import pandas as pd
import textstat
import spacy
import torch
import pickle # Added for saving text chunks efficiently
from sentence_transformers import SentenceTransformer, util
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Define Paths for the "Vector Database"
db_dir = "/home/mshahidul/readctrl/data/vector_db"
os.makedirs(db_dir, exist_ok=True)
embs_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_embs.pt")
text_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_chunks.pkl")
# 2. Load Models
model = SentenceTransformer('all-MiniLM-L6-v2')
nlp = spacy.load(f"{args.lang}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"])
# (Helper functions get_parse_tree_stats and walk_tree remain the same...)
def walk_tree(node, depth):
if not list(node.children): return depth
return max([walk_tree(child, depth + 1) for child in node.children], default=depth)
def get_parse_tree_stats(text):
doc = nlp(text)
depths = [walk_tree(sent.root, 1) for sent in doc.sents]
return np.mean(depths) if depths else 0
# ---------------------------------------------------------
# 3. Step 1 & 2: Load or Create Vector Database
# ---------------------------------------------------------
if os.path.exists(embs_cache_path) and os.path.exists(text_cache_path):
print("Loading cached vector database...")
all_chunk_embs = torch.load(embs_cache_path)
with open(text_cache_path, "rb") as f:
all_wiki_chunks = pickle.load(f)
print(f"Loaded {len(all_wiki_chunks)} chunks from cache.")
else:
print(f"Cache not found. Merging {args.num_shards} shards and encoding...")
all_wiki_chunks = []
for i in range(args.num_shards):
path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{i}.parquet"
if os.path.exists(path):
df_shard = pd.read_parquet(path)
all_wiki_chunks.extend(df_shard['text'].tolist())
print(f"Total merged chunks: {len(all_wiki_chunks)}")
# Encoding
all_chunk_embs = model.encode(all_wiki_chunks, convert_to_tensor=True, show_progress_bar=True)
# SAVE the vector database
print("Saving vector database for future use...")
torch.save(all_chunk_embs, embs_cache_path)
with open(text_cache_path, "wb") as f:
pickle.dump(all_wiki_chunks, f)
print("Database saved successfully.")
# ---------------------------------------------------------
# 4. Step 3: Run Target Documents
# ---------------------------------------------------------
# (The rest of your target processing logic remains the same)
with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{args.lang}_v1.json", "r") as f:
res = json.load(f)
my_targets = []
for item in res:
for key, val in item['diff_label_texts'].items():
my_targets.append({"index": item['index'], "label": key, "text": val})
target_texts = [d['text'] for d in my_targets]
target_embs = model.encode(target_texts, convert_to_tensor=True)
print("Running semantic search...")
search_results = util.semantic_search(target_embs, all_chunk_embs, top_k=25)
processed_data = []
for i, hits in enumerate(tqdm.tqdm(search_results)):
doc = my_targets[i]
doc_len = len(doc['text'].split())
wiki_anchor = None
best_fallback = None
min_delta = float('inf')
for hit in hits:
cand_text = all_wiki_chunks[hit['corpus_id']]
cand_len = len(cand_text.split())
len_diff = abs(cand_len - doc_len)
if len_diff < min_delta:
min_delta = len_diff
best_fallback = cand_text
if 0.8 <= (cand_len / doc_len) <= 1.2:
wiki_anchor = cand_text
break
final_anchor = wiki_anchor if wiki_anchor else best_fallback
processed_data.append({
"index": doc['index'],
"label": doc['label'],
"original_doc": doc['text'],
"wiki_anchor": final_anchor,
"doc_fkgl": textstat.flesch_kincaid_grade(doc['text']),
"wiki_fkgl": textstat.flesch_kincaid_grade(final_anchor),
"doc_tree_depth": get_parse_tree_stats(doc['text']),
"wiki_tree_depth": get_parse_tree_stats(final_anchor)
})
final_save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v1/crowdsourcing_input_{args.lang}_fully_merged_v2.json"
with open(final_save_path, "w") as f:
json.dump(processed_data, f, indent=2)
print(f"Done! Results saved to {final_save_path}")