Fredrik Sitje commited on
Commit
077c9e3
·
1 Parent(s): 058983a

Refactor Streamlit app to support jurisdiction-specific data handling. Updated functions to accept jurisdiction as a parameter, modified file paths for user data and grading templates, and added jurisdiction selection in the login process. This enhances the app's flexibility for different jurisdictions.

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +128 -101
src/streamlit_app.py CHANGED
@@ -15,9 +15,12 @@ HF_TOKEN = os.getenv("HF_TOKEN")
15
  # Fallback to st.secrets for local development (if not found in environment)
16
  if not HF_DATASET_REPO:
17
  try:
18
- HF_DATASET_REPO = st.secrets.get("HF_DATASET_REPO", "TransLegal/en-us-grading-answers")
19
  except Exception:
20
- HF_DATASET_REPO = "TransLegal/en-us-grading-answers"
 
 
 
21
 
22
  if not HF_TOKEN:
23
  try:
@@ -33,35 +36,32 @@ if not HF_TOKEN:
33
  @st.cache_resource
34
  def get_hf_api():
35
  """Get cached Hugging Face API client - only initializes once per session"""
36
- try:
37
- login(token=HF_TOKEN)
38
  return HfApi(token=HF_TOKEN)
39
- except Exception as e:
40
- st.error(f"❌ **Error initializing Hugging Face API**: {str(e)}")
41
- st.stop()
42
 
43
  # Initialize HF API - cached to avoid re-initialization on every rerun
44
  hf_api = get_hf_api()
45
 
46
  @st.cache_data
47
- def load_grading_template():
48
- """Load grading template from Hugging Face Dataset"""
49
  try:
50
  file_path = hf_hub_download(
51
  repo_id=HF_DATASET_REPO,
52
- filename="grading_template.parquet",
53
  repo_type="dataset",
54
  token=HF_TOKEN
55
  )
56
  return pd.read_parquet(file_path)
57
  except Exception as e:
58
  st.error(f"❌ **Error loading grading template from Hugging Face Dataset**: {str(e)}")
59
- st.error(f"Please ensure the file `grading_template.parquet` exists in the dataset repository: {HF_DATASET_REPO}")
60
  st.stop()
61
 
62
- # Load data from the grading template
63
- df = load_grading_template()
64
-
65
  # Assessment options with descriptive labels
