import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "0" import json import torch import pickle import gradio as gr import textstat from sentence_transformers import SentenceTransformer, util # --- Configuration & Paths --- LANG_CODE = "en" CHUNKS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_chunks.pkl" EMBS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_embs.pt" TARGET_DOCS_PATH = f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{LANG_CODE}_v1.json" SAVE_PATH = f"/home/mshahidul/readctrl/data/data_annotator_data/manual_selections_{LANG_CODE}.json" # --- 1. Load Resources --- print("Loading Model and Tensors...") model = SentenceTransformer('all-MiniLM-L6-v2') with open(CHUNKS_PATH, "rb") as f: wiki_chunks = pickle.load(f) device = "cuda" if torch.cuda.is_available() else "cpu" wiki_embs = torch.load(EMBS_PATH).to(device) with open(TARGET_DOCS_PATH, "r") as f: raw_targets = json.load(f) target_list = [] for item in raw_targets: for label, text in item['diff_label_texts'].items(): target_list.append({ "index": item['index'], "label": label, "text": text }) # --- 2. Logic Functions --- def get_candidates(target_text, top_k=20): query_emb = model.encode(target_text, convert_to_tensor=True).to(device) hits = util.semantic_search(query_emb, wiki_embs, top_k=top_k)[0] candidates = [] for hit in hits: candidates.append(wiki_chunks[hit['corpus_id']]) return candidates def calculate_stats(text): if not text: return "N/A" wc = len(text.split()) fk = textstat.flesch_kincaid_grade(text) return f"📏 Words: {wc} | 🎓 FKGL: {fk}" def save_selection(target_idx, label, original_text, selected_wiki): entry = { "index": target_idx, "label": label, "original_text": original_text, "selected_wiki_anchor": selected_wiki, "wiki_fkgl": textstat.flesch_kincaid_grade(selected_wiki), "doc_fkgl": textstat.flesch_kincaid_grade(original_text) } existing_data = [] if os.path.exists(SAVE_PATH): try: with open(SAVE_PATH, "r") as f: existing_data = json.load(f) except: existing_data = [] existing_data = [d for d in existing_data if not (d['index'] == target_idx and d['label'] == label)] existing_data.append(entry) with open(SAVE_PATH, "w") as f: json.dump(existing_data, f, indent=2) gr.Info(f"Successfully saved ID {target_idx} ({label})") return f"✅ Saved: ID {target_idx} ({label})" # --- 3. Gradio UI --- with gr.Blocks(theme=gr.themes.Soft(), title="Wiki Anchor Selector") as demo: gr.Markdown(f"# 🔍 ReadCtrl: Anchor Selection (Numeric View)") current_idx = gr.State(0) with gr.Row(): # Left Panel with gr.Column(scale=1): target_info = gr.Markdown("### Loading...") # Changed from HighlightedText to Textbox for stability label_display = gr.Textbox(label="Target Readability Level", interactive=False) display_text = gr.Textbox(label="Medical Text", lines=12, interactive=False) target_stats = gr.Markdown("Stats: ...") # Right Panel with gr.Column(scale=2): wiki_dropdown = gr.Dropdown( label="Select Candidate Number", choices=[], interactive=True ) full_wiki_view = gr.Textbox(label="Wikipedia Chunk Preview", lines=12, interactive=False) wiki_stats = gr.Markdown("Stats: ...") status_msg = gr.Markdown("### *Status: Ready*") with gr.Row(): prev_btn = gr.Button("⬅️ Previous") save_btn = gr.Button("💾 Confirm & Save", variant="primary") next_btn = gr.Button("Next / Skip ➡️") # --- UI Logic --- def load_item(idx): if not (0 <= idx < len(target_list)): return "End", "None", "", "", gr.update(choices=[], value=None), "", "", "Finished!" doc = target_list[idx] candidates = get_candidates(doc['text'], top_k=20) info = f"### Document {idx + 1} of {len(target_list)} (ID: {doc['index']})" t_stats = calculate_stats(doc['text']) dropdown_choices = [(f"Candidate {i+1}", c) for i, c in enumerate(candidates)] return ( info, doc['label'].upper(), # Simple string for the Label Textbox doc['text'], t_stats, gr.update(choices=dropdown_choices, value=candidates[0]), candidates[0], calculate_stats(candidates[0]), "" ) def on_dropdown_change(selected_text): if not selected_text: return "", "" return selected_text, calculate_stats(selected_text) def handle_next(idx): new_idx = min(len(target_list) - 1, idx + 1) return [new_idx] + list(load_item(new_idx)) def handle_prev(idx): new_idx = max(0, idx - 1) return [new_idx] + list(load_item(new_idx)) # --- Event Bindings --- demo.load(load_item, inputs=[current_idx], outputs=[target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) wiki_dropdown.change(on_dropdown_change, inputs=wiki_dropdown, outputs=[full_wiki_view, wiki_stats]) save_btn.click(lambda i, t, w: save_selection(target_list[i]['index'], target_list[i]['label'], t, w), inputs=[current_idx, display_text, wiki_dropdown], outputs=[status_msg]) next_btn.click(handle_next, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) prev_btn.click(handle_prev, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7861,share=True)