"""Skill Assessment Evaluation Script for Multiple Datasets.""" import json import sys from collections import defaultdict import numpy as np def detect_dataset_from_video_id(video_id): """Detect dataset from video ID patterns.""" video_id = str(video_id).lower() # JIGSAWS dataset - patterns like "knot_tying_b001", "suturing_b001", etc. if any(pattern in video_id for pattern in ["knot_tying", "suturing", "needle_passing"]) and "_b" in video_id: return "jigsaws" # AVOS dataset - YouTube video IDs if len(video_id) == 11 and any(c.isalpha() for c in video_id): return "AVOS" # CoPESD dataset - numerical IDs with parts if "_part" in video_id and video_id.replace("_part", "").split("_")[0].isdigit(): return "CoPESD" # CholecT50 dataset if "video" in video_id.lower() and any(c.isdigit() for c in video_id): return "CholecT50" # NurViD dataset - specific patterns if any(keyword in video_id for keyword in ["nur", "nursing", "medical"]): return "NurViD" return "Unknown" def detect_dataset_from_question(question): """Detect dataset from question text patterns.""" question_lower = question.lower() # JIGSAWS dataset - look for robotic surgery, bench-top tasks if any(pattern in question_lower for pattern in ["robotic bench-top", "knot-tying", "needle-passing", "suturing", "surgical technique"]): return "jigsaws" if "avos" in question_lower: return "AVOS" elif "copesd" in question_lower: return "CoPESD" elif "cholect50" in question_lower or "cholec" in question_lower: return "CholecT50" elif "nurvid" in question_lower or "nursing" in question_lower: return "NurViD" # Check for dataset-specific action patterns if any(action in question_lower for action in ["cutting", "tying", "suturing"]): return "AVOS" elif "forceps" in question_lower and "knife" in question_lower: return "CoPESD" return "Unknown" def parse_skill_scores(skill_text): """Parse skill assessment text into individual scores.""" import re # Extract all X/5 patterns pattern = r'(\d+)/5' scores = re.findall(pattern, skill_text) # print("scores in parse_skill_scores", scores) if scores: # Convert to integers and return average numeric_scores = [int(score) for score in scores] # print("numeric_scores", numeric_scores) return sum(numeric_scores) / len(numeric_scores) return None def parse_aspect_scores(skill_text): """Parse aspect scores from text like 'Respect for tissue: 2/5, Suture/needle handling: 1/5, ...'""" import re # Split by commas first, then parse each part parts = skill_text.split(',') aspect_scores = {} for part in parts: # Pattern to match aspect name followed by score within each part match = re.search(r'([^:]+?):\s*(\d+)/5', part.strip()) if match: aspect_name = match.group(1).strip() score = int(match.group(2)) aspect_scores[aspect_name] = score # print("parts", parts) return aspect_scores def normalize_skill_level(skill_text): """Normalize skill level text to standard format for classification.""" skill_text = skill_text.strip().lower() # print("skill_text in normalize_skill_level") # print("-"*50) # print(skill_text) # print("-"*50) # JIGSAWS skill level mapping - treat as direct classification skill_mappings = { # Direct skill level names "novice": "novice", "beginner": "novice", "intermediate": "intermediate", "expert": "expert", "advanced": "expert", # Letter codes (JIGSAWS uses N, I, E) "n": "novice", "i": "intermediate", "e": "expert", # Numeric mappings (if any) "1": "novice", "2": "intermediate", "3": "expert", # Quality descriptors "low": "novice", "medium": "intermediate", "high": "expert", "poor": "novice", "good": "intermediate", "excellent": "expert" } # Check for exact matches first if skill_text in skill_mappings: # print("skill_text in skill_mappings", skill_text, "skill_mappings[skill_text]", skill_mappings[skill_text]) return skill_mappings[skill_text] # Check for partial matches for key, value in skill_mappings.items(): if key in skill_text: return value # Return original if no mapping found (for debugging) print(f"Warning: No mapping found for skill_text: '{skill_text}'") return skill_text def convert_scores_to_skill_level(skill_text): """Convert structured skill assessment scores to skill level.""" # If it contains scores (like "Respect for tissue: 1/5, ..."), parse them avg_score = parse_skill_scores(skill_text) # print("avg_score in convert_scores_to_skill_level", avg_score) if avg_score is not None: # Convert average score to skill level if avg_score <= 2.0: return "novice" elif avg_score <= 3.5: return "intermediate" else: return "expert" # If no scores found, return None return None def calculate_balanced_accuracy(per_class_correct, per_class_total): """Calculate balanced accuracy across classes.""" if not per_class_total: return 0.0 # Calculate recall for each class recalls = [] for class_name in per_class_total: if per_class_total[class_name] > 0: recall = per_class_correct[class_name] / per_class_total[class_name] recalls.append(recall) # Balanced accuracy is the mean of per-class recalls if recalls: return np.mean(recalls) else: return 0.0 def group_records_by_dataset(data): """Group skill assessment records by dataset.""" dataset_records = defaultdict(list) for idx, record in data.items(): if record.get("qa_type") != "skill_assessment": continue # Get dataset from data_source field if available (preferred method) dataset = record.get("data_source", "Unknown") # Fallback to detection methods if data_source is not available if dataset == "Unknown" or not dataset: dataset = detect_dataset_from_video_id(record["metadata"]["video_id"]) if dataset == "Unknown": dataset = detect_dataset_from_question(record["question"]) record_data = { "question": record["question"], "answer": record["answer"], "gnd": record["gnd"], "video_id": record["metadata"]["video_id"], "struc_info": record.get("struc_info", []) } dataset_records[dataset].append(record_data) return dataset_records def evaluate_skill_assessment(records): """Evaluate skill assessment using accuracy metric.""" if not records: return {"accuracy": 0.0, "correct": 0, "total": 0} correct = 0 total = 0 per_skill_correct = defaultdict(int) per_skill_total = defaultdict(int) # Per-aspect evaluation aspect_correct = defaultdict(int) aspect_total = defaultdict(int) aspect_mae = defaultdict(float) # Mean Absolute Error for aspects for record in records: # print("record") # print(record) # print("--------------------------------") # Get predicted skill level from the answer # Parse structured scores (like "Respect for tissue: 1/5, ...") pred_skill = convert_scores_to_skill_level(record["answer"]) if pred_skill is None: print(f"Warning: Could not parse answer for skill level: '{record['answer']}'. Skipping record.") continue # print("pred_skill", pred_skill) # print() # Get ground truth skill level from struc_info if available, otherwise from gnd text gnd_skill = None if record.get("struc_info") and len(record["struc_info"]) > 0: skill_level_code = record["struc_info"][0].get("skill_level", "") if skill_level_code: gnd_skill = normalize_skill_level(skill_level_code) # Fallback to parsing the ground truth text if struc_info not available if not gnd_skill: gnd_skill = convert_scores_to_skill_level(record["gnd"]) if gnd_skill is None: print(f"Warning: Could not parse ground truth for skill level: '{record['gnd']}'. Skipping record.") continue per_skill_total[gnd_skill] += 1 total += 1 if pred_skill == gnd_skill: correct += 1 per_skill_correct[gnd_skill] += 1 # Parse aspect scores from text pred_aspects = parse_aspect_scores(record["answer"]) gnd_aspects = None # Get ground truth aspect scores from struc_info if available if record.get("struc_info") and len(record["struc_info"]) > 0: gnd_aspects = record["struc_info"][0].get("skill_scores", {}) # Fallback to parsing ground truth text if not gnd_aspects: gnd_aspects = parse_aspect_scores(record["gnd"]) # Evaluate each aspect for aspect_name in gnd_aspects: if aspect_name in pred_aspects: gnd_score = gnd_aspects[aspect_name] pred_score = pred_aspects[aspect_name] aspect_total[aspect_name] += 1 # Exact match accuracy if pred_score == gnd_score: aspect_correct[aspect_name] += 1 # Mean Absolute Error aspect_mae[aspect_name] += abs(pred_score - gnd_score) accuracy = correct / total if total > 0 else 0.0 # Calculate per-skill accuracies per_skill_accuracies = {} for skill in per_skill_total: skill_correct = per_skill_correct[skill] skill_total = per_skill_total[skill] skill_accuracy = skill_correct / skill_total if skill_total > 0 else 0.0 per_skill_accuracies[skill] = { "accuracy": skill_accuracy, "correct": skill_correct, "total": skill_total } # Calculate balanced accuracy for aspects only aspect_balanced_acc = calculate_balanced_accuracy(aspect_correct, aspect_total) # Calculate per-aspect metrics per_aspect_metrics = {} for aspect in aspect_total: aspect_acc = aspect_correct[aspect] / aspect_total[aspect] if aspect_total[aspect] > 0 else 0.0 aspect_mae_avg = aspect_mae[aspect] / aspect_total[aspect] if aspect_total[aspect] > 0 else 0.0 per_aspect_metrics[aspect] = { "accuracy": aspect_acc, "correct": aspect_correct[aspect], "total": aspect_total[aspect], "mae": aspect_mae_avg } return { "accuracy": accuracy, "correct": correct, "total": total, "per_skill": per_skill_accuracies, "per_aspect": per_aspect_metrics, "aspect_balanced_accuracy": aspect_balanced_acc } def evaluate_dataset_skill_assessment(dataset_name, dataset_records): """Evaluate skill assessment for a specific dataset.""" print(f"\n=== Skill Assessment Evaluation for {dataset_name} ===") print(f"Number of records: {len(dataset_records)}") if not dataset_records: print("No records found for this dataset.") return {} # Evaluate the dataset results = evaluate_skill_assessment(dataset_records) # Print per-aspect results FIRST (main focus) if "per_aspect" in results and results["per_aspect"]: print(f"\n*** PER-ASPECT PERFORMANCE ***") print(f"Aspect Balanced Accuracy: {results.get('aspect_balanced_accuracy', 0.0):.4f}") print("\nIndividual Aspect Performance:") # Sort aspects by name for consistent output sorted_aspects = sorted(results["per_aspect"].items()) for aspect, metrics in sorted_aspects: print(f" {aspect}:") print(f" Accuracy: {metrics['accuracy']:.4f} ({metrics['correct']}/{metrics['total']})") print(f" Mean Absolute Error: {metrics['mae']:.3f}") # Print overall skill level results (secondary) print(f"\n*** OVERALL SKILL LEVEL CLASSIFICATION ***") print(f"Overall Accuracy: {results['accuracy']:.4f} ({results['correct']}/{results['total']})") # Print per-skill results if "per_skill" in results and results["per_skill"]: print("\nPer-skill Level Accuracy:") sorted_skills = sorted(results["per_skill"].items()) for skill, metrics in sorted_skills: print(f" {skill}: {metrics['accuracy']:.4f} ({metrics['correct']}/{metrics['total']})") return results def main(): """Main evaluation function.""" if len(sys.argv) > 1: output_file = sys.argv[1] else: print("Usage: python eval_skill_assessment.py ") sys.exit(1) print(f"Loading results from: {output_file}") with open(output_file, "r") as f: infer_output = json.load(f) # Group records by dataset dataset_records = group_records_by_dataset(infer_output) print(f"\nFound datasets: {list(dataset_records.keys())}") for dataset, records in dataset_records.items(): print(f" {dataset}: {len(records)} skill assessment records") if not any(dataset_records.values()): print("No skill assessment records found!") return # Evaluate each dataset all_results = {} for dataset_name, records in dataset_records.items(): if records: # Only evaluate if we have records results = evaluate_dataset_skill_assessment(dataset_name, records) all_results[dataset_name] = results # Print summary print(f"\n{'='*80}") print("SKILL ASSESSMENT EVALUATION SUMMARY") print(f"{'='*80}") for dataset_name, results in all_results.items(): if results: print(f"\n{dataset_name}:") # Show per-aspect summary first if "per_aspect" in results and results["per_aspect"]: print(f" Aspect Balanced Accuracy: {results.get('aspect_balanced_accuracy', 0.0):.4f}") print(" Per-Aspect Accuracy:") sorted_aspects = sorted(results["per_aspect"].items()) for aspect, metrics in sorted_aspects: print(f" {aspect}: {metrics['accuracy']:.4f} (MAE: {metrics['mae']:.3f})") # Show overall skill level accuracy print(f" Overall Skill Level Accuracy: {results['accuracy']:.4f} ({results['correct']}/{results['total']})") if __name__ == "__main__": main()