iyosha commited on
Commit
8c04306
·
verified ·
1 Parent(s): 91d38d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -28
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  from uuid import uuid4
5
  from datasets import load_dataset
6
  from collections import Counter
 
7
  from configs import configs
8
  from clients import backend, logger
9
  from backend.helpers import get_random_session_samples
@@ -46,7 +47,7 @@ def human_eval_tab():
46
  if p == configs.USER_PASSWORD and usr.strip() != "":
47
  new_session_id = str(uuid4())
48
  sample_indices, stage = get_random_session_samples(
49
- backend, dataset, STAGE_SPLITS, usr, num_samples=2
50
  )
51
  logger.info(f"Session ID: {new_session_id}, Stage: {stage}")
52
  return (
@@ -271,6 +272,53 @@ def human_eval_tab():
271
  )
272
 
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def get_admin_tab():
275
  with gr.Tab("Admin Console"):
276
  admin_password = gr.Text(label="Enter Admin Password", type="password")
@@ -281,7 +329,7 @@ def get_admin_tab():
281
  def calculate_majority_vote_accuracy(pw):
282
  if pw != configs.ADMIN_PASSWORD:
283
  return gr.update(
284
- visible=True, value="\u274c Incorrect password."
285
  ), gr.update(visible=False)
286
 
287
  df = backend.get_all_rows()
@@ -290,43 +338,131 @@ def get_admin_tab():
290
  visible=False
291
  )
292
 
 
293
  majority_answers = {}
294
  for interp_id, group in df.groupby("interpretation_id"):
295
  answer_counts = Counter(group["answer"])
296
  if answer_counts:
297
  majority_answers[interp_id] = answer_counts.most_common(1)[0][0]
298
 
299
- total = 0
300
- correct = 0
301
-
302
- for sample in dataset:
303
- interp_id = sample["interpretation_id"]
304
- if interp_id not in majority_answers:
305
- continue
306
- predicted_answer = majority_answers[interp_id]
307
- correct_label_idx = sample["label"]
308
- correct_answer_text = sample["possible_answers"][correct_label_idx]
309
- total += 1
310
- if predicted_answer == correct_answer_text:
311
- correct += 1
312
-
313
- acc = correct / total if total > 0 else 0
314
- # calculate total answers submited
315
  total_answers = len(df)
316
- answers_to_go = (3 * len(dataset)) - total_answers
317
  users_count = df["user_id"].nunique()
