File size: 4,915 Bytes
c7a6fe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import argparse
# 1. Setup Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="en")
parser.add_argument("--num_shards", type=int, default=20)
parser.add_argument("--cuda", type=str, default="0")
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
import json
import tqdm
import numpy as np
import pandas as pd
import textstat
import spacy
import torch
import pickle  # Added for saving text chunks efficiently
from sentence_transformers import SentenceTransformer, util


# device = "cuda" if torch.cuda.is_available() else "cpu"

# Define Paths for the "Vector Database"
db_dir = "/home/mshahidul/readctrl/data/vector_db"
os.makedirs(db_dir, exist_ok=True)
embs_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_embs.pt")
text_cache_path = os.path.join(db_dir, f"wiki_{args.lang}_chunks.pkl")

# 2. Load Models
model = SentenceTransformer('all-MiniLM-L6-v2')
nlp = spacy.load(f"{args.lang}_core_web_sm", disable=["ner", "lemmatizer", "attribute_ruler"])

# (Helper functions get_parse_tree_stats and walk_tree remain the same...)
def walk_tree(node, depth):
    if not list(node.children): return depth
    return max([walk_tree(child, depth + 1) for child in node.children], default=depth)

def get_parse_tree_stats(text):
    doc = nlp(text)
    depths = [walk_tree(sent.root, 1) for sent in doc.sents]
    return np.mean(depths) if depths else 0

# ---------------------------------------------------------
# 3. Step 1 & 2: Load or Create Vector Database
# ---------------------------------------------------------
if os.path.exists(embs_cache_path) and os.path.exists(text_cache_path):
    print("Loading cached vector database...")
    all_chunk_embs = torch.load(embs_cache_path)
    with open(text_cache_path, "rb") as f:
        all_wiki_chunks = pickle.load(f)
    print(f"Loaded {len(all_wiki_chunks)} chunks from cache.")
else:
    print(f"Cache not found. Merging {args.num_shards} shards and encoding...")
    all_wiki_chunks = []
    for i in range(args.num_shards):
        path = f"/home/mshahidul/readctrl/data/wiki_chunks/wiki_chunks_{args.lang}_shard_{i}.parquet"
        if os.path.exists(path):
            df_shard = pd.read_parquet(path)
            all_wiki_chunks.extend(df_shard['text'].tolist())

    print(f"Total merged chunks: {len(all_wiki_chunks)}")
    
    # Encoding
    all_chunk_embs = model.encode(all_wiki_chunks, convert_to_tensor=True, show_progress_bar=True)
    
    # SAVE the vector database
    print("Saving vector database for future use...")
    torch.save(all_chunk_embs, embs_cache_path)
    with open(text_cache_path, "wb") as f:
        pickle.dump(all_wiki_chunks, f)
    print("Database saved successfully.")

# ---------------------------------------------------------
# 4. Step 3: Run Target Documents
# ---------------------------------------------------------
# (The rest of your target processing logic remains the same)
with open(f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{args.lang}_v1.json", "r") as f:
    res = json.load(f)

my_targets = []
for item in res:
    for key, val in item['diff_label_texts'].items():
        my_targets.append({"index": item['index'], "label": key, "text": val})

target_texts = [d['text'] for d in my_targets]
target_embs = model.encode(target_texts, convert_to_tensor=True)

print("Running semantic search...")
search_results = util.semantic_search(target_embs, all_chunk_embs, top_k=25)

processed_data = []
for i, hits in enumerate(tqdm.tqdm(search_results)):
    doc = my_targets[i]
    doc_len = len(doc['text'].split())
    
    wiki_anchor = None
    best_fallback = None
    min_delta = float('inf')

    for hit in hits:
        cand_text = all_wiki_chunks[hit['corpus_id']]
        cand_len = len(cand_text.split())
        len_diff = abs(cand_len - doc_len)
        
        if len_diff < min_delta:
            min_delta = len_diff
            best_fallback = cand_text
            
        if 0.8 <= (cand_len / doc_len) <= 1.2:
            wiki_anchor = cand_text
            break
    
    final_anchor = wiki_anchor if wiki_anchor else best_fallback

    processed_data.append({
        "index": doc['index'],
        "label": doc['label'],
        "original_doc": doc['text'],
        "wiki_anchor": final_anchor,
        "doc_fkgl": textstat.flesch_kincaid_grade(doc['text']),
        "wiki_fkgl": textstat.flesch_kincaid_grade(final_anchor),
        "doc_tree_depth": get_parse_tree_stats(doc['text']),
        "wiki_tree_depth": get_parse_tree_stats(final_anchor)
    })

final_save_path = f"/home/mshahidul/readctrl/data/data_annotator_data/new_v1/crowdsourcing_input_{args.lang}_fully_merged_v2.json"
with open(final_save_path, "w") as f:
    json.dump(processed_data, f, indent=2)

print(f"Done! Results saved to {final_save_path}")