MedGRPO Team Claude Sonnet 4.5 commited on
Commit
a66b9a4
·
1 Parent(s): 6d8dbb2

Add semantic similarity matching for Next Action evaluation

Browse files

Major improvements:
- Implement semantic similarity using SentenceBERT (all-MiniLM-L6-v2)
- Add support for predefined action lists (AVOS, CholecT50, CoPESD, NurViD)
- Handle free-form actions (EgoSurgery) with dynamic action list creation
- Add fallback to 'gnd' field for ground truth (supports CholecT50)
- Add normalization and class mapping per dataset
- Compute overall accuracy across all FPS values

SA metric fix:
- Extract Overall Skill Level Accuracy (0.2437) instead of Aspect Balanced Accuracy (0.2542)
- Now matches table value of 0.244 (vs 0.2542 before)

Results:
- NAP_acc improved from 0.3384 (exact match) to 0.4074 (semantic similarity)
- Matches original evaluation (0.4045) within 0.0029
- Now consistent with original methodology

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +4 -3
  2. evaluation/eval_next_action.py +150 -17
app.py CHANGED
@@ -717,10 +717,11 @@ def parse_evaluation_output(output: str) -> Dict[str, float]:
717
  except:
718
  pass
719
 
720
- # Skill Assessment: Extract accuracy
721
- elif current_task == "skill_assessment" and "accuracy" in line.lower():
722
  try:
723
- value = float(line.split(":")[-1].strip())
 
724
  metrics["sa_acc"] = value
725
  except:
726
  pass
 
717
  except:
718
  pass
719
 
720
+ # Skill Assessment: Extract Overall Accuracy (not Aspect Balanced Accuracy)
721
+ elif current_task == "skill_assessment" and "overall accuracy:" in line.lower() and "aspect" not in line.lower():
722
  try:
723
+ # Extract from "Overall Accuracy: 0.2437 (39/160)"
724
+ value = float(line.split(":")[1].split("(")[0].strip())
725
  metrics["sa_acc"] = value
726
  except:
727
  pass
evaluation/eval_next_action.py CHANGED
@@ -512,37 +512,170 @@ def group_records_by_dataset(data):
512
  return dict(dataset_groups)
513
 
514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  def evaluate_dataset_next_action(dataset_name, records):
516
- """Evaluate next_action for a specific dataset."""
517
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
518
 
519
- results_by_fps = defaultdict(list)
520
-
521
  for record in records:
522
- fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
523
- if isinstance(fps, str):
524
- fps = float(fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
 
526
- pred = record.get('answer', '').strip().lower()
 
 
 
527
 
528
- # Get ground truth
529
- struc_info = record.get('struc_info', {})
530
- if isinstance(struc_info, list) and len(struc_info) > 0:
531
- struc_info = struc_info[0]
532
- gt = struc_info.get('next_action', '').strip().lower()
533
 
534
- # Simple exact match accuracy
535
- is_correct = (pred == gt)
536
- results_by_fps[fps].append(1 if is_correct else 0)
537
 
538
- # Aggregate results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  aggregated = {}
540
- for fps, acc_list in results_by_fps.items():
 
 
541
  if acc_list:
542
  aggregated[f'fps_{fps}'] = {
543
  'accuracy': np.mean(acc_list),
544
  'count': len(acc_list)
545
  }
 
 
 
 
 
 
 
 
546
 
547
  return aggregated
548
 
 
512
  return dict(dataset_groups)
513
 
514
 
515
+ def normalize_action_text(action_text, dataset_name):
516
+ """Normalize action text for comparison."""
517
+ action_text = action_text.strip().lower()
518
+
519
+ # Dataset-specific mappings
520
+ if dataset_name == "CoPESD":
521
+ action_text = COPESD_ACTION_MAPPING.get(action_text, action_text)
522
+
523
+ return action_text
524
+
525
+
526
+ def get_action_list_for_dataset(dataset_name, procedure=None):
527
+ """Get action list for a specific dataset."""
528
+ if dataset_name == "AVOS":
529
+ return AVOS_ACTIONS
530
+ elif dataset_name == "CholecT50":
531
+ return T50_PHASES
532
+ elif dataset_name == "CoPESD":
533
+ return TOTAL_NEW_ACTION_LIST
534
+ elif dataset_name == "NurViD" and procedure:
535
+ return NURVID_PROCEDURE_ACTIONS.get(procedure, [])
536
+ elif dataset_name == "EgoSurgery":
537
+ # EgoSurgery uses free-form actions, return empty list
538
+ return []
539
+ return []
540
+
541
+
542
+ def create_class_map_for_dataset(actions):
543
+ """Create mapping from action name to index."""
544
+ return {action: idx for idx, action in enumerate(actions)}
545
+
546
+
547
  def evaluate_dataset_next_action(dataset_name, records):
548
+ """Evaluate next_action for a specific dataset with semantic similarity."""
549
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
550
 
551
+ # Group records by procedure (for NurViD)
552
+ procedure_groups = defaultdict(list)
553
  for record in records:
554
+ procedure = record.get('procedure', 'default')
555
+ procedure_groups[procedure].append(record)
556
+
557
+ all_results_by_fps = defaultdict(list)
558
+
559
+ # Evaluate each procedure group
560
+ for procedure, proc_records in procedure_groups.items():
561
+ # Get action list for this dataset/procedure
562
+ actions = get_action_list_for_dataset(dataset_name, procedure)
563
+
564
+ if not actions:
565
+ # For datasets without predefined action lists (like EgoSurgery),
566
+ # collect unique ground truth actions and use semantic similarity
567
+ unique_actions = set()
568
+ temp_records = []
569
+
570
+ for record in proc_records:
571
+ struc_info = record.get('struc_info', {})
572
+ if isinstance(struc_info, list) and len(struc_info) > 0:
573
+ struc_info = struc_info[0]
574
+
575
+ gnd_text = struc_info.get('next_action', '')
576
+ if not gnd_text:
577
+ gnd_text = record.get('gnd', '')
578
+
579
+ gnd_text = normalize_action_text(gnd_text, dataset_name)
580
+ if gnd_text:
581
+ unique_actions.add(gnd_text)
582
+ temp_records.append((record, gnd_text))
583
+
584
+ if not unique_actions:
585
+ continue
586
+
587
+ # Create action list from unique ground truths
588
+ actions = sorted(list(unique_actions))
589
+ CLASS_MAP = create_class_map_for_dataset(actions)
590
+ semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
591
+ class_embeddings = semantic_model.encode(actions, convert_to_tensor=True)
592
+
593
+ # Evaluate with semantic similarity
594
+ for record, gnd_text in temp_records:
595
+ fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
596
+ if isinstance(fps, str):
597
+ fps = float(fps)
598
+
599
+ pred_text = normalize_action_text(record.get('answer', ''), dataset_name)
600
+
601
+ # Get ground truth index
602
+ gnd_idx = CLASS_MAP[gnd_text]
603
+
604
+ # Determine prediction class using semantic similarity
605
+ if pred_text in CLASS_MAP:
606
+ pred_idx = CLASS_MAP[pred_text]
607
+ else:
608
+ # Use semantic similarity
609
+ pred_emb = semantic_model.encode(pred_text, convert_to_tensor=True)
610
+ sim_scores = util.cos_sim(pred_emb, class_embeddings)[0]
611
+ pred_idx = sim_scores.argmax().item()
612
+
613
+ is_correct = (pred_idx == gnd_idx)
614
+ all_results_by_fps[fps].append(1 if is_correct else 0)
615
+ continue
616
+
617
+ # Create class map and embeddings for semantic similarity
618
+ CLASS_MAP = create_class_map_for_dataset(actions)
619
+ semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
620
+ class_embeddings = semantic_model.encode(actions, convert_to_tensor=True)
621
+
622
+ # Evaluate each record with semantic similarity
623
+ for record in proc_records:
624
+ fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
625
+ if isinstance(fps, str):
626
+ fps = float(fps)
627
+
628
+ pred_text = normalize_action_text(record.get('answer', ''), dataset_name)
629
 
630
+ # Get ground truth - try struc_info first, then gnd field
631
+ struc_info = record.get('struc_info', {})
632
+ if isinstance(struc_info, list) and len(struc_info) > 0:
633
+ struc_info = struc_info[0]
634
 
635
+ gnd_text = struc_info.get('next_action', '')
636
+ if not gnd_text:
637
+ # Fallback to gnd field (used for CholecT50 and others)
638
+ gnd_text = record.get('gnd', '')
 
639
 
640
+ gnd_text = normalize_action_text(gnd_text, dataset_name)
 
 
641
 
642
+ # Skip if ground truth not in action list
643
+ if not gnd_text or gnd_text not in CLASS_MAP:
644
+ continue
645
+
646
+ # Determine prediction class using semantic similarity
647
+ if pred_text in CLASS_MAP:
648
+ pred_idx = CLASS_MAP[pred_text]
649
+ else:
650
+ # Use semantic similarity as fallback
651
+ pred_emb = semantic_model.encode(pred_text, convert_to_tensor=True)
652
+ sim_scores = util.cos_sim(pred_emb, class_embeddings)[0]
653
+ pred_idx = sim_scores.argmax().item()
654
+
655
+ gnd_idx = CLASS_MAP[gnd_text]
656
+
657
+ # Check if correct
658
+ is_correct = (pred_idx == gnd_idx)
659
+ all_results_by_fps[fps].append(1 if is_correct else 0)
660
+
661
+ # Aggregate results by FPS
662
  aggregated = {}
663
+ all_accuracies = []
664
+
665
+ for fps, acc_list in all_results_by_fps.items():
666
  if acc_list:
667
  aggregated[f'fps_{fps}'] = {
668
  'accuracy': np.mean(acc_list),
669
  'count': len(acc_list)
670
  }
671
+ all_accuracies.extend(acc_list)
672
+
673
+ # Add overall accuracy across all FPS
674
+ if all_accuracies:
675
+ aggregated['overall'] = {
676
+ 'accuracy': np.mean(all_accuracies),
677
+ 'count': len(all_accuracies)
678
+ }
679
 
680
  return aggregated
681