318
- # update the admin console
319
- return gr.update(visible=False), gr.update(
320
- visible=True,
321
- value=f"""**Accuracy over answered samples:** {acc:.3%} ({correct}/{total})
322
 
323
- **Total answers submitted:** {total_answers}
324
-
325
- **Answers to go:** {answers_to_go}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- **Users count:** {users_count}""",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  )
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  check_btn.click(
331
  fn=calculate_majority_vote_accuracy,
332
  inputs=admin_password,
@@ -339,5 +475,6 @@ with gr.Blocks() as demo:
339
  human_eval_tab()
340
  get_admin_tab()
341
 
342
- #demo.launch()
 
343
 
 
4
  from uuid import uuid4
5
  from datasets import load_dataset
6
  from collections import Counter
7
+ import numpy as np
8
  from configs import configs
9
  from clients import backend, logger
10
  from backend.helpers import get_random_session_samples
 
47
  if p == configs.USER_PASSWORD and usr.strip() != "":
48
  new_session_id = str(uuid4())
49
  sample_indices, stage = get_random_session_samples(
50
+ backend, dataset, STAGE_SPLITS, usr, num_samples=30
51
  )
52
  logger.info(f"Session ID: {new_session_id}, Stage: {stage}")
53
  return (
 
272
  )
273
 
274
 
275
+ def compute_random_sampled_accuracy(df, dataset, n_rounds=100, seed=42):
276
+ rng = np.random.default_rng(seed)
277
+
278
+ # Filter to interpretation_ids with at least 3 user answers
279
+ counts = df.groupby("interpretation_id")["user_id"].nunique()
280
+ eligible_ids = set(counts[counts >= 3].index)
281
+
282
+ # Group answers by interpretation_id
283
+ grouped = df[df["interpretation_id"].isin(eligible_ids)].groupby(
284
+ "interpretation_id"
285
+ )
286
+
287
+ all_scores = []
288
+ total_answered_per_round = []
289
+
290
+ for _ in range(n_rounds):
291
+ correct = 0
292
+ total = 0
293
+
294
+ for interp_id, group in grouped:
295
+ if group.empty:
296
+ continue
297
+
298
+ # Randomly pick one row
299
+ row = group.sample(1, random_state=rng.integers(1e6)).iloc[0]
300
+ answer = row["answer"]
301
+ idx = int(row["index_in_dataset"])
302
+ sample = dataset[idx]
303
+ gt = sample["possible_answers"][sample["label"]]
304
+ total += 1
305
+ if answer == gt:
306
+ correct += 1
307
+
308
+ if total > 0:
309
+ all_scores.append(correct / total)
310
+ total_answered_per_round.append(total)
311
+
312
+ if all_scores:
313
+ mean_acc = np.mean(all_scores)
314
+ mean_total = int(np.mean(total_answered_per_round))
315
+ std_acc = np.std(all_scores, ddof=1) # sample std
316
+ ci_95 = 1.96 * std_acc / np.sqrt(n_rounds)
317
+ return mean_acc, std_acc, mean_total, ci_95
318
+
319
+ return None, None, 0, None
320
+
321
+
322
  def get_admin_tab():
323
  with gr.Tab("Admin Console"):
324
  admin_password = gr.Text(label="Enter Admin Password", type="password")
 
329
  def calculate_majority_vote_accuracy(pw):
330
  if pw != configs.ADMIN_PASSWORD:
331
  return gr.update(
332
+ visible=True, value=" Incorrect password."
333
  ), gr.update(visible=False)
334
 
335
  df = backend.get_all_rows()
 
338
  visible=False
339
  )
340
 
341
+ # Majority vote per interpretation_id
342
  majority_answers = {}
343
  for interp_id, group in df.groupby("interpretation_id"):
344
  answer_counts = Counter(group["answer"])
345
  if answer_counts:
346
  majority_answers[interp_id] = answer_counts.most_common(1)[0][0]
347
 
348
+ counts = df.groupby("interpretation_id")["user_id"].nunique().to_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  total_answers = len(df)
 
350
  users_count = df["user_id"].nunique()
 
 
 
 
351
 
352
+ stage_acc = {}
353
+ stage_completes = {}
354
+ stage_counts = {}
355
+ stage_remaining = {}
356
+
357
+ # global_correct = 0
358
+ # global_total = 0
359
+
360
+ for stage in ["stage1", "stage2", "stage3"]:
361
+ correct, total = 0, 0
362
+ complete = 0
363
+ for i in STAGE_SPLITS[stage]:
364
+ sample = dataset[i]
365
+ interp_id = sample["interpretation_id"]
366
+ label = sample["label"]
367
+ gt = sample["possible_answers"][label]
368
+
369
+ n = counts.get(interp_id, 0)
370
+ if n >= 3:
371
+ complete += 1
372
+ if interp_id in majority_answers:
373
+ pred = majority_answers[interp_id]
374
+ total += 1
375
+ if pred == gt:
376
+ correct += 1
377
+
378
+ stage_counts[stage] = len(STAGE_SPLITS[stage])
379
+ stage_completes[stage] = complete
380
+ stage_remaining[stage] = 3 * len(STAGE_SPLITS[stage]) - sum(
381
+ counts.get(dataset[i]["interpretation_id"], 0)
382
+ for i in STAGE_SPLITS[stage]
383
+ )
384
 
