Update app.py
Browse files
app.py
CHANGED
|
@@ -13,6 +13,7 @@ from captum.attr import LayerIntegratedGradients, TokenReferenceBase
|
|
| 13 |
from captum.attr import visualization as viz
|
| 14 |
from huggingface_hub import InferenceClient
|
| 15 |
from datetime import datetime
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
HF = 'hf'
|
|
@@ -45,19 +46,30 @@ def load_drafts():
|
|
| 45 |
return {}
|
| 46 |
return {}
|
| 47 |
|
| 48 |
-
def update_cache_row(user, pol_a, pol_b, a_list, idx, b_text, rel, inter, just):
|
| 49 |
-
"""Fires automatically on keystrokes/clicks to save progress"""
|
| 50 |
if not user or not a_list or idx >= len(a_list) or not b_text: return
|
| 51 |
curr_a = a_list[idx]
|
| 52 |
|
| 53 |
drafts = load_drafts()
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
cache_key = f"{pol_a}|{pol_b}|{curr_a}"
|
| 57 |
-
if cache_key not in drafts[user]: drafts[user][cache_key] = {}
|
| 58 |
|
| 59 |
-
# Store the exact state of this specific row
|
| 60 |
-
drafts[user][cache_key][b_text] = {
|
|
|
|
|
|
|
| 61 |
|
| 62 |
with open(DRAFT_FILE, 'w') as f:
|
| 63 |
json.dump(drafts, f)
|
|
@@ -313,6 +325,7 @@ def get_worksheet_by_number(spreadsheet, worksheet_number, format=True):
|
|
| 313 |
|
| 314 |
if 'Policy' in df.columns:
|
| 315 |
df['Policy'] = df['Policy'].ffill()
|
|
|
|
| 316 |
|
| 317 |
return df
|
| 318 |
|
|
@@ -497,6 +510,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 497 |
|
| 498 |
hf_df_state = gr.State()
|
| 499 |
user_tag_state = gr.State()
|
|
|
|
| 500 |
|
| 501 |
target_a_list_state = gr.State([])
|
| 502 |
pending_tasks_state = gr.State({})
|
|
@@ -666,8 +680,12 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 666 |
for i in range(MAX_ROWS):
|
| 667 |
_, b_text, rel_radio, _, inter_dd, just_box, _, _, _, _, _ = eval_rows[i]
|
| 668 |
|
| 669 |
-
# Gather the exact state needed to cache this row
|
| 670 |
-
inputs_to_cache = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
|
| 672 |
# Trigger cache save silently in the background on any change
|
| 673 |
rel_radio.change(fn=update_cache_row, inputs=inputs_to_cache)
|
|
@@ -678,12 +696,50 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 678 |
def authenticate(email):
|
| 679 |
user_tag, msg = get_or_create_user(email)
|
| 680 |
if not user_tag:
|
| 681 |
-
return gr.update(value=f"<font color='red'>{msg}</font>"), gr.update(visible=True), gr.update(visible=False), None, None
|
|
|
|
| 682 |
|
| 683 |
hf_df = load_hf_dataset()
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
domain_a_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_a_dd, outputs=policy_a_dd)
|
| 688 |
domain_b_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_b_dd, outputs=policy_b_dd)
|
| 689 |
|
|
@@ -716,7 +772,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 716 |
# Pull Drafts for this specific user and Target A
|
| 717 |
drafts = load_drafts()
|
| 718 |
cache_key = f"{pol_a}|{pol_b}|{curr_a_eng}"
|
| 719 |
-
user_draft = drafts.get(user_tag, {}).get(cache_key, {})
|
| 720 |
|
| 721 |
# Run model predictions for this batch
|
| 722 |
preds = get_model_predictions(curr_a_eng, bs_to_eval_eng)
|
|
@@ -937,8 +993,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
| 937 |
# CLEAR CACHE ON SUCCESSFUL SAVE
|
| 938 |
drafts = load_drafts()
|
| 939 |
cache_key = f"{pol_a}|{pol_b}|{current_a_eng}"
|
| 940 |
-
|
| 941 |
-
|
|
|
|
|
|
|
| 942 |
with open(DRAFT_FILE, 'w') as f:
|
| 943 |
json.dump(drafts, f)
|
| 944 |
|
|
|
|
| 13 |
from captum.attr import visualization as viz
|
| 14 |
from huggingface_hub import InferenceClient
|
| 15 |
from datetime import datetime
|
| 16 |
+
import uuid
|
| 17 |
|
| 18 |
|
| 19 |
HF = 'hf'
|
|
|
|
| 46 |
return {}
|
| 47 |
return {}
|
| 48 |
|
| 49 |
+
def update_cache_row(user, session_id, dom_a, pol_a, dom_b, pol_b, tar_col, ctx_col, a_list, idx, b_text, rel, inter, just):
|
| 50 |
+
"""Fires automatically on keystrokes/clicks to save progress and workspace state"""
|
| 51 |
if not user or not a_list or idx >= len(a_list) or not b_text: return
|
| 52 |
curr_a = a_list[idx]
|
| 53 |
|
| 54 |
drafts = load_drafts()
|
| 55 |
+
# Upgraded structure to hold workspace settings AND row data
|
| 56 |
+
if user not in drafts: drafts[user] = {"workspace": {}, "rows": {}}
|
| 57 |
+
|
| 58 |
+
# Save the active workspace so we can restore it on reload
|
| 59 |
+
drafts[user]["workspace"] = {
|
| 60 |
+
"session_id": session_id,
|
| 61 |
+
"dom_a": dom_a, "pol_a": pol_a,
|
| 62 |
+
"dom_b": dom_b, "pol_b": pol_b,
|
| 63 |
+
"tar_col": tar_col, "ctx_col": ctx_col
|
| 64 |
+
}
|
| 65 |
|
| 66 |
cache_key = f"{pol_a}|{pol_b}|{curr_a}"
|
| 67 |
+
if cache_key not in drafts[user]["rows"]: drafts[user]["rows"][cache_key] = {}
|
| 68 |
|
| 69 |
+
# Store the exact state of this specific row with the unique session tag
|
| 70 |
+
drafts[user]["rows"][cache_key][b_text] = {
|
| 71 |
+
"rel": rel, "inter": inter, "just": just, "session_id": session_id
|
| 72 |
+
}
|
| 73 |
|
| 74 |
with open(DRAFT_FILE, 'w') as f:
|
| 75 |
json.dump(drafts, f)
|
|
|
|
| 325 |
|
| 326 |
if 'Policy' in df.columns:
|
| 327 |
df['Policy'] = df['Policy'].ffill()
|
| 328 |
+
|
| 329 |
|
| 330 |
return df
|
| 331 |
|
|
|
|
| 510 |
|
| 511 |
hf_df_state = gr.State()
|
| 512 |
user_tag_state = gr.State()
|
| 513 |
+
session_id_state = gr.State(lambda: str(uuid.uuid4().hex[:12]))
|
| 514 |
|
| 515 |
target_a_list_state = gr.State([])
|
| 516 |
pending_tasks_state = gr.State({})
|
|
|
|
| 680 |
for i in range(MAX_ROWS):
|
| 681 |
_, b_text, rel_radio, _, inter_dd, just_box, _, _, _, _, _ = eval_rows[i]
|
| 682 |
|
| 683 |
+
# Gather the exact state needed to cache this row AND the workspace config
|
| 684 |
+
inputs_to_cache = [
|
| 685 |
+
user_tag_state, session_id_state,
|
| 686 |
+
domain_a_dd, policy_a_dd, domain_b_dd, policy_b_dd, target_col_dd, context_col_dd,
|
| 687 |
+
target_a_list_state, current_index_state, b_text, rel_radio, inter_dd, just_box
|
| 688 |
+
]
|
| 689 |
|
| 690 |
# Trigger cache save silently in the background on any change
|
| 691 |
rel_radio.change(fn=update_cache_row, inputs=inputs_to_cache)
|
|
|
|
| 696 |
def authenticate(email):
|
| 697 |
user_tag, msg = get_or_create_user(email)
|
| 698 |
if not user_tag:
|
| 699 |
+
return (gr.update(value=f"<font color='red'>{msg}</font>"), gr.update(visible=True), gr.update(visible=False), None, None,
|
| 700 |
+
gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
| 701 |
|
| 702 |
hf_df = load_hf_dataset()
|
| 703 |
+
|
| 704 |
+
# Check Cache for Session Recovery
|
| 705 |
+
drafts = load_drafts()
|
| 706 |
+
user_data = drafts.get(user_tag, {})
|
| 707 |
+
ws = user_data.get("workspace", {})
|
| 708 |
+
|
| 709 |
+
# If we found a saved workspace, restore the dropdowns!
|
| 710 |
+
if ws.get("pol_a") and ws.get("pol_b"):
|
| 711 |
+
msg += f" Restored your previous session workspace."
|
| 712 |
+
return (
|
| 713 |
+
gr.update(value=f"{msg} Loaded {len(hf_df)} annotations."),
|
| 714 |
+
gr.update(visible=False),
|
| 715 |
+
gr.update(visible=True),
|
| 716 |
+
user_tag,
|
| 717 |
+
hf_df,
|
| 718 |
+
gr.update(value=ws["dom_a"]),
|
| 719 |
+
gr.update(choices=get_policy_list(ws["dom_a"]), value=ws["pol_a"]),
|
| 720 |
+
gr.update(value=ws["dom_b"]),
|
| 721 |
+
gr.update(choices=get_policy_list(ws["dom_b"]), value=ws["pol_b"]),
|
| 722 |
+
gr.update(value=ws["tar_col"]),
|
| 723 |
+
gr.update(value=ws["ctx_col"])
|
| 724 |
+
)
|
| 725 |
+
else:
|
| 726 |
+
return (
|
| 727 |
+
gr.update(value=f"{msg} Loaded {len(hf_df)} annotations."),
|
| 728 |
+
gr.update(visible=False),
|
| 729 |
+
gr.update(visible=True),
|
| 730 |
+
user_tag,
|
| 731 |
+
hf_df,
|
| 732 |
+
gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
login_btn.click(
|
| 736 |
+
fn=authenticate,
|
| 737 |
+
inputs=[email_box],
|
| 738 |
+
outputs=[
|
| 739 |
+
login_status, login_box, app_box, user_tag_state, hf_df_state,
|
| 740 |
+
domain_a_dd, policy_a_dd, domain_b_dd, policy_b_dd, target_col_dd, context_col_dd # <-- ADDED THESE
|
| 741 |
+
]
|
| 742 |
+
)
|
| 743 |
domain_a_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_a_dd, outputs=policy_a_dd)
|
| 744 |
domain_b_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_b_dd, outputs=policy_b_dd)
|
| 745 |
|
|
|
|
| 772 |
# Pull Drafts for this specific user and Target A
|
| 773 |
drafts = load_drafts()
|
| 774 |
cache_key = f"{pol_a}|{pol_b}|{curr_a_eng}"
|
| 775 |
+
user_draft = drafts.get(user_tag, {}).get("rows", {}).get(cache_key, {})
|
| 776 |
|
| 777 |
# Run model predictions for this batch
|
| 778 |
preds = get_model_predictions(curr_a_eng, bs_to_eval_eng)
|
|
|
|
| 993 |
# CLEAR CACHE ON SUCCESSFUL SAVE
|
| 994 |
drafts = load_drafts()
|
| 995 |
cache_key = f"{pol_a}|{pol_b}|{current_a_eng}"
|
| 996 |
+
|
| 997 |
+
# Check inside the "rows" sub-dictionary
|
| 998 |
+
if user_tag in drafts and "rows" in drafts[user_tag] and cache_key in drafts[user_tag]["rows"]:
|
| 999 |
+
del drafts[user_tag]["rows"][cache_key]
|
| 1000 |
with open(DRAFT_FILE, 'w') as f:
|
| 1001 |
json.dump(drafts, f)
|
| 1002 |
|