import random import pandas as pd import json def get_current_stage(backend, dataset, stage_splits, threshold=3): df = backend.get_all_rows() counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() # Check Stage 1 stage1_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage1"]] if all(counts.get(iid, 0) >= threshold for iid in stage1_ids): # Check Stage 2 stage2_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage2"]] if all(counts.get(iid, 0) >= threshold for iid in stage2_ids): return 3 else: return 2 return 1 def get_random_session_samples( backend, dataset, stage_splits, user_name, num_samples=30 ): df = backend.get_all_rows() # Defensive fallback if df.empty: stage = 1 stage_pool = stage_splits["stage1"] return random.sample(stage_pool, min(num_samples, len(stage_pool))), stage global_stage = get_current_stage(backend, dataset, stage_splits) counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict() seen_ids = set(df[df["user_name"] == user_name]["interpretation_id"]) # if user finished global_stage, they can see the next stage for stage_num in range(global_stage, 4): # stages 1 to 3 stage_key = f"stage{stage_num}" stage_pool = stage_splits[stage_key] eligible_indices = [ i for i in stage_pool if counts.get(dataset[i]["interpretation_id"], 0) < 3 and dataset[i]["interpretation_id"] not in seen_ids ] if eligible_indices: return ( random.sample( eligible_indices, min(num_samples, len(eligible_indices)) ), stage_num, ) # If this user has completed everything (even beyond current stage) return [], 4