evaluation / backend /helpers.py
iyosha's picture
Update backend/helpers.py
a95ee59 verified
import random
import pandas as pd
import json
def get_current_stage(backend, dataset, stage_splits, threshold=3):
df = backend.get_all_rows()
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
# Check Stage 1
stage1_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage1"]]
if all(counts.get(iid, 0) >= threshold for iid in stage1_ids):
# Check Stage 2
stage2_ids = [dataset[i]["interpretation_id"] for i in stage_splits["stage2"]]
if all(counts.get(iid, 0) >= threshold for iid in stage2_ids):
return 3
else:
return 2
return 1
def get_random_session_samples(
backend, dataset, stage_splits, user_name, num_samples=30
):
df = backend.get_all_rows()
# Defensive fallback
if df.empty:
stage = 1
stage_pool = stage_splits["stage1"]
return random.sample(stage_pool, min(num_samples, len(stage_pool))), stage
global_stage = get_current_stage(backend, dataset, stage_splits)
counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
seen_ids = set(df[df["user_name"] == user_name]["interpretation_id"])
# if user finished global_stage, they can see the next stage
for stage_num in range(global_stage, 4): # stages 1 to 3
stage_key = f"stage{stage_num}"
stage_pool = stage_splits[stage_key]
eligible_indices = [
i
for i in stage_pool
if counts.get(dataset[i]["interpretation_id"], 0) < 3
and dataset[i]["interpretation_id"] not in seen_ids
]
if eligible_indices:
return (
random.sample(
eligible_indices, min(num_samples, len(eligible_indices))
),
stage_num,
)
# If this user has completed everything (even beyond current stage)
return [], 4