385
+ if complete == len(STAGE_SPLITS[stage]):
386
+ acc = correct / total if total > 0 else 0
387
+ stage_acc[stage] = (acc, correct, total)
388
+ else:
389
+ stage_acc[stage] = None # not shown yet
390
+
391
+ # Determine active stage
392
+ if stage_completes["stage1"] < stage_counts["stage1"]:
393
+ current_stage = "Stage 1"
394
+ elif stage_completes["stage2"] < stage_counts["stage2"]:
395
+ current_stage = "Stage 2"
396
+ else:
397
+ current_stage = "Stage 3"
398
+
399
+ # Majority Vote Accuracy Section
400
+ agg_lines = []
401
+ if stage_acc["stage1"]:
402
+ acc1, c1, t1 = stage_acc["stage1"]
403
+ agg_lines.append(f"- **Stage 1:** {acc1:.2%} ({c1}/{t1})")
404
+ if stage_acc["stage2"]:
405
+ acc2, c2, t2 = stage_acc["stage2"]
406
+ agg_lines.append(
407
+ f"- **Stage 1+2:** {(c1 + c2) / (t1 + t2):.2%} ({c1 + c2}/{t1 + t2})"
408
+ )
409
+ if stage_acc["stage3"]:
410
+ acc3, c3, t3 = stage_acc["stage3"]
411
+ agg_lines.append(
412
+ f"- **All Stages:** {(c1 + c2 + c3) / (t1 + t2 + t3):.2%} ({c1 + c2 + c3}/{t1 + t2 + t3})"
413
+ )
414
+ agg_msg = "\n".join(agg_lines) if agg_lines else "No completed stages yet."
415
+ # Compute random-sampled accuracy
416
+ n_rounds = 100
417
+ rand_acc, rand_std, rand_total, rand_ci = compute_random_sampled_accuracy(
418
+ df, dataset, n_rounds=n_rounds
419
  )
420
 
421
+ # Random-sampled Accuracy
422
+ if rand_acc is not None:
423
+ rand_acc_msg = (
424
+ f"**Accuracy:** {rand_acc:.2%} ± {rand_ci:.2%} (95% CI)\n\n"
425
+ f"Standard deviation: {rand_std:.2%}\n\n"
426
+ f"Samples used: {rand_total} × {n_rounds} rounds"
427
+ )
428
+ else:
429
+ rand_acc_msg = "Random sampling failed (no data)."
430
+
431
+ # Final message (no indentation!)
432
+ msg = f"""
433
+ ## ✅ Accuracy Summary
434
+
435
+ ### Majority Vote
436
+ {agg_msg}
437
+
438
+ ---
439
+
440
+ ### Random-Sampled Accuracy
441
+ {rand_acc_msg}
442
+
443
+ ---
444
+
445
+ ## 📊 Answer Progress
446
+
447
+ - **Total answers submitted:** {total_answers}
448
+ - **Answers to go (global):** {3 * len(dataset) - total_answers}
449
+ - **Unique users:** {users_count}
450
+
451
+ ---
452
+
453
+ ## 🧱 Stage Breakdown
454
+
455
+ | Stage | Completed | Total | Remaining Answers |
456
+ |-------|-----------|--------|-------------------|
457
+ | 1 | {stage_completes['stage1']} / {stage_counts['stage1']} | {stage_counts['stage1']} | {stage_remaining['stage1']} |
458
+ | 2 | {stage_completes['stage2']} / {stage_counts['stage2']} | {stage_counts['stage2']} | {stage_remaining['stage2']} |
459
+ | 3 | {stage_completes['stage3']} / {stage_counts['stage3']} | {stage_counts['stage3']} | {stage_remaining['stage3']} |
460
+
461
+ **➡️ Current Active Stage:** {current_stage}
462
+ """
463
+
464
+ return gr.update(visible=False), gr.update(visible=True, value=msg)
465
+
466
  check_btn.click(
467
  fn=calculate_majority_vote_accuracy,
468
  inputs=admin_password,
 
475
  human_eval_tab()
476
  get_admin_tab()
477
 
478
+
479
+ demo.launch()
480