66
  ASSESSMENT_OPTIONS = [
67
  "Perfect",
@@ -138,12 +138,12 @@ def hash_password(password):
138
  return hashlib.sha256(password.encode()).hexdigest()
139
 
140
  @st.cache_data
141
- def load_users():
142
- """Load user credentials from Hugging Face Dataset"""
143
  try:
144
  file_path = hf_hub_download(
145
  repo_id=HF_DATASET_REPO,
146
- filename="users/users.json",
147
  repo_type="dataset",
148
  token=HF_TOKEN
149
  )
@@ -153,8 +153,8 @@ def load_users():
153
  # File doesn't exist yet (first run), return empty dict
154
  return {}
155
 
156
- def save_users(users):
157
- """Save user credentials to Hugging Face Dataset"""
158
  try:
159
  with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
160
  json.dump(users, f, indent=2)
@@ -162,7 +162,7 @@ def save_users(users):
162
 
163
  hf_api.upload_file(
164
  path_or_fileobj=temp_path,
165
- path_in_repo="users/users.json",
166
  repo_id=HF_DATASET_REPO,
167
  repo_type="dataset",
168
  token=HF_TOKEN
@@ -170,7 +170,7 @@ def save_users(users):
170
  os.unlink(temp_path)
171
 
172
  # Clear cache for users to ensure fresh data on next load
173
- load_users.clear()
174
 
175
  return True
176
  except Exception as e:
@@ -178,25 +178,26 @@ def save_users(users):
178
  raise
179
 
180
  @st.cache_data(ttl=3600) # Cache for 1 hour as safety measure
181
- def load_user_data(username):
182
- """Load user's answer data from Hugging Face Dataset"""
183
  try:
184
  file_path = hf_hub_download(
185
  repo_id=HF_DATASET_REPO,
186
- filename=f"users/{username}_answers.parquet",
187
  repo_type="dataset",
188
  token=HF_TOKEN
189
  )
190
  return pd.read_parquet(file_path)
191
  except Exception:
192
- # File doesn't exist yet (new user), create new dataframe
 
193
  user_df = df.copy()
194
  user_df['legal_accuracy_score'] = None
195
  user_df['time_stamp'] = None
196
  return user_df
197
 
198
- def save_user_data(username, user_df, commit_message=None):
199
- """Save user's answer data to Hugging Face Dataset"""
200
  try:
201
  with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f:
202
  user_df.to_parquet(f.name, index=False)
@@ -204,7 +205,7 @@ def save_user_data(username, user_df, commit_message=None):
204
 
205
  upload_kwargs = {
206
  'path_or_fileobj': temp_path,
207
- 'path_in_repo': f"users/{username}_answers.parquet",
208
  'repo_id': HF_DATASET_REPO,
209
  'repo_type': "dataset",
210
  'token': HF_TOKEN
@@ -217,18 +218,18 @@ def save_user_data(username, user_df, commit_message=None):
217
  hf_api.upload_file(**upload_kwargs)
218
  os.unlink(temp_path)
219
 
220
- # Clear cache for this user to ensure fresh data on next load
221
- load_user_data.clear(username)
222
 
223
  return True
224
  except Exception as e:
225
  st.error(f"❌ **Error saving user data to Hugging Face Dataset**: {str(e)}")
226
  raise
227
 
228
- def update_user_answer(username, term, category, subcategory, question, answer, score):
229
  """Update a specific answer in the user's data (deprecated - use update_category_answers for bulk updates)"""
230
  try:
231
- user_df = load_user_data(username)
232
 
233
  # Find the matching row
234
  mask = (
@@ -241,7 +242,7 @@ def update_user_answer(username, term, category, subcategory, question, answer,
241
 
242
  if mask.any():
243
  user_df.loc[mask, 'legal_accuracy_score'] = score
244
- save_user_data(username, user_df)
245
  return True
246
  else:
247
  print(f"Warning: Could not find matching row for: {term}, {category}, {subcategory}, {question}")
@@ -273,7 +274,7 @@ def auto_score_unknown_answers(username, term, category, df):
273
  return [(row['subcategory'], row['question'], row['answer'], "NA")
274
  for _, row in unknown_rows.iterrows()]
275
 
276
- def auto_score_all_unknown_answers_for_new_user(username):
277
  """
278
  Automatically score all Unknown answers for all categories when a new user is created.
279
  This runs in the background during account creation.
@@ -297,7 +298,7 @@ def auto_score_all_unknown_answers_for_new_user(username):
297
  return True
298
 
299
  # Load user dataframe once
300
- user_df = load_user_data(username)
301
 
302
  # Get current timestamp once for all updates
303
  current_timestamp = pd.Timestamp.now()
@@ -327,14 +328,14 @@ def auto_score_all_unknown_answers_for_new_user(username):
327
 
328
  # Save once with a single commit message
329
  commit_message = f"Auto-score all Unknown answers for new user {username}"
330
- save_user_data(username, user_df, commit_message=commit_message)
331
 
332
  return True
333
  except Exception as e:
334
  print(f"Error auto-scoring Unknown answers for new user {username}: {str(e)}")
335
  return False
336
 
337
- def update_category_answers(username, term, category, answers_list, commit_message=None):
338
  """
339
  Update all answers for a category in a single commit.
340
 
@@ -343,6 +344,7 @@ def update_category_answers(username, term, category, answers_list, commit_messa
343
  term: Term name
344
  category: Category name
345
  answers_list: List of tuples (subcategory, question, answer, score)
 
346
  commit_message: Optional commit message (auto-generated if None)
347
 
348
  Returns:
@@ -350,7 +352,7 @@ def update_category_answers(username, term, category, answers_list, commit_messa
350
  """
351
  try:
352
  # Load user dataframe once
353
- user_df = load_user_data(username)
354
 
355
  # Get current timestamp once for all updates in this category
356
  current_timestamp = pd.Timestamp.now()
@@ -383,15 +385,15 @@ def update_category_answers(username, term, category, answers_list, commit_messa
383
  commit_message = f"Update answers for {username} - {term} - {category}"
384
 
385
  # Save once with commit message
386
- save_user_data(username, user_df, commit_message=commit_message)
387
  return True
388
  except Exception as e:
389
  print(f"Error updating category answers: {str(e)}")
390
  return False
391
 
392
- def get_user_answer(username, term, category, subcategory, question, answer):
393
  """Get user's answer for a specific question"""
394
- user_df = load_user_data(username)
395
 
396
  mask = (
397
  (user_df['term'] == term) &
@@ -407,12 +409,12 @@ def get_user_answer(username, term, category, subcategory, question, answer):
407
  return score
408
  return None
409
 
410
- def find_first_unanswered_category(username):
411
  """Find the first category that hasn't been fully answered"""
412
- user_df = load_user_data(username)
413
 
414
- # Use the global term_category_pairs to ensure consistent ordering
415
- # This matches the order used in the main application
416
 
417
  for idx, (term, category) in enumerate(term_category_pairs):
418
  # Get all subcategories for this term-category pair from base df
@@ -445,11 +447,13 @@ def find_first_unanswered_category(username):
445
 
446
  return len(term_category_pairs) # All answered, return last index
447
 
448
- def restore_submitted_status(username):
449
  """Restore submitted status for categories that have all answers in parquet file"""
450
- user_df = load_user_data(username)
 
 
 
451
 
452
- # Use the global term_category_pairs to ensure consistent ordering
453
  submitted_pairs = set()
454
  for idx, (term, category) in enumerate(term_category_pairs):
455
  pair_key = f"{term}_{category}_{idx}"
@@ -585,22 +589,19 @@ def get_term_category_pairs(df):
585
  return [(term, category) for term, category in all_pairs
586
  if category_has_subcategories(term, category, df)]
587
 
588
- # Create a list of unique (term, category) pairs for navigation
589
- term_category_pairs = get_term_category_pairs(df)
590
- total_pairs = len(term_category_pairs)
591
-
592
- # Cache for Term instances
593
  term_cache = {}
594
 
595
- def get_term_instance(term_name):
596
- """Get or create a Term instance"""
597
- if term_name not in term_cache:
598
- term_cache[term_name] = Term(term_name, df)
599
- return term_cache[term_name]
 
600
 
601
- def get_category_for_pair(term_name, category_name):
602
  """Get Category instance for a term-category pair"""
603
- term = get_term_instance(term_name)
604
  return term.get_category_by_name(category_name)
605
 
606
  # Initialize session state
@@ -608,6 +609,8 @@ if 'logged_in' not in st.session_state:
608
  st.session_state.logged_in = False
609
  if 'username' not in st.session_state:
610
  st.session_state.username = None
 
 
611
  if 'current_index' not in st.session_state:
612
  st.session_state.current_index = 0
613
  if 'show_term_complete' not in st.session_state:
@@ -636,7 +639,11 @@ if 'has_unsaved_changes' not in st.session_state:
636
  # Login page
637
  if not st.session_state.logged_in:
638
  st.markdown("# Login")
639
- st.markdown("Please enter your username and password to continue.")
 
 
 
 
640
 
641
  username = st.text_input("Username")
642
  password = st.text_input("Password", type="password")
@@ -644,44 +651,58 @@ if not st.session_state.logged_in:
644
  col1, col2 = st.columns(2)
645
  with col1:
646
  if st.button("Login", type="primary", use_container_width=True):
647
- users = load_users()
648
-
649
- if username in users:
650
- # Existing user - check password
651
- if users[username]['password'] == hash_password(password):
652
- st.session_state.logged_in = True
653
- st.session_state.username = username
654
- # Restore submitted status for previously submitted categories
655
- st.session_state.submitted_pairs = restore_submitted_status(username)
656
- # Find first unanswered category and resume there
657
- resume_index = find_first_unanswered_category(username)
658
- st.session_state.current_index = resume_index
659
- st.rerun()
660
- else:
661
- st.error("Incorrect password")
662
  else:
663
- # Username not found - require registration
664
- st.error("Username not found. Please register first using the 'Register New User' button.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
 
666
  with col2:
667
  if st.button("Register New User", use_container_width=True):
668
- users = load_users()
669
- if username in users:
670
- st.error("Username already exists")
671
- elif username and password:
672
- users[username] = {'password': hash_password(password)}
673
- save_users(users)
674
- # Auto-score all Unknown answers for the new user in the background
675
- auto_score_all_unknown_answers_for_new_user(username)
676
- st.success("User registered successfully! Please click Login.")
677
  else:
678
- st.error("Please enter both username and password")
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
  # Main application (only shown if logged in)
681
  elif st.session_state.logged_in:
682
  username = st.session_state.username
 
683
  current_index = st.session_state.current_index
684
 
 
 
 
685
  # Debug info (can be removed in production)
686
  with st.sidebar:
687
  with st.expander("Debug Info"):
@@ -689,8 +710,9 @@ elif st.session_state.logged_in:
689
  st.write(f"HF Token configured: {HF_TOKEN is not None}")
690
  st.write(f"HF API initialized: {hf_api is not None}")
691
  if username:
692
- st.write(f"User parquet file: `users/{username}_answers.parquet`")
693
- st.write(f"Users file: `users/users.json`")
 
694
 
695
  # Check if we should show the annotation guide first
696
  if st.session_state.show_guide:
@@ -829,10 +851,14 @@ elif st.session_state.logged_in:
829
  st.session_state.next_term = None
830
  st.rerun()
831
 
 
 
 
 
832
  elif current_index < total_pairs:
833
  term_name, category_name = term_category_pairs[current_index]
834
- category = get_category_for_pair(term_name, category_name)
835
- term = get_term_instance(term_name)
836
 
837
  # Safety check: skip if category has no subcategories (shouldn't happen due to filtering, but just in case)
838
  if not category or len(category.subcategories) == 0:
@@ -858,7 +884,7 @@ elif st.session_state.logged_in:
858
  # Check visible subcategories
859
  for i, subcat in enumerate(category.subcategories):
860
  saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
861
- subcat.question, subcat.answer)
862
  if saved_score is None:
863
  category_fully_answered = False
864
  break
@@ -867,7 +893,7 @@ elif st.session_state.logged_in:
867
  if category_fully_answered:
868
  unknown_answers = auto_score_unknown_answers(username, term_name, category_name, df)
869
  for subcategory, question, answer, score in unknown_answers:
870
- saved_score = get_user_answer(username, term_name, category_name, subcategory, question, answer)
871
  if saved_score is None:
872
  category_fully_answered = False
873
  break
@@ -890,7 +916,7 @@ elif st.session_state.logged_in:
890
  st.session_state.original_selections[pair_key][radio_key] = st.session_state[radio_key]
891
  else:
892
  saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
893
- subcat.question, subcat.answer)
894
  if saved_score is not None:
895
  score_to_option = {v: k for k, v in ASSESSMENT_TO_SCORE.items()}
896
  if saved_score in score_to_option:
@@ -929,7 +955,7 @@ elif st.session_state.logged_in:
929
  # Get saved value to determine if we should set a default index
930
  # Don't pre-set session_state - only use index parameter to avoid conflicts
931
  saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
932
- subcat.question, subcat.answer)
933
  default_index = None
934
 
935
  # Check if there's a saved value in parquet file
@@ -1014,7 +1040,7 @@ elif st.session_state.logged_in:
1014
  st.session_state.back_current_index = current_index
1015
  st.session_state.show_term_back_warning = True
1016
  st.session_state.back_current_term = term.formatted_name
1017
- prev_term = get_term_instance(prev_term_name)
1018
  st.session_state.back_previous_term = prev_term.formatted_name
1019
  else:
1020
  # Same term, just move back
@@ -1058,7 +1084,7 @@ elif st.session_state.logged_in:
1058
 
1059
  # Update all answers in a single commit
1060
  commit_message = f"Update answers for {username} - {term_name} - {category_name}"
1061
- save_success = update_category_answers(username, term_name, category_name, answers_list, commit_message)
1062
 
1063
  if save_success:
1064
  # Save current selections as new originals
@@ -1090,7 +1116,7 @@ elif st.session_state.logged_in:
1090
  # Update all answers in a single commit
1091
  # Note: Unknown answers are already auto-scored during account creation
1092
  commit_message = f"Update answers for {username} - {term_name} - {category_name}"
1093
- save_success = update_category_answers(username, term_name, category_name, answers_list, commit_message)
1094
 
1095
  if save_success:
1096
  # Mark as submitted and save original selections
@@ -1113,7 +1139,7 @@ elif st.session_state.logged_in:
1113
  # Show intermediate page
1114
  st.session_state.show_term_complete = True
1115
  st.session_state.completed_term = term.formatted_name
1116
- next_term = get_term_instance(next_term_name)
1117
  st.session_state.next_term = next_term.formatted_name
1118
  else:
1119
  # Same term, just move to next category
@@ -1148,7 +1174,7 @@ elif st.session_state.logged_in:
1148
  # Moving to a different term - show term switching page
1149
  st.session_state.show_term_complete = True
1150
  st.session_state.completed_term = term.formatted_name
1151
- next_term = get_term_instance(next_term_name)
1152
  st.session_state.next_term = next_term.formatted_name
1153
  else:
1154
  # Same term, just move forward
@@ -1170,6 +1196,7 @@ elif st.session_state.logged_in:
1170
  if st.button("Logout"):
1171
  st.session_state.logged_in = False
1172
  st.session_state.username = None
 
1173
  st.session_state.current_index = 0
1174
  st.session_state.show_guide = True
1175
  st.session_state.submitted_pairs = set()
 
15
  # Fallback to st.secrets for local development (if not found in environment)
16
  if not HF_DATASET_REPO:
17
  try:
18
+ HF_DATASET_REPO = st.secrets.get("HF_DATASET_REPO", "TransLegal/grading-answers")
19
  except Exception:
20
+ HF_DATASET_REPO = "TransLegal/grading-answers"
21
+
22
+ # Available jurisdictions
23
+ AVAILABLE_JURISDICTIONS = ["en-us", "hr-hr", "sv-se"]
24
 
25
  if not HF_TOKEN:
26
  try:
 
36
  @st.cache_resource
37
  def get_hf_api():
38
  """Get cached Hugging Face API client - only initializes once per session"""
39
+ try:
40
+ login(token=HF_TOKEN)
41
  return HfApi(token=HF_TOKEN)
42
+ except Exception as e:
43
+ st.error(f"❌ **Error initializing Hugging Face API**: {str(e)}")
44
+ st.stop()
45
 
46
  # Initialize HF API - cached to avoid re-initialization on every rerun
47
  hf_api = get_hf_api()
48
 
49
  @st.cache_data
50
+ def load_grading_template(jurisdiction):
51
+ """Load grading template from Hugging Face Dataset for the specified jurisdiction"""
52
  try:
53
  file_path = hf_hub_download(
54
  repo_id=HF_DATASET_REPO,
55
+ filename=f"{jurisdiction}/grading_template.parquet",
56
  repo_type="dataset",
57
  token=HF_TOKEN
58
  )
59
  return pd.read_parquet(file_path)
60
  except Exception as e:
61
  st.error(f"❌ **Error loading grading template from Hugging Face Dataset**: {str(e)}")
62
+ st.error(f"Please ensure the file `{jurisdiction}/grading_template.parquet` exists in the dataset repository: {HF_DATASET_REPO}")
63
  st.stop()
64
 
 
 
 
65
  # Assessment options with descriptive labels
66
  ASSESSMENT_OPTIONS = [
67
  "Perfect",
 
138
  return hashlib.sha256(password.encode()).hexdigest()
139
 
140
  @st.cache_data
141
+ def load_users(jurisdiction):
142
+ """Load user credentials from Hugging Face Dataset for the specified jurisdiction"""
143
  try:
144
  file_path = hf_hub_download(
145
  repo_id=HF_DATASET_REPO,
146
+ filename=f"{jurisdiction}/users/users.json",
147
  repo_type="dataset",
148
  token=HF_TOKEN
149
  )
 
153
  # File doesn't exist yet (first run), return empty dict
154
  return {}
155
 
156
+ def save_users(users, jurisdiction):
157
+ """Save user credentials to Hugging Face Dataset for the specified jurisdiction"""
158
  try:
159
  with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
160
  json.dump(users, f, indent=2)
 
162
 
163
  hf_api.upload_file(
164
  path_or_fileobj=temp_path,
165
+ path_in_repo=f"{jurisdiction}/users/users.json",
166
  repo_id=HF_DATASET_REPO,
167
  repo_type="dataset",
168
  token=HF_TOKEN
 
170
  os.unlink(temp_path)
171
 
172
  # Clear cache for users to ensure fresh data on next load
173
+ load_users.clear(jurisdiction)
174
 
175
  return True
176
  except Exception as e:
 
178
  raise
179
 
180
  @st.cache_data(ttl=3600) # Cache for 1 hour as safety measure
181
+ def load_user_data(username, jurisdiction):
182
+ """Load user's answer data from Hugging Face Dataset for the specified jurisdiction"""
183
  try:
184
  file_path = hf_hub_download(
185
  repo_id=HF_DATASET_REPO,
186
+ filename=f"{jurisdiction}/users/{username}_answers.parquet",
187
  repo_type="dataset",
188
  token=HF_TOKEN
189
  )
190
  return pd.read_parquet(file_path)
191
  except Exception:
192
+ # File doesn't exist yet (new user), create new dataframe from grading template
193
+ df = load_grading_template(jurisdiction)
194
  user_df = df.copy()
195
  user_df['legal_accuracy_score'] = None
196
  user_df['time_stamp'] = None
197
  return user_df
198
 
199
+ def save_user_data(username, user_df, jurisdiction, commit_message=None):
200
+ """Save user's answer data to Hugging Face Dataset for the specified jurisdiction"""
201
  try:
202
  with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f:
203
  user_df.to_parquet(f.name, index=False)
 
205
 
206
  upload_kwargs = {
207
  'path_or_fileobj': temp_path,
208
+ 'path_in_repo': f"{jurisdiction}/users/{username}_answers.parquet",
209
  'repo_id': HF_DATASET_REPO,
210
  'repo_type': "dataset",
211
  'token': HF_TOKEN
 
218
  hf_api.upload_file(**upload_kwargs)
219
  os.unlink(temp_path)
220
 
221
+ # Clear cache for this user/jurisdiction to ensure fresh data on next load
222
+ load_user_data.clear(username, jurisdiction)
223
 
224
  return True
225
  except Exception as e:
226
  st.error(f"❌ **Error saving user data to Hugging Face Dataset**: {str(e)}")
227
  raise
228
 
229
+ def update_user_answer(username, term, category, subcategory, question, answer, score, jurisdiction, df):
230
  """Update a specific answer in the user's data (deprecated - use update_category_answers for bulk updates)"""
231
  try:
232
+ user_df = load_user_data(username, jurisdiction)
233
 
234
  # Find the matching row
235
  mask = (
 
242
 
243
  if mask.any():
244
  user_df.loc[mask, 'legal_accuracy_score'] = score
245
+ save_user_data(username, user_df, jurisdiction)
246
  return True
247
  else:
248
  print(f"Warning: Could not find matching row for: {term}, {category}, {subcategory}, {question}")
 
274
  return [(row['subcategory'], row['question'], row['answer'], "NA")
275
  for _, row in unknown_rows.iterrows()]
276
 
277
+ def auto_score_all_unknown_answers_for_new_user(username, jurisdiction, df):
278
  """
279
  Automatically score all Unknown answers for all categories when a new user is created.
280
  This runs in the background during account creation.
 
298
  return True
299
 
300
  # Load user dataframe once
301
+ user_df = load_user_data(username, jurisdiction)
302
 
303
  # Get current timestamp once for all updates
304
  current_timestamp = pd.Timestamp.now()
 
328
 
329
  # Save once with a single commit message
330
  commit_message = f"Auto-score all Unknown answers for new user {username}"
331
+ save_user_data(username, user_df, jurisdiction, commit_message=commit_message)
332
 
333
  return True
334
  except Exception as e:
335
  print(f"Error auto-scoring Unknown answers for new user {username}: {str(e)}")
336
  return False
337
 
338
+ def update_category_answers(username, term, category, answers_list, jurisdiction, commit_message=None):
339
  """
340
  Update all answers for a category in a single commit.
341
 
 
344
  term: Term name
345
  category: Category name
346
  answers_list: List of tuples (subcategory, question, answer, score)
347
+ jurisdiction: Jurisdiction identifier
348
  commit_message: Optional commit message (auto-generated if None)
349
 
350
  Returns:
 
352
  """
353
  try:
354
  # Load user dataframe once
355
+ user_df = load_user_data(username, jurisdiction)
356
 
357
  # Get current timestamp once for all updates in this category
358
  current_timestamp = pd.Timestamp.now()
 
385
  commit_message = f"Update answers for {username} - {term} - {category}"
386
 
387
  # Save once with commit message
388
+ save_user_data(username, user_df, jurisdiction, commit_message=commit_message)
389
  return True
390
  except Exception as e:
391
  print(f"Error updating category answers: {str(e)}")
392
  return False
393
 
394
+ def get_user_answer(username, term, category, subcategory, question, answer, jurisdiction):
395
  """Get user's answer for a specific question"""
396
+ user_df = load_user_data(username, jurisdiction)
397
 
398
  mask = (
399
  (user_df['term'] == term) &
 
409
  return score
410
  return None
411
 
412
+ def find_first_unanswered_category(username, jurisdiction, df):
413
  """Find the first category that hasn't been fully answered"""
414
+ user_df = load_user_data(username, jurisdiction)
415
 
416
+ # Get term_category_pairs for this jurisdiction
417
+ term_category_pairs = get_term_category_pairs(df)
418
 
419
  for idx, (term, category) in enumerate(term_category_pairs):
420
  # Get all subcategories for this term-category pair from base df
 
447
 
448
  return len(term_category_pairs) # All answered, return last index
449
 
450
+ def restore_submitted_status(username, jurisdiction, df):
451
  """Restore submitted status for categories that have all answers in parquet file"""
452
+ user_df = load_user_data(username, jurisdiction)
453
+
454
+ # Get term_category_pairs for this jurisdiction
455
+ term_category_pairs = get_term_category_pairs(df)
456
 
 
457
  submitted_pairs = set()
458
  for idx, (term, category) in enumerate(term_category_pairs):
459
  pair_key = f"{term}_{category}_{idx}"
 
589
  return [(term, category) for term, category in all_pairs
590
  if category_has_subcategories(term, category, df)]
591
 
592
+ # Cache for Term instances (keyed by jurisdiction and term_name)
 
 
 
 
593
  term_cache = {}
594
 
595
+ def get_term_instance(term_name, df):
596
+ """Get or create a Term instance for the given dataframe"""
597
+ cache_key = f"{id(df)}_{term_name}" # Use df id to differentiate jurisdictions
598
+ if cache_key not in term_cache:
599
+ term_cache[cache_key] = Term(term_name, df)
600
+ return term_cache[cache_key]
601
 
602
+ def get_category_for_pair(term_name, category_name, df):
603
  """Get Category instance for a term-category pair"""
604
+ term = get_term_instance(term_name, df)
605
  return term.get_category_by_name(category_name)
606
 
607
  # Initialize session state
 
609
  st.session_state.logged_in = False
610
  if 'username' not in st.session_state:
611
  st.session_state.username = None
612
+ if 'jurisdiction' not in st.session_state:
613
+ st.session_state.jurisdiction = None
614
  if 'current_index' not in st.session_state:
615
  st.session_state.current_index = 0
616
  if 'show_term_complete' not in st.session_state:
 
639
  # Login page
640
  if not st.session_state.logged_in:
641
  st.markdown("# Login")
642
+ st.markdown("Please select a jurisdiction and enter your username and password to continue.")
643
+
644
+ # Jurisdiction selector
645
+ jurisdiction = st.selectbox("Jurisdiction", options=AVAILABLE_JURISDICTIONS, index=0 if st.session_state.jurisdiction is None else AVAILABLE_JURISDICTIONS.index(st.session_state.jurisdiction) if st.session_state.jurisdiction in AVAILABLE_JURISDICTIONS else 0)
646
+ st.session_state.jurisdiction = jurisdiction
647
 
648
  username = st.text_input("Username")
649
  password = st.text_input("Password", type="password")
 
651
  col1, col2 = st.columns(2)
652
  with col1:
653
  if st.button("Login", type="primary", use_container_width=True):
654
+ if not jurisdiction:
655
+ st.error("Please select a jurisdiction")
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  else:
657
+ users = load_users(jurisdiction)
658
+
659
+ if username in users:
660
+ # Existing user - check password
661
+ if users[username]['password'] == hash_password(password):
662
+ st.session_state.logged_in = True
663
+ st.session_state.username = username
664
+ # Load grading template for this jurisdiction
665
+ df = load_grading_template(jurisdiction)
666
+ # Restore submitted status for previously submitted categories
667
+ st.session_state.submitted_pairs = restore_submitted_status(username, jurisdiction, df)
668
+ # Find first unanswered category and resume there
669
+ resume_index = find_first_unanswered_category(username, jurisdiction, df)
670
+ st.session_state.current_index = resume_index
671
+ st.rerun()
672
+ else:
673
+ st.error("Incorrect password")
674
+ else:
675
+ # Username not found - require registration
676
+ st.error("Username not found. Please register first using the 'Register New User' button.")
677
 
678
  with col2:
679
  if st.button("Register New User", use_container_width=True):
680
+ if not jurisdiction:
681
+ st.error("Please select a jurisdiction")
 
 
 
 
 
 
 
682
  else:
683
+ users = load_users(jurisdiction)
684
+ if username in users:
685
+ st.error("Username already exists")
686
+ elif username and password:
687
+ users[username] = {'password': hash_password(password)}
688
+ save_users(users, jurisdiction)
689
+ # Load grading template for this jurisdiction
690
+ df = load_grading_template(jurisdiction)
691
+ # Auto-score all Unknown answers for the new user in the background
692
+ auto_score_all_unknown_answers_for_new_user(username, jurisdiction, df)
693
+ st.success("User registered successfully! Please click Login.")
694
+ else:
695
+ st.error("Please enter both username and password")
696
 
697
  # Main application (only shown if logged in)
698
  elif st.session_state.logged_in:
699
  username = st.session_state.username
700
+ jurisdiction = st.session_state.jurisdiction
701
  current_index = st.session_state.current_index
702
 
703
+ # Load grading template for the selected jurisdiction
704
+ df = load_grading_template(jurisdiction)
705
+
706
  # Debug info (can be removed in production)
707
  with st.sidebar:
708
  with st.expander("Debug Info"):
 
710
  st.write(f"HF Token configured: {HF_TOKEN is not None}")
711
  st.write(f"HF API initialized: {hf_api is not None}")
712
  if username:
713
+ st.write(f"Jurisdiction: `{jurisdiction}`")
714
+ st.write(f"User parquet file: `{jurisdiction}/users/{username}_answers.parquet`")
715
+ st.write(f"Users file: `{jurisdiction}/users/users.json`")
716
 
717
  # Check if we should show the annotation guide first
718
  if st.session_state.show_guide:
 
851
  st.session_state.next_term = None
852
  st.rerun()
853
 
854
+ # Get term_category_pairs for this jurisdiction
855
+ term_category_pairs = get_term_category_pairs(df)
856
+ total_pairs = len(term_category_pairs)
857
+
858
  elif current_index < total_pairs:
859
  term_name, category_name = term_category_pairs[current_index]
860
+ category = get_category_for_pair(term_name, category_name, df)
861
+ term = get_term_instance(term_name, df)
862
 
863
  # Safety check: skip if category has no subcategories (shouldn't happen due to filtering, but just in case)
864
  if not category or len(category.subcategories) == 0:
 
884
  # Check visible subcategories
885
  for i, subcat in enumerate(category.subcategories):
886
  saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
887
+ subcat.question, subcat.answer, jurisdiction)
888
  if saved_score is None:
889
  category_fully_answered = False
890
  break
 
893
  if category_fully_answered:
894
  unknown_answers = auto_score_unknown_answers(username, term_name, category_name, df)
895
  for subcategory, question, answer, score in unknown_answers:
896
+ saved_score = get_user_answer(username, term_name, category_name, subcategory, question, answer, jurisdiction)
897
  if saved_score is None:
898
  category_fully_answered = False
899
  break
 
916
  st.session_state.original_selections[pair_key][radio_key] = st.session_state[radio_key]
917
  else:
918
  saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
919
+ subcat.question, subcat.answer, jurisdiction)
920
  if saved_score is not None:
921
  score_to_option = {v: k for k, v in ASSESSMENT_TO_SCORE.items()}
922
  if saved_score in score_to_option:
 
955
  # Get saved value to determine if we should set a default index
956
  # Don't pre-set session_state - only use index parameter to avoid conflicts
957
  saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
958
+ subcat.question, subcat.answer, jurisdiction)
959
  default_index = None
960
 
961
  # Check if there's a saved value in parquet file
 
1040
  st.session_state.back_current_index = current_index
1041
  st.session_state.show_term_back_warning = True
1042
  st.session_state.back_current_term = term.formatted_name
1043
+ prev_term = get_term_instance(prev_term_name, df)
1044
  st.session_state.back_previous_term = prev_term.formatted_name
1045
  else:
1046
  # Same term, just move back
 
1084
 
1085
  # Update all answers in a single commit
1086
  commit_message = f"Update answers for {username} - {term_name} - {category_name}"
1087
+ save_success = update_category_answers(username, term_name, category_name, answers_list, jurisdiction, commit_message)
1088
 
1089
  if save_success:
1090
  # Save current selections as new originals
 
1116
  # Update all answers in a single commit
1117
  # Note: Unknown answers are already auto-scored during account creation
1118
  commit_message = f"Update answers for {username} - {term_name} - {category_name}"
1119
+ save_success = update_category_answers(username, term_name, category_name, answers_list, jurisdiction, commit_message)
1120
 
1121
  if save_success:
1122
  # Mark as submitted and save original selections
 
1139
  # Show intermediate page
1140
  st.session_state.show_term_complete = True
1141
  st.session_state.completed_term = term.formatted_name
1142
+ next_term = get_term_instance(next_term_name, df)
1143
  st.session_state.next_term = next_term.formatted_name
1144
  else:
1145
  # Same term, just move to next category
 
1174
  # Moving to a different term - show term switching page
1175
  st.session_state.show_term_complete = True
1176
  st.session_state.completed_term = term.formatted_name
1177
+ next_term = get_term_instance(next_term_name, df)
1178
  st.session_state.next_term = next_term.formatted_name
1179
  else:
1180
  # Same term, just move forward
 
1196
  if st.button("Logout"):
1197
  st.session_state.logged_in = False
1198
  st.session_state.username = None
1199
+ st.session_state.jurisdiction = None
1200
  st.session_state.current_index = 0
1201
  st.session_state.show_guide = True
1202
  st.session_state.submitted_pairs = set()