Spaces:
Build error
Build error
| import random | |
| import pandas as pd | |
| import json | |
| def get_current_stage(backend, dataset, stage_splits, threshold=3): | |
| df = backend.get_all_rows() | |
| counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() | |
| # Check Stage 1 | |
| stage1_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage1"]] | |
| if all(counts.get(iid, 0) >= threshold for iid in stage1_ids): | |
| # Check Stage 2 | |
| stage2_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage2"]] | |
| if all(counts.get(iid, 0) >= threshold for iid in stage2_ids): | |
| return 3 | |
| else: | |
| return 2 | |
| return 1 | |
| def get_random_session_samples( | |
| backend, dataset, stage_splits, user_name, num_samples=30 | |
| ): | |
| df = backend.get_all_rows() | |
| # Defensive fallback | |
| if df.empty: | |
| stage = 1 | |
| stage_pool = stage_splits["stage1"] | |
| return random.sample(stage_pool, min(num_samples, len(stage_pool))), stage | |
| global_stage = get_current_stage(backend, dataset, stage_splits) | |
| counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() | |
| seen_ids = set(df[df["user_name"] == user_name]["interpretation_id"]) | |
| # if user finished global_stage, they can see the next stage | |
| for stage_num in range(global_stage, 4): # stages 1 to 3 | |
| stage_key = f"stage{stage_num}" | |
| stage_pool = stage_splits[stage_key] | |
| eligible_indices = [ | |
| i | |
| for i in stage_pool | |
| if counts.get(dataset[i]["interpretation_id"], 0) < 3 | |
| and dataset[i]["interpretation_id"] not in seen_ids | |
| ] | |
| if eligible_indices: | |
| return ( | |
| random.sample( | |
| eligible_indices, min(num_samples, len(eligible_indices)) | |
| ), | |
| stage_num, | |
| ) | |
| # If this user has completed everything (even beyond current stage) | |
| return [], 4 | |