readctrl / code /vectordb_build /vector_db_select_v2.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import torch
import pickle
import gradio as gr
import textstat
from sentence_transformers import SentenceTransformer, util
# --- Configuration & Paths ---
LANG_CODE = "en"
CHUNKS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_chunks.pkl"
EMBS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_embs.pt"
TARGET_DOCS_PATH = f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{LANG_CODE}_v1.json"
SAVE_PATH = f"/home/mshahidul/readctrl/data/data_annotator_data/manual_selections_{LANG_CODE}.json"
# --- 1. Load Resources ---
print("Loading Model and Tensors...")
model = SentenceTransformer('all-MiniLM-L6-v2')
with open(CHUNKS_PATH, "rb") as f:
wiki_chunks = pickle.load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
wiki_embs = torch.load(EMBS_PATH).to(device)
with open(TARGET_DOCS_PATH, "r") as f:
raw_targets = json.load(f)
target_list = []
for item in raw_targets:
for label, text in item['diff_label_texts'].items():
target_list.append({
"index": item['index'],
"label": label,
"text": text
})
# --- 2. Resume Logic ---
def get_resume_index():
"""Finds the first index in target_list that hasn't been saved yet."""
if not os.path.exists(SAVE_PATH):
return 0
try:
with open(SAVE_PATH, "r") as f:
saved_data = json.load(f)
# Create a set of (index, label) tuples that are already done
done_keys = {(d['index'], d['label']) for d in saved_data}
for i, item in enumerate(target_list):
if (item['index'], item['label']) not in done_keys:
return i
return len(target_list) - 1 # All done
except Exception as e:
print(f"Error loading save file: {e}")
return 0
START_INDEX = get_resume_index()
print(f"Resuming from index: {START_INDEX}")
# --- 3. Logic Functions ---
def get_candidates(target_text, top_k=20):
query_emb = model.encode(target_text, convert_to_tensor=True).to(device)
hits = util.semantic_search(query_emb, wiki_embs, top_k=top_k)[0]
candidates = []
for hit in hits:
candidates.append(wiki_chunks[hit['corpus_id']])
return candidates
def calculate_stats(text):
if not text: return "N/A"
wc = len(text.split())
fk = textstat.flesch_kincaid_grade(text)
return f"πŸ“ Words: {wc} | πŸŽ“ FKGL: {fk}"
def save_selection(target_idx, label, original_text, selected_wiki):
entry = {
"index": target_idx,
"label": label,
"original_text": original_text,
"selected_wiki_anchor": selected_wiki,
"wiki_fkgl": textstat.flesch_kincaid_grade(selected_wiki),
"doc_fkgl": textstat.flesch_kincaid_grade(original_text)
}
existing_data = []
if os.path.exists(SAVE_PATH):
try:
with open(SAVE_PATH, "r") as f:
existing_data = json.load(f)
except:
existing_data = []
# Overwrite if exists, otherwise append
existing_data = [d for d in existing_data if not (d['index'] == target_idx and d['label'] == label)]
existing_data.append(entry)
with open(SAVE_PATH, "w") as f:
json.dump(existing_data, f, indent=2)
return f"βœ… Saved: ID {target_idx} ({label})"
# --- 4. Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="Wiki Anchor Selector") as demo:
gr.Markdown(f"# πŸ” ReadCtrl: Anchor Selection (Resume Mode)")
# Initialize state with the calculated START_INDEX
current_idx = gr.State(START_INDEX)
with gr.Row():
with gr.Column(scale=1):
target_info = gr.Markdown("### Loading...")
label_display = gr.Textbox(label="Target Readability Level", interactive=False)
display_text = gr.Textbox(label="Medical Text", lines=12, interactive=False)
target_stats = gr.Markdown("Stats: ...")
with gr.Column(scale=2):
wiki_dropdown = gr.Dropdown(
label="Select Candidate Number",
choices=[],
interactive=True
)
full_wiki_view = gr.Textbox(label="Wikipedia Chunk Preview", lines=12, interactive=False)
wiki_stats = gr.Markdown("Stats: ...")
status_msg = gr.Markdown("### *Status: Ready*")
with gr.Row():
prev_btn = gr.Button("⬅️ Previous")
save_btn = gr.Button("πŸ’Ύ Confirm & Save", variant="primary")
next_btn = gr.Button("Next / Skip ➑️")
def load_item(idx):
if not (0 <= idx < len(target_list)):
return "End", "None", "", "", gr.update(choices=[], value=None), "", "", "Finished all items!"
doc = target_list[idx]
candidates = get_candidates(doc['text'], top_k=20)
info = f"### Document {idx + 1} of {len(target_list)} (ID: {doc['index']})"
t_stats = calculate_stats(doc['text'])
dropdown_choices = [(f"Candidate {i+1}", c) for i, c in enumerate(candidates)]
return (
info,
doc['label'].upper(),
doc['text'],
t_stats,
gr.update(choices=dropdown_choices, value=candidates[0]),
candidates[0],
calculate_stats(candidates[0]),
f"Currently viewing index {idx}"
)
def on_dropdown_change(selected_text):
if not selected_text: return "", ""
return selected_text, calculate_stats(selected_text)
def handle_next(idx):
new_idx = min(len(target_list) - 1, idx + 1)
return [new_idx] + list(load_item(new_idx))
def handle_prev(idx):
new_idx = max(0, idx - 1)
return [new_idx] + list(load_item(new_idx))
# --- Event Bindings ---
# Trigger load_item on page load using the START_INDEX from state
demo.load(load_item, inputs=[current_idx],
outputs=[target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg])
wiki_dropdown.change(on_dropdown_change, inputs=wiki_dropdown, outputs=[full_wiki_view, wiki_stats])
save_btn.click(lambda i, t, w: save_selection(target_list[i]['index'], target_list[i]['label'], t, w),
inputs=[current_idx, display_text, wiki_dropdown],
outputs=[status_msg])
next_btn.click(handle_next, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg])
prev_btn.click(handle_prev, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7861, share=True)