akaburia commited on
Commit
220e04a
·
verified ·
1 Parent(s): 575d747

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -15
app.py CHANGED
@@ -13,6 +13,7 @@ from captum.attr import LayerIntegratedGradients, TokenReferenceBase
13
  from captum.attr import visualization as viz
14
  from huggingface_hub import InferenceClient
15
  from datetime import datetime
 
16
 
17
 
18
  HF = 'hf'
@@ -45,19 +46,30 @@ def load_drafts():
45
  return {}
46
  return {}
47
 
48
- def update_cache_row(user, pol_a, pol_b, a_list, idx, b_text, rel, inter, just):
49
- """Fires automatically on keystrokes/clicks to save progress"""
50
  if not user or not a_list or idx >= len(a_list) or not b_text: return
51
  curr_a = a_list[idx]
52
 
53
  drafts = load_drafts()
54
- if user not in drafts: drafts[user] = {}
 
 
 
 
 
 
 
 
 
55
 
56
  cache_key = f"{pol_a}|{pol_b}|{curr_a}"
57
- if cache_key not in drafts[user]: drafts[user][cache_key] = {}
58
 
59
- # Store the exact state of this specific row
60
- drafts[user][cache_key][b_text] = {"rel": rel, "inter": inter, "just": just}
 
 
61
 
62
  with open(DRAFT_FILE, 'w') as f:
63
  json.dump(drafts, f)
@@ -313,6 +325,7 @@ def get_worksheet_by_number(spreadsheet, worksheet_number, format=True):
313
 
314
  if 'Policy' in df.columns:
315
  df['Policy'] = df['Policy'].ffill()
 
316
 
317
  return df
318
 
@@ -497,6 +510,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
497
 
498
  hf_df_state = gr.State()
499
  user_tag_state = gr.State()
 
500
 
501
  target_a_list_state = gr.State([])
502
  pending_tasks_state = gr.State({})
@@ -666,8 +680,12 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
666
  for i in range(MAX_ROWS):
667
  _, b_text, rel_radio, _, inter_dd, just_box, _, _, _, _, _ = eval_rows[i]
668
 
669
- # Gather the exact state needed to cache this row
670
- inputs_to_cache = [user_tag_state, policy_a_dd, policy_b_dd, target_a_list_state, current_index_state, b_text, rel_radio, inter_dd, just_box]
 
 
 
 
671
 
672
  # Trigger cache save silently in the background on any change
673
  rel_radio.change(fn=update_cache_row, inputs=inputs_to_cache)
@@ -678,12 +696,50 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
678
  def authenticate(email):
679
  user_tag, msg = get_or_create_user(email)
680
  if not user_tag:
681
- return gr.update(value=f"<font color='red'>{msg}</font>"), gr.update(visible=True), gr.update(visible=False), None, None
 
682
 
683
  hf_df = load_hf_dataset()
684
- return gr.update(value=f"{msg} Loaded {len(hf_df)} annotations."), gr.update(visible=False), gr.update(visible=True), user_tag, hf_df
685
-
686
- login_btn.click(fn=authenticate, inputs=[email_box], outputs=[login_status, login_box, app_box, user_tag_state, hf_df_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
  domain_a_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_a_dd, outputs=policy_a_dd)
688
  domain_b_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_b_dd, outputs=policy_b_dd)
689
 
@@ -716,7 +772,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
716
  # Pull Drafts for this specific user and Target A
717
  drafts = load_drafts()
718
  cache_key = f"{pol_a}|{pol_b}|{curr_a_eng}"
719
- user_draft = drafts.get(user_tag, {}).get(cache_key, {})
720
 
721
  # Run model predictions for this batch
722
  preds = get_model_predictions(curr_a_eng, bs_to_eval_eng)
@@ -937,8 +993,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
937
  # CLEAR CACHE ON SUCCESSFUL SAVE
938
  drafts = load_drafts()
939
  cache_key = f"{pol_a}|{pol_b}|{current_a_eng}"
940
- if user_tag in drafts and cache_key in drafts[user_tag]:
941
- del drafts[user_tag][cache_key]
 
 
942
  with open(DRAFT_FILE, 'w') as f:
943
  json.dump(drafts, f)
944
 
 
13
  from captum.attr import visualization as viz
14
  from huggingface_hub import InferenceClient
15
  from datetime import datetime
16
+ import uuid
17
 
18
 
19
  HF = 'hf'
 
46
  return {}
47
  return {}
48
 
49
+ def update_cache_row(user, session_id, dom_a, pol_a, dom_b, pol_b, tar_col, ctx_col, a_list, idx, b_text, rel, inter, just):
50
+ """Fires automatically on keystrokes/clicks to save progress and workspace state"""
51
  if not user or not a_list or idx >= len(a_list) or not b_text: return
52
  curr_a = a_list[idx]
53
 
54
  drafts = load_drafts()
55
+ # Upgraded structure to hold workspace settings AND row data
56
+ if user not in drafts: drafts[user] = {"workspace": {}, "rows": {}}
57
+
58
+ # Save the active workspace so we can restore it on reload
59
+ drafts[user]["workspace"] = {
60
+ "session_id": session_id,
61
+ "dom_a": dom_a, "pol_a": pol_a,
62
+ "dom_b": dom_b, "pol_b": pol_b,
63
+ "tar_col": tar_col, "ctx_col": ctx_col
64
+ }
65
 
