| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| import json |
| import torch |
| import pickle |
| import gradio as gr |
| import textstat |
| from sentence_transformers import SentenceTransformer, util |
|
|
| |
| LANG_CODE = "en" |
| CHUNKS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_chunks.pkl" |
| EMBS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_embs.pt" |
| TARGET_DOCS_PATH = f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{LANG_CODE}_v1.json" |
| SAVE_PATH = f"/home/mshahidul/readctrl/data/data_annotator_data/manual_selections_{LANG_CODE}.json" |
|
|
| |
| print("Loading Model and Tensors...") |
| model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
| with open(CHUNKS_PATH, "rb") as f: |
| wiki_chunks = pickle.load(f) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| wiki_embs = torch.load(EMBS_PATH).to(device) |
|
|
| with open(TARGET_DOCS_PATH, "r") as f: |
| raw_targets = json.load(f) |
|
|
| target_list = [] |
| for item in raw_targets: |
| for label, text in item['diff_label_texts'].items(): |
| target_list.append({ |
| "index": item['index'], |
| "label": label, |
| "text": text |
| }) |
|
|
| |
| def get_resume_index(): |
| """Finds the first index in target_list that hasn't been saved yet.""" |
| if not os.path.exists(SAVE_PATH): |
| return 0 |
| |
| try: |
| with open(SAVE_PATH, "r") as f: |
| saved_data = json.load(f) |
| |
| |
| done_keys = {(d['index'], d['label']) for d in saved_data} |
| |
| for i, item in enumerate(target_list): |
| if (item['index'], item['label']) not in done_keys: |
| return i |
| return len(target_list) - 1 |
| except Exception as e: |
| print(f"Error loading save file: {e}") |
| return 0 |
|
|
| START_INDEX = get_resume_index() |
| print(f"Resuming from index: {START_INDEX}") |
|
|
| |
| def get_candidates(target_text, top_k=20): |
| query_emb = model.encode(target_text, convert_to_tensor=True).to(device) |
| hits = util.semantic_search(query_emb, wiki_embs, top_k=top_k)[0] |
| |
| candidates = [] |
| for hit in hits: |
| candidates.append(wiki_chunks[hit['corpus_id']]) |
| return candidates |
|
|
| def calculate_stats(text): |
| if not text: return "N/A" |
| wc = len(text.split()) |
| fk = textstat.flesch_kincaid_grade(text) |
| return f"π Words: {wc} | π FKGL: {fk}" |
|
|
| def save_selection(target_idx, label, original_text, selected_wiki): |
| entry = { |
| "index": target_idx, |
| "label": label, |
| "original_text": original_text, |
| "selected_wiki_anchor": selected_wiki, |
| "wiki_fkgl": textstat.flesch_kincaid_grade(selected_wiki), |
| "doc_fkgl": textstat.flesch_kincaid_grade(original_text) |
| } |
| |
| existing_data = [] |
| if os.path.exists(SAVE_PATH): |
| try: |
| with open(SAVE_PATH, "r") as f: |
| existing_data = json.load(f) |
| except: |
| existing_data = [] |
| |
| |
| existing_data = [d for d in existing_data if not (d['index'] == target_idx and d['label'] == label)] |
| existing_data.append(entry) |
| |
| with open(SAVE_PATH, "w") as f: |
| json.dump(existing_data, f, indent=2) |
| return f"β
Saved: ID {target_idx} ({label})" |
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), title="Wiki Anchor Selector") as demo: |
| gr.Markdown(f"# π ReadCtrl: Anchor Selection (Resume Mode)") |
| |
| |
| current_idx = gr.State(START_INDEX) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| target_info = gr.Markdown("### Loading...") |
| label_display = gr.Textbox(label="Target Readability Level", interactive=False) |
| display_text = gr.Textbox(label="Medical Text", lines=12, interactive=False) |
| target_stats = gr.Markdown("Stats: ...") |
| |
| with gr.Column(scale=2): |
| wiki_dropdown = gr.Dropdown( |
| label="Select Candidate Number", |
| choices=[], |
| interactive=True |
| ) |
| full_wiki_view = gr.Textbox(label="Wikipedia Chunk Preview", lines=12, interactive=False) |
| wiki_stats = gr.Markdown("Stats: ...") |
|
|
| status_msg = gr.Markdown("### *Status: Ready*") |
|
|
| with gr.Row(): |
| prev_btn = gr.Button("β¬
οΈ Previous") |
| save_btn = gr.Button("πΎ Confirm & Save", variant="primary") |
| next_btn = gr.Button("Next / Skip β‘οΈ") |
|
|
| def load_item(idx): |
| if not (0 <= idx < len(target_list)): |
| return "End", "None", "", "", gr.update(choices=[], value=None), "", "", "Finished all items!" |
| |
| doc = target_list[idx] |
| candidates = get_candidates(doc['text'], top_k=20) |
| |
| info = f"### Document {idx + 1} of {len(target_list)} (ID: {doc['index']})" |
| t_stats = calculate_stats(doc['text']) |
| |
| dropdown_choices = [(f"Candidate {i+1}", c) for i, c in enumerate(candidates)] |
| |
| return ( |
| info, |
| doc['label'].upper(), |
| doc['text'], |
| t_stats, |
| gr.update(choices=dropdown_choices, value=candidates[0]), |
| candidates[0], |
| calculate_stats(candidates[0]), |
| f"Currently viewing index {idx}" |
| ) |
|
|
| def on_dropdown_change(selected_text): |
| if not selected_text: return "", "" |
| return selected_text, calculate_stats(selected_text) |
|
|
| def handle_next(idx): |
| new_idx = min(len(target_list) - 1, idx + 1) |
| return [new_idx] + list(load_item(new_idx)) |
|
|
| def handle_prev(idx): |
| new_idx = max(0, idx - 1) |
| return [new_idx] + list(load_item(new_idx)) |
|
|
| |
| |
| demo.load(load_item, inputs=[current_idx], |
| outputs=[target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) |
| |
| wiki_dropdown.change(on_dropdown_change, inputs=wiki_dropdown, outputs=[full_wiki_view, wiki_stats]) |
| |
| save_btn.click(lambda i, t, w: save_selection(target_list[i]['index'], target_list[i]['label'], t, w), |
| inputs=[current_idx, display_text, wiki_dropdown], |
| outputs=[status_msg]) |
| |
| next_btn.click(handle_next, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) |
| prev_btn.click(handle_prev, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7861, share=True) |