Spaces:
Running
Running
Fredrik Sitje
commited on
Commit
·
077c9e3
1
Parent(s):
058983a
Refactor Streamlit app to support jurisdiction-specific data handling. Updated functions to accept jurisdiction as a parameter, modified file paths for user data and grading templates, and added jurisdiction selection in the login process. This enhances the app's flexibility for different jurisdictions.
Browse files- src/streamlit_app.py +128 -101
src/streamlit_app.py
CHANGED
|
@@ -15,9 +15,12 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 15 |
# Fallback to st.secrets for local development (if not found in environment)
|
| 16 |
if not HF_DATASET_REPO:
|
| 17 |
try:
|
| 18 |
-
HF_DATASET_REPO = st.secrets.get("HF_DATASET_REPO", "TransLegal/
|
| 19 |
except Exception:
|
| 20 |
-
HF_DATASET_REPO = "TransLegal/
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
if not HF_TOKEN:
|
| 23 |
try:
|
|
@@ -33,35 +36,32 @@ if not HF_TOKEN:
|
|
| 33 |
@st.cache_resource
|
| 34 |
def get_hf_api():
|
| 35 |
"""Get cached Hugging Face API client - only initializes once per session"""
|
| 36 |
-
|
| 37 |
-
|
| 38 |
return HfApi(token=HF_TOKEN)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
# Initialize HF API - cached to avoid re-initialization on every rerun
|
| 44 |
hf_api = get_hf_api()
|
| 45 |
|
| 46 |
@st.cache_data
|
| 47 |
-
def load_grading_template():
|
| 48 |
-
"""Load grading template from Hugging Face Dataset"""
|
| 49 |
try:
|
| 50 |
file_path = hf_hub_download(
|
| 51 |
repo_id=HF_DATASET_REPO,
|
| 52 |
-
filename="grading_template.parquet",
|
| 53 |
repo_type="dataset",
|
| 54 |
token=HF_TOKEN
|
| 55 |
)
|
| 56 |
return pd.read_parquet(file_path)
|
| 57 |
except Exception as e:
|
| 58 |
st.error(f"❌ **Error loading grading template from Hugging Face Dataset**: {str(e)}")
|
| 59 |
-
st.error(f"Please ensure the file `grading_template.parquet` exists in the dataset repository: {HF_DATASET_REPO}")
|
| 60 |
st.stop()
|
| 61 |
|
| 62 |
-
# Load data from the grading template
|
| 63 |
-
df = load_grading_template()
|
| 64 |
-
|
| 65 |
# Assessment options with descriptive labels
|
| 66 |
ASSESSMENT_OPTIONS = [
|
| 67 |
"Perfect",
|
|
@@ -138,12 +138,12 @@ def hash_password(password):
|
|
| 138 |
return hashlib.sha256(password.encode()).hexdigest()
|
| 139 |
|
| 140 |
@st.cache_data
|
| 141 |
-
def load_users():
|
| 142 |
-
"""Load user credentials from Hugging Face Dataset"""
|
| 143 |
try:
|
| 144 |
file_path = hf_hub_download(
|
| 145 |
repo_id=HF_DATASET_REPO,
|
| 146 |
-
filename="users/users.json",
|
| 147 |
repo_type="dataset",
|
| 148 |
token=HF_TOKEN
|
| 149 |
)
|
|
@@ -153,8 +153,8 @@ def load_users():
|
|
| 153 |
# File doesn't exist yet (first run), return empty dict
|
| 154 |
return {}
|
| 155 |
|
| 156 |
-
def save_users(users):
|
| 157 |
-
"""Save user credentials to Hugging Face Dataset"""
|
| 158 |
try:
|
| 159 |
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| 160 |
json.dump(users, f, indent=2)
|
|
@@ -162,7 +162,7 @@ def save_users(users):
|
|
| 162 |
|
| 163 |
hf_api.upload_file(
|
| 164 |
path_or_fileobj=temp_path,
|
| 165 |
-
path_in_repo="users/users.json",
|
| 166 |
repo_id=HF_DATASET_REPO,
|
| 167 |
repo_type="dataset",
|
| 168 |
token=HF_TOKEN
|
|
@@ -170,7 +170,7 @@ def save_users(users):
|
|
| 170 |
os.unlink(temp_path)
|
| 171 |
|
| 172 |
# Clear cache for users to ensure fresh data on next load
|
| 173 |
-
load_users.clear()
|
| 174 |
|
| 175 |
return True
|
| 176 |
except Exception as e:
|
|
@@ -178,25 +178,26 @@ def save_users(users):
|
|
| 178 |
raise
|
| 179 |
|
| 180 |
@st.cache_data(ttl=3600) # Cache for 1 hour as safety measure
|
| 181 |
-
def load_user_data(username):
|
| 182 |
-
"""Load user's answer data from Hugging Face Dataset"""
|
| 183 |
try:
|
| 184 |
file_path = hf_hub_download(
|
| 185 |
repo_id=HF_DATASET_REPO,
|
| 186 |
-
filename=f"users/{username}_answers.parquet",
|
| 187 |
repo_type="dataset",
|
| 188 |
token=HF_TOKEN
|
| 189 |
)
|
| 190 |
return pd.read_parquet(file_path)
|
| 191 |
except Exception:
|
| 192 |
-
# File doesn't exist yet (new user), create new dataframe
|
|
|
|
| 193 |
user_df = df.copy()
|
| 194 |
user_df['legal_accuracy_score'] = None
|
| 195 |
user_df['time_stamp'] = None
|
| 196 |
return user_df
|
| 197 |
|
| 198 |
-
def save_user_data(username, user_df, commit_message=None):
|
| 199 |
-
"""Save user's answer data to Hugging Face Dataset"""
|
| 200 |
try:
|
| 201 |
with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f:
|
| 202 |
user_df.to_parquet(f.name, index=False)
|
|
@@ -204,7 +205,7 @@ def save_user_data(username, user_df, commit_message=None):
|
|
| 204 |
|
| 205 |
upload_kwargs = {
|
| 206 |
'path_or_fileobj': temp_path,
|
| 207 |
-
'path_in_repo': f"users/{username}_answers.parquet",
|
| 208 |
'repo_id': HF_DATASET_REPO,
|
| 209 |
'repo_type': "dataset",
|
| 210 |
'token': HF_TOKEN
|
|
@@ -217,18 +218,18 @@ def save_user_data(username, user_df, commit_message=None):
|
|
| 217 |
hf_api.upload_file(**upload_kwargs)
|
| 218 |
os.unlink(temp_path)
|
| 219 |
|
| 220 |
-
# Clear cache for this user to ensure fresh data on next load
|
| 221 |
-
load_user_data.clear(username)
|
| 222 |
|
| 223 |
return True
|
| 224 |
except Exception as e:
|
| 225 |
st.error(f"❌ **Error saving user data to Hugging Face Dataset**: {str(e)}")
|
| 226 |
raise
|
| 227 |
|
| 228 |
-
def update_user_answer(username, term, category, subcategory, question, answer, score):
|
| 229 |
"""Update a specific answer in the user's data (deprecated - use update_category_answers for bulk updates)"""
|
| 230 |
try:
|
| 231 |
-
user_df = load_user_data(username)
|
| 232 |
|
| 233 |
# Find the matching row
|
| 234 |
mask = (
|
|
@@ -241,7 +242,7 @@ def update_user_answer(username, term, category, subcategory, question, answer,
|
|
| 241 |
|
| 242 |
if mask.any():
|
| 243 |
user_df.loc[mask, 'legal_accuracy_score'] = score
|
| 244 |
-
save_user_data(username, user_df)
|
| 245 |
return True
|
| 246 |
else:
|
| 247 |
print(f"Warning: Could not find matching row for: {term}, {category}, {subcategory}, {question}")
|
|
@@ -273,7 +274,7 @@ def auto_score_unknown_answers(username, term, category, df):
|
|
| 273 |
return [(row['subcategory'], row['question'], row['answer'], "NA")
|
| 274 |
for _, row in unknown_rows.iterrows()]
|
| 275 |
|
| 276 |
-
def auto_score_all_unknown_answers_for_new_user(username):
|
| 277 |
"""
|
| 278 |
Automatically score all Unknown answers for all categories when a new user is created.
|
| 279 |
This runs in the background during account creation.
|
|
@@ -297,7 +298,7 @@ def auto_score_all_unknown_answers_for_new_user(username):
|
|
| 297 |
return True
|
| 298 |
|
| 299 |
# Load user dataframe once
|
| 300 |
-
user_df = load_user_data(username)
|
| 301 |
|
| 302 |
# Get current timestamp once for all updates
|
| 303 |
current_timestamp = pd.Timestamp.now()
|
|
@@ -327,14 +328,14 @@ def auto_score_all_unknown_answers_for_new_user(username):
|
|
| 327 |
|
| 328 |
# Save once with a single commit message
|
| 329 |
commit_message = f"Auto-score all Unknown answers for new user {username}"
|
| 330 |
-
save_user_data(username, user_df, commit_message=commit_message)
|
| 331 |
|
| 332 |
return True
|
| 333 |
except Exception as e:
|
| 334 |
print(f"Error auto-scoring Unknown answers for new user {username}: {str(e)}")
|
| 335 |
return False
|
| 336 |
|
| 337 |
-
def update_category_answers(username, term, category, answers_list, commit_message=None):
|
| 338 |
"""
|
| 339 |
Update all answers for a category in a single commit.
|
| 340 |
|
|
@@ -343,6 +344,7 @@ def update_category_answers(username, term, category, answers_list, commit_messa
|
|
| 343 |
term: Term name
|
| 344 |
category: Category name
|
| 345 |
answers_list: List of tuples (subcategory, question, answer, score)
|
|
|
|
| 346 |
commit_message: Optional commit message (auto-generated if None)
|
| 347 |
|
| 348 |
Returns:
|
|
@@ -350,7 +352,7 @@ def update_category_answers(username, term, category, answers_list, commit_messa
|
|
| 350 |
"""
|
| 351 |
try:
|
| 352 |
# Load user dataframe once
|
| 353 |
-
user_df = load_user_data(username)
|
| 354 |
|
| 355 |
# Get current timestamp once for all updates in this category
|
| 356 |
current_timestamp = pd.Timestamp.now()
|
|
@@ -383,15 +385,15 @@ def update_category_answers(username, term, category, answers_list, commit_messa
|
|
| 383 |
commit_message = f"Update answers for {username} - {term} - {category}"
|
| 384 |
|
| 385 |
# Save once with commit message
|
| 386 |
-
save_user_data(username, user_df, commit_message=commit_message)
|
| 387 |
return True
|
| 388 |
except Exception as e:
|
| 389 |
print(f"Error updating category answers: {str(e)}")
|
| 390 |
return False
|
| 391 |
|
| 392 |
-
def get_user_answer(username, term, category, subcategory, question, answer):
|
| 393 |
"""Get user's answer for a specific question"""
|
| 394 |
-
user_df = load_user_data(username)
|
| 395 |
|
| 396 |
mask = (
|
| 397 |
(user_df['term'] == term) &
|
|
@@ -407,12 +409,12 @@ def get_user_answer(username, term, category, subcategory, question, answer):
|
|
| 407 |
return score
|
| 408 |
return None
|
| 409 |
|
| 410 |
-
def find_first_unanswered_category(username):
|
| 411 |
"""Find the first category that hasn't been fully answered"""
|
| 412 |
-
user_df = load_user_data(username)
|
| 413 |
|
| 414 |
-
#
|
| 415 |
-
|
| 416 |
|
| 417 |
for idx, (term, category) in enumerate(term_category_pairs):
|
| 418 |
# Get all subcategories for this term-category pair from base df
|
|
@@ -445,11 +447,13 @@ def find_first_unanswered_category(username):
|
|
| 445 |
|
| 446 |
return len(term_category_pairs) # All answered, return last index
|
| 447 |
|
| 448 |
-
def restore_submitted_status(username):
|
| 449 |
"""Restore submitted status for categories that have all answers in parquet file"""
|
| 450 |
-
user_df = load_user_data(username)
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
-
# Use the global term_category_pairs to ensure consistent ordering
|
| 453 |
submitted_pairs = set()
|
| 454 |
for idx, (term, category) in enumerate(term_category_pairs):
|
| 455 |
pair_key = f"{term}_{category}_{idx}"
|
|
@@ -585,22 +589,19 @@ def get_term_category_pairs(df):
|
|
| 585 |
return [(term, category) for term, category in all_pairs
|
| 586 |
if category_has_subcategories(term, category, df)]
|
| 587 |
|
| 588 |
-
#
|
| 589 |
-
term_category_pairs = get_term_category_pairs(df)
|
| 590 |
-
total_pairs = len(term_category_pairs)
|
| 591 |
-
|
| 592 |
-
# Cache for Term instances
|
| 593 |
term_cache = {}
|
| 594 |
|
| 595 |
-
def get_term_instance(term_name):
|
| 596 |
-
"""Get or create a Term instance"""
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
|
|
|
| 600 |
|
| 601 |
-
def get_category_for_pair(term_name, category_name):
|
| 602 |
"""Get Category instance for a term-category pair"""
|
| 603 |
-
term = get_term_instance(term_name)
|
| 604 |
return term.get_category_by_name(category_name)
|
| 605 |
|
| 606 |
# Initialize session state
|
|
@@ -608,6 +609,8 @@ if 'logged_in' not in st.session_state:
|
|
| 608 |
st.session_state.logged_in = False
|
| 609 |
if 'username' not in st.session_state:
|
| 610 |
st.session_state.username = None
|
|
|
|
|
|
|
| 611 |
if 'current_index' not in st.session_state:
|
| 612 |
st.session_state.current_index = 0
|
| 613 |
if 'show_term_complete' not in st.session_state:
|
|
@@ -636,7 +639,11 @@ if 'has_unsaved_changes' not in st.session_state:
|
|
| 636 |
# Login page
|
| 637 |
if not st.session_state.logged_in:
|
| 638 |
st.markdown("# Login")
|
| 639 |
-
st.markdown("Please enter your username and password to continue.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
username = st.text_input("Username")
|
| 642 |
password = st.text_input("Password", type="password")
|
|
@@ -644,44 +651,58 @@ if not st.session_state.logged_in:
|
|
| 644 |
col1, col2 = st.columns(2)
|
| 645 |
with col1:
|
| 646 |
if st.button("Login", type="primary", use_container_width=True):
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
if username in users:
|
| 650 |
-
# Existing user - check password
|
| 651 |
-
if users[username]['password'] == hash_password(password):
|
| 652 |
-
st.session_state.logged_in = True
|
| 653 |
-
st.session_state.username = username
|
| 654 |
-
# Restore submitted status for previously submitted categories
|
| 655 |
-
st.session_state.submitted_pairs = restore_submitted_status(username)
|
| 656 |
-
# Find first unanswered category and resume there
|
| 657 |
-
resume_index = find_first_unanswered_category(username)
|
| 658 |
-
st.session_state.current_index = resume_index
|
| 659 |
-
st.rerun()
|
| 660 |
-
else:
|
| 661 |
-
st.error("Incorrect password")
|
| 662 |
else:
|
| 663 |
-
|
| 664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
|
| 666 |
with col2:
|
| 667 |
if st.button("Register New User", use_container_width=True):
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
st.error("Username already exists")
|
| 671 |
-
elif username and password:
|
| 672 |
-
users[username] = {'password': hash_password(password)}
|
| 673 |
-
save_users(users)
|
| 674 |
-
# Auto-score all Unknown answers for the new user in the background
|
| 675 |
-
auto_score_all_unknown_answers_for_new_user(username)
|
| 676 |
-
st.success("User registered successfully! Please click Login.")
|
| 677 |
else:
|
| 678 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
|
| 680 |
# Main application (only shown if logged in)
|
| 681 |
elif st.session_state.logged_in:
|
| 682 |
username = st.session_state.username
|
|
|
|
| 683 |
current_index = st.session_state.current_index
|
| 684 |
|
|
|
|
|
|
|
|
|
|
| 685 |
# Debug info (can be removed in production)
|
| 686 |
with st.sidebar:
|
| 687 |
with st.expander("Debug Info"):
|
|
@@ -689,8 +710,9 @@ elif st.session_state.logged_in:
|
|
| 689 |
st.write(f"HF Token configured: {HF_TOKEN is not None}")
|
| 690 |
st.write(f"HF API initialized: {hf_api is not None}")
|
| 691 |
if username:
|
| 692 |
-
st.write(f"
|
| 693 |
-
st.write(f"
|
|
|
|
| 694 |
|
| 695 |
# Check if we should show the annotation guide first
|
| 696 |
if st.session_state.show_guide:
|
|
@@ -829,10 +851,14 @@ elif st.session_state.logged_in:
|
|
| 829 |
st.session_state.next_term = None
|
| 830 |
st.rerun()
|
| 831 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
elif current_index < total_pairs:
|
| 833 |
term_name, category_name = term_category_pairs[current_index]
|
| 834 |
-
category = get_category_for_pair(term_name, category_name)
|
| 835 |
-
term = get_term_instance(term_name)
|
| 836 |
|
| 837 |
# Safety check: skip if category has no subcategories (shouldn't happen due to filtering, but just in case)
|
| 838 |
if not category or len(category.subcategories) == 0:
|
|
@@ -858,7 +884,7 @@ elif st.session_state.logged_in:
|
|
| 858 |
# Check visible subcategories
|
| 859 |
for i, subcat in enumerate(category.subcategories):
|
| 860 |
saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
|
| 861 |
-
subcat.question, subcat.answer)
|
| 862 |
if saved_score is None:
|
| 863 |
category_fully_answered = False
|
| 864 |
break
|
|
@@ -867,7 +893,7 @@ elif st.session_state.logged_in:
|
|
| 867 |
if category_fully_answered:
|
| 868 |
unknown_answers = auto_score_unknown_answers(username, term_name, category_name, df)
|
| 869 |
for subcategory, question, answer, score in unknown_answers:
|
| 870 |
-
saved_score = get_user_answer(username, term_name, category_name, subcategory, question, answer)
|
| 871 |
if saved_score is None:
|
| 872 |
category_fully_answered = False
|
| 873 |
break
|
|
@@ -890,7 +916,7 @@ elif st.session_state.logged_in:
|
|
| 890 |
st.session_state.original_selections[pair_key][radio_key] = st.session_state[radio_key]
|
| 891 |
else:
|
| 892 |
saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
|
| 893 |
-
subcat.question, subcat.answer)
|
| 894 |
if saved_score is not None:
|
| 895 |
score_to_option = {v: k for k, v in ASSESSMENT_TO_SCORE.items()}
|
| 896 |
if saved_score in score_to_option:
|
|
@@ -929,7 +955,7 @@ elif st.session_state.logged_in:
|
|
| 929 |
# Get saved value to determine if we should set a default index
|
| 930 |
# Don't pre-set session_state - only use index parameter to avoid conflicts
|
| 931 |
saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
|
| 932 |
-
subcat.question, subcat.answer)
|
| 933 |
default_index = None
|
| 934 |
|
| 935 |
# Check if there's a saved value in parquet file
|
|
@@ -1014,7 +1040,7 @@ elif st.session_state.logged_in:
|
|
| 1014 |
st.session_state.back_current_index = current_index
|
| 1015 |
st.session_state.show_term_back_warning = True
|
| 1016 |
st.session_state.back_current_term = term.formatted_name
|
| 1017 |
-
prev_term = get_term_instance(prev_term_name)
|
| 1018 |
st.session_state.back_previous_term = prev_term.formatted_name
|
| 1019 |
else:
|
| 1020 |
# Same term, just move back
|
|
@@ -1058,7 +1084,7 @@ elif st.session_state.logged_in:
|
|
| 1058 |
|
| 1059 |
# Update all answers in a single commit
|
| 1060 |
commit_message = f"Update answers for {username} - {term_name} - {category_name}"
|
| 1061 |
-
save_success = update_category_answers(username, term_name, category_name, answers_list, commit_message)
|
| 1062 |
|
| 1063 |
if save_success:
|
| 1064 |
# Save current selections as new originals
|
|
@@ -1090,7 +1116,7 @@ elif st.session_state.logged_in:
|
|
| 1090 |
# Update all answers in a single commit
|
| 1091 |
# Note: Unknown answers are already auto-scored during account creation
|
| 1092 |
commit_message = f"Update answers for {username} - {term_name} - {category_name}"
|
| 1093 |
-
save_success = update_category_answers(username, term_name, category_name, answers_list, commit_message)
|
| 1094 |
|
| 1095 |
if save_success:
|
| 1096 |
# Mark as submitted and save original selections
|
|
@@ -1113,7 +1139,7 @@ elif st.session_state.logged_in:
|
|
| 1113 |
# Show intermediate page
|
| 1114 |
st.session_state.show_term_complete = True
|
| 1115 |
st.session_state.completed_term = term.formatted_name
|
| 1116 |
-
next_term = get_term_instance(next_term_name)
|
| 1117 |
st.session_state.next_term = next_term.formatted_name
|
| 1118 |
else:
|
| 1119 |
# Same term, just move to next category
|
|
@@ -1148,7 +1174,7 @@ elif st.session_state.logged_in:
|
|
| 1148 |
# Moving to a different term - show term switching page
|
| 1149 |
st.session_state.show_term_complete = True
|
| 1150 |
st.session_state.completed_term = term.formatted_name
|
| 1151 |
-
|
| 1152 |
st.session_state.next_term = next_term.formatted_name
|
| 1153 |
else:
|
| 1154 |
# Same term, just move forward
|
|
@@ -1170,6 +1196,7 @@ elif st.session_state.logged_in:
|
|
| 1170 |
if st.button("Logout"):
|
| 1171 |
st.session_state.logged_in = False
|
| 1172 |
st.session_state.username = None
|
|
|
|
| 1173 |
st.session_state.current_index = 0
|
| 1174 |
st.session_state.show_guide = True
|
| 1175 |
st.session_state.submitted_pairs = set()
|
|
|
|
| 15 |
# Fallback to st.secrets for local development (if not found in environment)
|
| 16 |
if not HF_DATASET_REPO:
|
| 17 |
try:
|
| 18 |
+
HF_DATASET_REPO = st.secrets.get("HF_DATASET_REPO", "TransLegal/grading-answers")
|
| 19 |
except Exception:
|
| 20 |
+
HF_DATASET_REPO = "TransLegal/grading-answers"
|
| 21 |
+
|
| 22 |
+
# Available jurisdictions
|
| 23 |
+
AVAILABLE_JURISDICTIONS = ["en-us", "hr-hr", "sv-se"]
|
| 24 |
|
| 25 |
if not HF_TOKEN:
|
| 26 |
try:
|
|
|
|
| 36 |
@st.cache_resource
|
| 37 |
def get_hf_api():
|
| 38 |
"""Get cached Hugging Face API client - only initializes once per session"""
|
| 39 |
+
try:
|
| 40 |
+
login(token=HF_TOKEN)
|
| 41 |
return HfApi(token=HF_TOKEN)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
st.error(f"❌ **Error initializing Hugging Face API**: {str(e)}")
|
| 44 |
+
st.stop()
|
| 45 |
|
| 46 |
# Initialize HF API - cached to avoid re-initialization on every rerun
|
| 47 |
hf_api = get_hf_api()
|
| 48 |
|
| 49 |
@st.cache_data
|
| 50 |
+
def load_grading_template(jurisdiction):
|
| 51 |
+
"""Load grading template from Hugging Face Dataset for the specified jurisdiction"""
|
| 52 |
try:
|
| 53 |
file_path = hf_hub_download(
|
| 54 |
repo_id=HF_DATASET_REPO,
|
| 55 |
+
filename=f"{jurisdiction}/grading_template.parquet",
|
| 56 |
repo_type="dataset",
|
| 57 |
token=HF_TOKEN
|
| 58 |
)
|
| 59 |
return pd.read_parquet(file_path)
|
| 60 |
except Exception as e:
|
| 61 |
st.error(f"❌ **Error loading grading template from Hugging Face Dataset**: {str(e)}")
|
| 62 |
+
st.error(f"Please ensure the file `{jurisdiction}/grading_template.parquet` exists in the dataset repository: {HF_DATASET_REPO}")
|
| 63 |
st.stop()
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
# Assessment options with descriptive labels
|
| 66 |
ASSESSMENT_OPTIONS = [
|
| 67 |
"Perfect",
|
|
|
|
| 138 |
return hashlib.sha256(password.encode()).hexdigest()
|
| 139 |
|
| 140 |
@st.cache_data
|
| 141 |
+
def load_users(jurisdiction):
|
| 142 |
+
"""Load user credentials from Hugging Face Dataset for the specified jurisdiction"""
|
| 143 |
try:
|
| 144 |
file_path = hf_hub_download(
|
| 145 |
repo_id=HF_DATASET_REPO,
|
| 146 |
+
filename=f"{jurisdiction}/users/users.json",
|
| 147 |
repo_type="dataset",
|
| 148 |
token=HF_TOKEN
|
| 149 |
)
|
|
|
|
| 153 |
# File doesn't exist yet (first run), return empty dict
|
| 154 |
return {}
|
| 155 |
|
| 156 |
+
def save_users(users, jurisdiction):
|
| 157 |
+
"""Save user credentials to Hugging Face Dataset for the specified jurisdiction"""
|
| 158 |
try:
|
| 159 |
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
| 160 |
json.dump(users, f, indent=2)
|
|
|
|
| 162 |
|
| 163 |
hf_api.upload_file(
|
| 164 |
path_or_fileobj=temp_path,
|
| 165 |
+
path_in_repo=f"{jurisdiction}/users/users.json",
|
| 166 |
repo_id=HF_DATASET_REPO,
|
| 167 |
repo_type="dataset",
|
| 168 |
token=HF_TOKEN
|
|
|
|
| 170 |
os.unlink(temp_path)
|
| 171 |
|
| 172 |
# Clear cache for users to ensure fresh data on next load
|
| 173 |
+
load_users.clear(jurisdiction)
|
| 174 |
|
| 175 |
return True
|
| 176 |
except Exception as e:
|
|
|
|
| 178 |
raise
|
| 179 |
|
| 180 |
@st.cache_data(ttl=3600) # Cache for 1 hour as safety measure
|
| 181 |
+
def load_user_data(username, jurisdiction):
|
| 182 |
+
"""Load user's answer data from Hugging Face Dataset for the specified jurisdiction"""
|
| 183 |
try:
|
| 184 |
file_path = hf_hub_download(
|
| 185 |
repo_id=HF_DATASET_REPO,
|
| 186 |
+
filename=f"{jurisdiction}/users/{username}_answers.parquet",
|
| 187 |
repo_type="dataset",
|
| 188 |
token=HF_TOKEN
|
| 189 |
)
|
| 190 |
return pd.read_parquet(file_path)
|
| 191 |
except Exception:
|
| 192 |
+
# File doesn't exist yet (new user), create new dataframe from grading template
|
| 193 |
+
df = load_grading_template(jurisdiction)
|
| 194 |
user_df = df.copy()
|
| 195 |
user_df['legal_accuracy_score'] = None
|
| 196 |
user_df['time_stamp'] = None
|
| 197 |
return user_df
|
| 198 |
|
| 199 |
+
def save_user_data(username, user_df, jurisdiction, commit_message=None):
|
| 200 |
+
"""Save user's answer data to Hugging Face Dataset for the specified jurisdiction"""
|
| 201 |
try:
|
| 202 |
with tempfile.NamedTemporaryFile(suffix='.parquet', delete=False) as f:
|
| 203 |
user_df.to_parquet(f.name, index=False)
|
|
|
|
| 205 |
|
| 206 |
upload_kwargs = {
|
| 207 |
'path_or_fileobj': temp_path,
|
| 208 |
+
'path_in_repo': f"{jurisdiction}/users/{username}_answers.parquet",
|
| 209 |
'repo_id': HF_DATASET_REPO,
|
| 210 |
'repo_type': "dataset",
|
| 211 |
'token': HF_TOKEN
|
|
|
|
| 218 |
hf_api.upload_file(**upload_kwargs)
|
| 219 |
os.unlink(temp_path)
|
| 220 |
|
| 221 |
+
# Clear cache for this user/jurisdiction to ensure fresh data on next load
|
| 222 |
+
load_user_data.clear(username, jurisdiction)
|
| 223 |
|
| 224 |
return True
|
| 225 |
except Exception as e:
|
| 226 |
st.error(f"❌ **Error saving user data to Hugging Face Dataset**: {str(e)}")
|
| 227 |
raise
|
| 228 |
|
| 229 |
+
def update_user_answer(username, term, category, subcategory, question, answer, score, jurisdiction, df):
|
| 230 |
"""Update a specific answer in the user's data (deprecated - use update_category_answers for bulk updates)"""
|
| 231 |
try:
|
| 232 |
+
user_df = load_user_data(username, jurisdiction)
|
| 233 |
|
| 234 |
# Find the matching row
|
| 235 |
mask = (
|
|
|
|
| 242 |
|
| 243 |
if mask.any():
|
| 244 |
user_df.loc[mask, 'legal_accuracy_score'] = score
|
| 245 |
+
save_user_data(username, user_df, jurisdiction)
|
| 246 |
return True
|
| 247 |
else:
|
| 248 |
print(f"Warning: Could not find matching row for: {term}, {category}, {subcategory}, {question}")
|
|
|
|
| 274 |
return [(row['subcategory'], row['question'], row['answer'], "NA")
|
| 275 |
for _, row in unknown_rows.iterrows()]
|
| 276 |
|
| 277 |
+
def auto_score_all_unknown_answers_for_new_user(username, jurisdiction, df):
|
| 278 |
"""
|
| 279 |
Automatically score all Unknown answers for all categories when a new user is created.
|
| 280 |
This runs in the background during account creation.
|
|
|
|
| 298 |
return True
|
| 299 |
|
| 300 |
# Load user dataframe once
|
| 301 |
+
user_df = load_user_data(username, jurisdiction)
|
| 302 |
|
| 303 |
# Get current timestamp once for all updates
|
| 304 |
current_timestamp = pd.Timestamp.now()
|
|
|
|
| 328 |
|
| 329 |
# Save once with a single commit message
|
| 330 |
commit_message = f"Auto-score all Unknown answers for new user {username}"
|
| 331 |
+
save_user_data(username, user_df, jurisdiction, commit_message=commit_message)
|
| 332 |
|
| 333 |
return True
|
| 334 |
except Exception as e:
|
| 335 |
print(f"Error auto-scoring Unknown answers for new user {username}: {str(e)}")
|
| 336 |
return False
|
| 337 |
|
| 338 |
+
def update_category_answers(username, term, category, answers_list, jurisdiction, commit_message=None):
|
| 339 |
"""
|
| 340 |
Update all answers for a category in a single commit.
|
| 341 |
|
|
|
|
| 344 |
term: Term name
|
| 345 |
category: Category name
|
| 346 |
answers_list: List of tuples (subcategory, question, answer, score)
|
| 347 |
+
jurisdiction: Jurisdiction identifier
|
| 348 |
commit_message: Optional commit message (auto-generated if None)
|
| 349 |
|
| 350 |
Returns:
|
|
|
|
| 352 |
"""
|
| 353 |
try:
|
| 354 |
# Load user dataframe once
|
| 355 |
+
user_df = load_user_data(username, jurisdiction)
|
| 356 |
|
| 357 |
# Get current timestamp once for all updates in this category
|
| 358 |
current_timestamp = pd.Timestamp.now()
|
|
|
|
| 385 |
commit_message = f"Update answers for {username} - {term} - {category}"
|
| 386 |
|
| 387 |
# Save once with commit message
|
| 388 |
+
save_user_data(username, user_df, jurisdiction, commit_message=commit_message)
|
| 389 |
return True
|
| 390 |
except Exception as e:
|
| 391 |
print(f"Error updating category answers: {str(e)}")
|
| 392 |
return False
|
| 393 |
|
| 394 |
+
def get_user_answer(username, term, category, subcategory, question, answer, jurisdiction):
|
| 395 |
"""Get user's answer for a specific question"""
|
| 396 |
+
user_df = load_user_data(username, jurisdiction)
|
| 397 |
|
| 398 |
mask = (
|
| 399 |
(user_df['term'] == term) &
|
|
|
|
| 409 |
return score
|
| 410 |
return None
|
| 411 |
|
| 412 |
+
def find_first_unanswered_category(username, jurisdiction, df):
|
| 413 |
"""Find the first category that hasn't been fully answered"""
|
| 414 |
+
user_df = load_user_data(username, jurisdiction)
|
| 415 |
|
| 416 |
+
# Get term_category_pairs for this jurisdiction
|
| 417 |
+
term_category_pairs = get_term_category_pairs(df)
|
| 418 |
|
| 419 |
for idx, (term, category) in enumerate(term_category_pairs):
|
| 420 |
# Get all subcategories for this term-category pair from base df
|
|
|
|
| 447 |
|
| 448 |
return len(term_category_pairs) # All answered, return last index
|
| 449 |
|
| 450 |
+
def restore_submitted_status(username, jurisdiction, df):
|
| 451 |
"""Restore submitted status for categories that have all answers in parquet file"""
|
| 452 |
+
user_df = load_user_data(username, jurisdiction)
|
| 453 |
+
|
| 454 |
+
# Get term_category_pairs for this jurisdiction
|
| 455 |
+
term_category_pairs = get_term_category_pairs(df)
|
| 456 |
|
|
|
|
| 457 |
submitted_pairs = set()
|
| 458 |
for idx, (term, category) in enumerate(term_category_pairs):
|
| 459 |
pair_key = f"{term}_{category}_{idx}"
|
|
|
|
| 589 |
return [(term, category) for term, category in all_pairs
|
| 590 |
if category_has_subcategories(term, category, df)]
|
| 591 |
|
| 592 |
+
# Cache for Term instances (keyed by jurisdiction and term_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
term_cache = {}
|
| 594 |
|
| 595 |
+
def get_term_instance(term_name, df):
|
| 596 |
+
"""Get or create a Term instance for the given dataframe"""
|
| 597 |
+
cache_key = f"{id(df)}_{term_name}" # Use df id to differentiate jurisdictions
|
| 598 |
+
if cache_key not in term_cache:
|
| 599 |
+
term_cache[cache_key] = Term(term_name, df)
|
| 600 |
+
return term_cache[cache_key]
|
| 601 |
|
| 602 |
+
def get_category_for_pair(term_name, category_name, df):
|
| 603 |
"""Get Category instance for a term-category pair"""
|
| 604 |
+
term = get_term_instance(term_name, df)
|
| 605 |
return term.get_category_by_name(category_name)
|
| 606 |
|
| 607 |
# Initialize session state
|
|
|
|
| 609 |
st.session_state.logged_in = False
|
| 610 |
if 'username' not in st.session_state:
|
| 611 |
st.session_state.username = None
|
| 612 |
+
if 'jurisdiction' not in st.session_state:
|
| 613 |
+
st.session_state.jurisdiction = None
|
| 614 |
if 'current_index' not in st.session_state:
|
| 615 |
st.session_state.current_index = 0
|
| 616 |
if 'show_term_complete' not in st.session_state:
|
|
|
|
| 639 |
# Login page
|
| 640 |
if not st.session_state.logged_in:
|
| 641 |
st.markdown("# Login")
|
| 642 |
+
st.markdown("Please select a jurisdiction and enter your username and password to continue.")
|
| 643 |
+
|
| 644 |
+
# Jurisdiction selector
|
| 645 |
+
jurisdiction = st.selectbox("Jurisdiction", options=AVAILABLE_JURISDICTIONS, index=0 if st.session_state.jurisdiction is None else AVAILABLE_JURISDICTIONS.index(st.session_state.jurisdiction) if st.session_state.jurisdiction in AVAILABLE_JURISDICTIONS else 0)
|
| 646 |
+
st.session_state.jurisdiction = jurisdiction
|
| 647 |
|
| 648 |
username = st.text_input("Username")
|
| 649 |
password = st.text_input("Password", type="password")
|
|
|
|
| 651 |
col1, col2 = st.columns(2)
|
| 652 |
with col1:
|
| 653 |
if st.button("Login", type="primary", use_container_width=True):
|
| 654 |
+
if not jurisdiction:
|
| 655 |
+
st.error("Please select a jurisdiction")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
else:
|
| 657 |
+
users = load_users(jurisdiction)
|
| 658 |
+
|
| 659 |
+
if username in users:
|
| 660 |
+
# Existing user - check password
|
| 661 |
+
if users[username]['password'] == hash_password(password):
|
| 662 |
+
st.session_state.logged_in = True
|
| 663 |
+
st.session_state.username = username
|
| 664 |
+
# Load grading template for this jurisdiction
|
| 665 |
+
df = load_grading_template(jurisdiction)
|
| 666 |
+
# Restore submitted status for previously submitted categories
|
| 667 |
+
st.session_state.submitted_pairs = restore_submitted_status(username, jurisdiction, df)
|
| 668 |
+
# Find first unanswered category and resume there
|
| 669 |
+
resume_index = find_first_unanswered_category(username, jurisdiction, df)
|
| 670 |
+
st.session_state.current_index = resume_index
|
| 671 |
+
st.rerun()
|
| 672 |
+
else:
|
| 673 |
+
st.error("Incorrect password")
|
| 674 |
+
else:
|
| 675 |
+
# Username not found - require registration
|
| 676 |
+
st.error("Username not found. Please register first using the 'Register New User' button.")
|
| 677 |
|
| 678 |
with col2:
|
| 679 |
if st.button("Register New User", use_container_width=True):
|
| 680 |
+
if not jurisdiction:
|
| 681 |
+
st.error("Please select a jurisdiction")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
else:
|
| 683 |
+
users = load_users(jurisdiction)
|
| 684 |
+
if username in users:
|
| 685 |
+
st.error("Username already exists")
|
| 686 |
+
elif username and password:
|
| 687 |
+
users[username] = {'password': hash_password(password)}
|
| 688 |
+
save_users(users, jurisdiction)
|
| 689 |
+
# Load grading template for this jurisdiction
|
| 690 |
+
df = load_grading_template(jurisdiction)
|
| 691 |
+
# Auto-score all Unknown answers for the new user in the background
|
| 692 |
+
auto_score_all_unknown_answers_for_new_user(username, jurisdiction, df)
|
| 693 |
+
st.success("User registered successfully! Please click Login.")
|
| 694 |
+
else:
|
| 695 |
+
st.error("Please enter both username and password")
|
| 696 |
|
| 697 |
# Main application (only shown if logged in)
|
| 698 |
elif st.session_state.logged_in:
|
| 699 |
username = st.session_state.username
|
| 700 |
+
jurisdiction = st.session_state.jurisdiction
|
| 701 |
current_index = st.session_state.current_index
|
| 702 |
|
| 703 |
+
# Load grading template for the selected jurisdiction
|
| 704 |
+
df = load_grading_template(jurisdiction)
|
| 705 |
+
|
| 706 |
# Debug info (can be removed in production)
|
| 707 |
with st.sidebar:
|
| 708 |
with st.expander("Debug Info"):
|
|
|
|
| 710 |
st.write(f"HF Token configured: {HF_TOKEN is not None}")
|
| 711 |
st.write(f"HF API initialized: {hf_api is not None}")
|
| 712 |
if username:
|
| 713 |
+
st.write(f"Jurisdiction: `{jurisdiction}`")
|
| 714 |
+
st.write(f"User parquet file: `{jurisdiction}/users/{username}_answers.parquet`")
|
| 715 |
+
st.write(f"Users file: `{jurisdiction}/users/users.json`")
|
| 716 |
|
| 717 |
# Check if we should show the annotation guide first
|
| 718 |
if st.session_state.show_guide:
|
|
|
|
| 851 |
st.session_state.next_term = None
|
| 852 |
st.rerun()
|
| 853 |
|
| 854 |
+
# Get term_category_pairs for this jurisdiction
|
| 855 |
+
term_category_pairs = get_term_category_pairs(df)
|
| 856 |
+
total_pairs = len(term_category_pairs)
|
| 857 |
+
|
| 858 |
elif current_index < total_pairs:
|
| 859 |
term_name, category_name = term_category_pairs[current_index]
|
| 860 |
+
category = get_category_for_pair(term_name, category_name, df)
|
| 861 |
+
term = get_term_instance(term_name, df)
|
| 862 |
|
| 863 |
# Safety check: skip if category has no subcategories (shouldn't happen due to filtering, but just in case)
|
| 864 |
if not category or len(category.subcategories) == 0:
|
|
|
|
| 884 |
# Check visible subcategories
|
| 885 |
for i, subcat in enumerate(category.subcategories):
|
| 886 |
saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
|
| 887 |
+
subcat.question, subcat.answer, jurisdiction)
|
| 888 |
if saved_score is None:
|
| 889 |
category_fully_answered = False
|
| 890 |
break
|
|
|
|
| 893 |
if category_fully_answered:
|
| 894 |
unknown_answers = auto_score_unknown_answers(username, term_name, category_name, df)
|
| 895 |
for subcategory, question, answer, score in unknown_answers:
|
| 896 |
+
saved_score = get_user_answer(username, term_name, category_name, subcategory, question, answer, jurisdiction)
|
| 897 |
if saved_score is None:
|
| 898 |
category_fully_answered = False
|
| 899 |
break
|
|
|
|
| 916 |
st.session_state.original_selections[pair_key][radio_key] = st.session_state[radio_key]
|
| 917 |
else:
|
| 918 |
saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
|
| 919 |
+
subcat.question, subcat.answer, jurisdiction)
|
| 920 |
if saved_score is not None:
|
| 921 |
score_to_option = {v: k for k, v in ASSESSMENT_TO_SCORE.items()}
|
| 922 |
if saved_score in score_to_option:
|
|
|
|
| 955 |
# Get saved value to determine if we should set a default index
|
| 956 |
# Don't pre-set session_state - only use index parameter to avoid conflicts
|
| 957 |
saved_score = get_user_answer(username, term_name, category_name, subcat.subcategory_name,
|
| 958 |
+
subcat.question, subcat.answer, jurisdiction)
|
| 959 |
default_index = None
|
| 960 |
|
| 961 |
# Check if there's a saved value in parquet file
|
|
|
|
| 1040 |
st.session_state.back_current_index = current_index
|
| 1041 |
st.session_state.show_term_back_warning = True
|
| 1042 |
st.session_state.back_current_term = term.formatted_name
|
| 1043 |
+
prev_term = get_term_instance(prev_term_name, df)
|
| 1044 |
st.session_state.back_previous_term = prev_term.formatted_name
|
| 1045 |
else:
|
| 1046 |
# Same term, just move back
|
|
|
|
| 1084 |
|
| 1085 |
# Update all answers in a single commit
|
| 1086 |
commit_message = f"Update answers for {username} - {term_name} - {category_name}"
|
| 1087 |
+
save_success = update_category_answers(username, term_name, category_name, answers_list, jurisdiction, commit_message)
|
| 1088 |
|
| 1089 |
if save_success:
|
| 1090 |
# Save current selections as new originals
|
|
|
|
| 1116 |
# Update all answers in a single commit
|
| 1117 |
# Note: Unknown answers are already auto-scored during account creation
|
| 1118 |
commit_message = f"Update answers for {username} - {term_name} - {category_name}"
|
| 1119 |
+
save_success = update_category_answers(username, term_name, category_name, answers_list, jurisdiction, commit_message)
|
| 1120 |
|
| 1121 |
if save_success:
|
| 1122 |
# Mark as submitted and save original selections
|
|
|
|
| 1139 |
# Show intermediate page
|
| 1140 |
st.session_state.show_term_complete = True
|
| 1141 |
st.session_state.completed_term = term.formatted_name
|
| 1142 |
+
next_term = get_term_instance(next_term_name, df)
|
| 1143 |
st.session_state.next_term = next_term.formatted_name
|
| 1144 |
else:
|
| 1145 |
# Same term, just move to next category
|
|
|
|
| 1174 |
# Moving to a different term - show term switching page
|
| 1175 |
st.session_state.show_term_complete = True
|
| 1176 |
st.session_state.completed_term = term.formatted_name
|
| 1177 |
+
next_term = get_term_instance(next_term_name, df)
|
| 1178 |
st.session_state.next_term = next_term.formatted_name
|
| 1179 |
else:
|
| 1180 |
# Same term, just move forward
|
|
|
|
| 1196 |
if st.button("Logout"):
|
| 1197 |
st.session_state.logged_in = False
|
| 1198 |
st.session_state.username = None
|
| 1199 |
+
st.session_state.jurisdiction = None
|
| 1200 |
st.session_state.current_index = 0
|
| 1201 |
st.session_state.show_guide = True
|
| 1202 |
st.session_state.submitted_pairs = set()
|