66
  cache_key = f"{pol_a}|{pol_b}|{curr_a}"
67
+ if cache_key not in drafts[user]["rows"]: drafts[user]["rows"][cache_key] = {}
68
 
69
+ # Store the exact state of this specific row with the unique session tag
70
+ drafts[user]["rows"][cache_key][b_text] = {
71
+ "rel": rel, "inter": inter, "just": just, "session_id": session_id
72
+ }
73
 
74
  with open(DRAFT_FILE, 'w') as f:
75
  json.dump(drafts, f)
 
325
 
326
  if 'Policy' in df.columns:
327
  df['Policy'] = df['Policy'].ffill()
328
+
329
 
330
  return df
331
 
 
510
 
511
  hf_df_state = gr.State()
512
  user_tag_state = gr.State()
513
+ session_id_state = gr.State(lambda: str(uuid.uuid4().hex[:12]))
514
 
515
  target_a_list_state = gr.State([])
516
  pending_tasks_state = gr.State({})
 
680
  for i in range(MAX_ROWS):
681
  _, b_text, rel_radio, _, inter_dd, just_box, _, _, _, _, _ = eval_rows[i]
682
 
683
+ # Gather the exact state needed to cache this row AND the workspace config
684
+ inputs_to_cache = [
685
+ user_tag_state, session_id_state,
686
+ domain_a_dd, policy_a_dd, domain_b_dd, policy_b_dd, target_col_dd, context_col_dd,
687
+ target_a_list_state, current_index_state, b_text, rel_radio, inter_dd, just_box
688
+ ]
689
 
690
  # Trigger cache save silently in the background on any change
691
  rel_radio.change(fn=update_cache_row, inputs=inputs_to_cache)
 
696
  def authenticate(email):
697
  user_tag, msg = get_or_create_user(email)
698
  if not user_tag:
699
+ return (gr.update(value=f"<font color='red'>{msg}</font>"), gr.update(visible=True), gr.update(visible=False), None, None,
700
+ gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
701
 
702
  hf_df = load_hf_dataset()
703
+
704
+ # Check Cache for Session Recovery
705
+ drafts = load_drafts()
706
+ user_data = drafts.get(user_tag, {})
707
+ ws = user_data.get("workspace", {})
708
+
709
+ # If we found a saved workspace, restore the dropdowns!
710
+ if ws.get("pol_a") and ws.get("pol_b"):
711
+ msg += f" Restored your previous session workspace."
712
+ return (
713
+ gr.update(value=f"{msg} Loaded {len(hf_df)} annotations."),
714
+ gr.update(visible=False),
715
+ gr.update(visible=True),
716
+ user_tag,
717
+ hf_df,
718
+ gr.update(value=ws["dom_a"]),
719
+ gr.update(choices=get_policy_list(ws["dom_a"]), value=ws["pol_a"]),
720
+ gr.update(value=ws["dom_b"]),
721
+ gr.update(choices=get_policy_list(ws["dom_b"]), value=ws["pol_b"]),
722
+ gr.update(value=ws["tar_col"]),
723
+ gr.update(value=ws["ctx_col"])
724
+ )
725
+ else:
726
+ return (
727
+ gr.update(value=f"{msg} Loaded {len(hf_df)} annotations."),
728
+ gr.update(visible=False),
729
+ gr.update(visible=True),
730
+ user_tag,
731
+ hf_df,
732
+ gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
733
+ )
734
+
735
+ login_btn.click(
736
+ fn=authenticate,
737
+ inputs=[email_box],
738
+ outputs=[
739
+ login_status, login_box, app_box, user_tag_state, hf_df_state,
740
+ domain_a_dd, policy_a_dd, domain_b_dd, policy_b_dd, target_col_dd, context_col_dd # <-- ADDED THESE
741
+ ]
742
+ )
743
  domain_a_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_a_dd, outputs=policy_a_dd)
744
  domain_b_dd.change(fn=lambda d: gr.update(choices=get_policy_list(d), value=None), inputs=domain_b_dd, outputs=policy_b_dd)
745
 
 
772
  # Pull Drafts for this specific user and Target A
773
  drafts = load_drafts()
774
  cache_key = f"{pol_a}|{pol_b}|{curr_a_eng}"
775
+ user_draft = drafts.get(user_tag, {}).get("rows", {}).get(cache_key, {})
776
 
777
  # Run model predictions for this batch
778
  preds = get_model_predictions(curr_a_eng, bs_to_eval_eng)
 
993
  # CLEAR CACHE ON SUCCESSFUL SAVE
994
  drafts = load_drafts()
995
  cache_key = f"{pol_a}|{pol_b}|{current_a_eng}"
996
+
997
+ # Check inside the "rows" sub-dictionary
998
+ if user_tag in drafts and "rows" in drafts[user_tag] and cache_key in drafts[user_tag]["rows"]:
999
+ del drafts[user_tag]["rows"][cache_key]
1000
  with open(DRAFT_FILE, 'w') as f:
1001
  json.dump(drafts, f)
1002