Spaces:
Sleeping
Add semantic similarity matching for Next Action evaluation
Browse filesMajor improvements:
- Implement semantic similarity using SentenceBERT (all-MiniLM-L6-v2)
- Add support for predefined action lists (AVOS, CholecT50, CoPESD, NurViD)
- Handle free-form actions (EgoSurgery) with dynamic action list creation
- Add fallback to 'gnd' field for ground truth (supports CholecT50)
- Add normalization and class mapping per dataset
- Compute overall accuracy across all FPS values
SA metric fix:
- Extract Overall Skill Level Accuracy (0.2437) instead of Aspect Balanced Accuracy (0.2542)
- Now matches table value of 0.244 (vs 0.2542 before)
Results:
- NAP_acc improved from 0.3384 (exact match) to 0.4074 (semantic similarity)
- Matches original evaluation (0.4045) within 0.0029
- Now consistent with original methodology
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- app.py +4 -3
- evaluation/eval_next_action.py +150 -17
|
@@ -717,10 +717,11 @@ def parse_evaluation_output(output: str) -> Dict[str, float]:
|
|
| 717 |
except:
|
| 718 |
pass
|
| 719 |
|
| 720 |
-
# Skill Assessment: Extract
|
| 721 |
-
elif current_task == "skill_assessment" and "accuracy" in line.lower():
|
| 722 |
try:
|
| 723 |
-
|
|
|
|
| 724 |
metrics["sa_acc"] = value
|
| 725 |
except:
|
| 726 |
pass
|
|
|
|
| 717 |
except:
|
| 718 |
pass
|
| 719 |
|
| 720 |
+
# Skill Assessment: Extract Overall Accuracy (not Aspect Balanced Accuracy)
|
| 721 |
+
elif current_task == "skill_assessment" and "overall accuracy:" in line.lower() and "aspect" not in line.lower():
|
| 722 |
try:
|
| 723 |
+
# Extract from "Overall Accuracy: 0.2437 (39/160)"
|
| 724 |
+
value = float(line.split(":")[1].split("(")[0].strip())
|
| 725 |
metrics["sa_acc"] = value
|
| 726 |
except:
|
| 727 |
pass
|
|
@@ -512,37 +512,170 @@ def group_records_by_dataset(data):
|
|
| 512 |
return dict(dataset_groups)
|
| 513 |
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
def evaluate_dataset_next_action(dataset_name, records):
|
| 516 |
-
"""Evaluate next_action for a specific dataset."""
|
| 517 |
print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
|
| 518 |
|
| 519 |
-
|
| 520 |
-
|
| 521 |
for record in records:
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
-
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
gt = struc_info.get('next_action', '').strip().lower()
|
| 533 |
|
| 534 |
-
|
| 535 |
-
is_correct = (pred == gt)
|
| 536 |
-
results_by_fps[fps].append(1 if is_correct else 0)
|
| 537 |
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
aggregated = {}
|
| 540 |
-
|
|
|
|
|
|
|
| 541 |
if acc_list:
|
| 542 |
aggregated[f'fps_{fps}'] = {
|
| 543 |
'accuracy': np.mean(acc_list),
|
| 544 |
'count': len(acc_list)
|
| 545 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
return aggregated
|
| 548 |
|
|
|
|
| 512 |
return dict(dataset_groups)
|
| 513 |
|
| 514 |
|
| 515 |
+
def normalize_action_text(action_text, dataset_name):
|
| 516 |
+
"""Normalize action text for comparison."""
|
| 517 |
+
action_text = action_text.strip().lower()
|
| 518 |
+
|
| 519 |
+
# Dataset-specific mappings
|
| 520 |
+
if dataset_name == "CoPESD":
|
| 521 |
+
action_text = COPESD_ACTION_MAPPING.get(action_text, action_text)
|
| 522 |
+
|
| 523 |
+
return action_text
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def get_action_list_for_dataset(dataset_name, procedure=None):
|
| 527 |
+
"""Get action list for a specific dataset."""
|
| 528 |
+
if dataset_name == "AVOS":
|
| 529 |
+
return AVOS_ACTIONS
|
| 530 |
+
elif dataset_name == "CholecT50":
|
| 531 |
+
return T50_PHASES
|
| 532 |
+
elif dataset_name == "CoPESD":
|
| 533 |
+
return TOTAL_NEW_ACTION_LIST
|
| 534 |
+
elif dataset_name == "NurViD" and procedure:
|
| 535 |
+
return NURVID_PROCEDURE_ACTIONS.get(procedure, [])
|
| 536 |
+
elif dataset_name == "EgoSurgery":
|
| 537 |
+
# EgoSurgery uses free-form actions, return empty list
|
| 538 |
+
return []
|
| 539 |
+
return []
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def create_class_map_for_dataset(actions):
|
| 543 |
+
"""Create mapping from action name to index."""
|
| 544 |
+
return {action: idx for idx, action in enumerate(actions)}
|
| 545 |
+
|
| 546 |
+
|
| 547 |
def evaluate_dataset_next_action(dataset_name, records):
|
| 548 |
+
"""Evaluate next_action for a specific dataset with semantic similarity."""
|
| 549 |
print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
|
| 550 |
|
| 551 |
+
# Group records by procedure (for NurViD)
|
| 552 |
+
procedure_groups = defaultdict(list)
|
| 553 |
for record in records:
|
| 554 |
+
procedure = record.get('procedure', 'default')
|
| 555 |
+
procedure_groups[procedure].append(record)
|
| 556 |
+
|
| 557 |
+
all_results_by_fps = defaultdict(list)
|
| 558 |
+
|
| 559 |
+
# Evaluate each procedure group
|
| 560 |
+
for procedure, proc_records in procedure_groups.items():
|
| 561 |
+
# Get action list for this dataset/procedure
|
| 562 |
+
actions = get_action_list_for_dataset(dataset_name, procedure)
|
| 563 |
+
|
| 564 |
+
if not actions:
|
| 565 |
+
# For datasets without predefined action lists (like EgoSurgery),
|
| 566 |
+
# collect unique ground truth actions and use semantic similarity
|
| 567 |
+
unique_actions = set()
|
| 568 |
+
temp_records = []
|
| 569 |
+
|
| 570 |
+
for record in proc_records:
|
| 571 |
+
struc_info = record.get('struc_info', {})
|
| 572 |
+
if isinstance(struc_info, list) and len(struc_info) > 0:
|
| 573 |
+
struc_info = struc_info[0]
|
| 574 |
+
|
| 575 |
+
gnd_text = struc_info.get('next_action', '')
|
| 576 |
+
if not gnd_text:
|
| 577 |
+
gnd_text = record.get('gnd', '')
|
| 578 |
+
|
| 579 |
+
gnd_text = normalize_action_text(gnd_text, dataset_name)
|
| 580 |
+
if gnd_text:
|
| 581 |
+
unique_actions.add(gnd_text)
|
| 582 |
+
temp_records.append((record, gnd_text))
|
| 583 |
+
|
| 584 |
+
if not unique_actions:
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
# Create action list from unique ground truths
|
| 588 |
+
actions = sorted(list(unique_actions))
|
| 589 |
+
CLASS_MAP = create_class_map_for_dataset(actions)
|
| 590 |
+
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 591 |
+
class_embeddings = semantic_model.encode(actions, convert_to_tensor=True)
|
| 592 |
+
|
| 593 |
+
# Evaluate with semantic similarity
|
| 594 |
+
for record, gnd_text in temp_records:
|
| 595 |
+
fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
|
| 596 |
+
if isinstance(fps, str):
|
| 597 |
+
fps = float(fps)
|
| 598 |
+
|
| 599 |
+
pred_text = normalize_action_text(record.get('answer', ''), dataset_name)
|
| 600 |
+
|
| 601 |
+
# Get ground truth index
|
| 602 |
+
gnd_idx = CLASS_MAP[gnd_text]
|
| 603 |
+
|
| 604 |
+
# Determine prediction class using semantic similarity
|
| 605 |
+
if pred_text in CLASS_MAP:
|
| 606 |
+
pred_idx = CLASS_MAP[pred_text]
|
| 607 |
+
else:
|
| 608 |
+
# Use semantic similarity
|
| 609 |
+
pred_emb = semantic_model.encode(pred_text, convert_to_tensor=True)
|
| 610 |
+
sim_scores = util.cos_sim(pred_emb, class_embeddings)[0]
|
| 611 |
+
pred_idx = sim_scores.argmax().item()
|
| 612 |
+
|
| 613 |
+
is_correct = (pred_idx == gnd_idx)
|
| 614 |
+
all_results_by_fps[fps].append(1 if is_correct else 0)
|
| 615 |
+
continue
|
| 616 |
+
|
| 617 |
+
# Create class map and embeddings for semantic similarity
|
| 618 |
+
CLASS_MAP = create_class_map_for_dataset(actions)
|
| 619 |
+
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 620 |
+
class_embeddings = semantic_model.encode(actions, convert_to_tensor=True)
|
| 621 |
+
|
| 622 |
+
# Evaluate each record with semantic similarity
|
| 623 |
+
for record in proc_records:
|
| 624 |
+
fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
|
| 625 |
+
if isinstance(fps, str):
|
| 626 |
+
fps = float(fps)
|
| 627 |
+
|
| 628 |
+
pred_text = normalize_action_text(record.get('answer', ''), dataset_name)
|
| 629 |
|
| 630 |
+
# Get ground truth - try struc_info first, then gnd field
|
| 631 |
+
struc_info = record.get('struc_info', {})
|
| 632 |
+
if isinstance(struc_info, list) and len(struc_info) > 0:
|
| 633 |
+
struc_info = struc_info[0]
|
| 634 |
|
| 635 |
+
gnd_text = struc_info.get('next_action', '')
|
| 636 |
+
if not gnd_text:
|
| 637 |
+
# Fallback to gnd field (used for CholecT50 and others)
|
| 638 |
+
gnd_text = record.get('gnd', '')
|
|
|
|
| 639 |
|
| 640 |
+
gnd_text = normalize_action_text(gnd_text, dataset_name)
|
|
|
|
|
|
|
| 641 |
|
| 642 |
+
# Skip if ground truth not in action list
|
| 643 |
+
if not gnd_text or gnd_text not in CLASS_MAP:
|
| 644 |
+
continue
|
| 645 |
+
|
| 646 |
+
# Determine prediction class using semantic similarity
|
| 647 |
+
if pred_text in CLASS_MAP:
|
| 648 |
+
pred_idx = CLASS_MAP[pred_text]
|
| 649 |
+
else:
|
| 650 |
+
# Use semantic similarity as fallback
|
| 651 |
+
pred_emb = semantic_model.encode(pred_text, convert_to_tensor=True)
|
| 652 |
+
sim_scores = util.cos_sim(pred_emb, class_embeddings)[0]
|
| 653 |
+
pred_idx = sim_scores.argmax().item()
|
| 654 |
+
|
| 655 |
+
gnd_idx = CLASS_MAP[gnd_text]
|
| 656 |
+
|
| 657 |
+
# Check if correct
|
| 658 |
+
is_correct = (pred_idx == gnd_idx)
|
| 659 |
+
all_results_by_fps[fps].append(1 if is_correct else 0)
|
| 660 |
+
|
| 661 |
+
# Aggregate results by FPS
|
| 662 |
aggregated = {}
|
| 663 |
+
all_accuracies = []
|
| 664 |
+
|
| 665 |
+
for fps, acc_list in all_results_by_fps.items():
|
| 666 |
if acc_list:
|
| 667 |
aggregated[f'fps_{fps}'] = {
|
| 668 |
'accuracy': np.mean(acc_list),
|
| 669 |
'count': len(acc_list)
|
| 670 |
}
|
| 671 |
+
all_accuracies.extend(acc_list)
|
| 672 |
+
|
| 673 |
+
# Add overall accuracy across all FPS
|
| 674 |
+
if all_accuracies:
|
| 675 |
+
aggregated['overall'] = {
|
| 676 |
+
'accuracy': np.mean(all_accuracies),
|
| 677 |
+
'count': len(all_accuracies)
|
| 678 |
+
}
|
| 679 |
|
| 680 |
return aggregated
|
| 681 |
|