import json import re from matplotlib import text import numpy as np from typing import Tuple from collections import defaultdict def extract_time_segments(text): print("="*50) print(text) segments = [] # Match: from 12.1 to 117.0 / from 113.2s to 163.4s / from 10.0 seconds to 15.0 seconds pattern1 = re.findall( r'(?:from|is from|takes place from)?\s*' # optional "from" r'(\d+(?:\.\d+)?)(?:s| seconds?)?\s*' r'to\s*' r'(\d+(?:\.\d+)?)(?:s| seconds?)?', text, flags=re.IGNORECASE) # Match: 00:00:00 to 00:00:08 pattern2 = re.findall( r'(\d+):(\d+):(\d+)\s+to\s+(\d+):(\d+):(\d+)', text, flags=re.IGNORECASE) for start, end in pattern1: try: segments.append({ 'start': round(float(start), 2), 'end': round(float(end), 2) }) except: continue for h1, m1, s1, h2, m2, s2 in pattern2: start_sec = int(h1) * 3600 + int(m1) * 60 + int(s1) end_sec = int(h2) * 3600 + int(m2) * 60 + int(s2) segments.append({ 'start': float(start_sec), 'end': float(end_sec) }) return segments def extract_segments_from_text(text): # Match patterns like 379-419 or 540-540 # pattern = re.findall(r'(\d+)\s*-\s*(\d+)', text) pattern = re.findall(r'(\d+(?:\.\d+)?)\s*-\s*(\d+(?:\.\d+)?)', text) segments = [] for start, end in pattern: segments.append({'start': float(start), 'end': float(end)}) if not segments: # process raw, usually zero-shot answer segments = extract_time_segments(text) if not segments: print(f"Warning: No valid segments found in text: {text}") return segments def compute_iou(seg1, seg2): inter_start = max(seg1['start'], seg2['start']) inter_end = min(seg1['end'], seg2['end']) inter = max(0, inter_end - inter_start) union = max(seg1['end'], seg2['end']) - min(seg1['start'], seg2['start']) return inter / union if union > 0 else 0.0 def evaluate_pair(preds, gts, tiou_thresh=0.5): gt_matched = [False] * len(gts) pred_matched = [False] * len(preds) matched_ious = [] for i, gt in enumerate(gts): best_iou = 0 best_j = -1 for j, pred in enumerate(preds): if pred_matched[j]: # avoid multiple GTs matching same pred continue iou = compute_iou(pred, gt) if iou > best_iou: best_iou = iou best_j = j if best_iou >= tiou_thresh: gt_matched[i] = True pred_matched[best_j] = True matched_ious.append(best_iou) recall = sum(gt_matched) / len(gts) if gts else 0.0 precision = sum(pred_matched) / len(preds) if preds else 0.0 f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 mean_iou = sum(matched_ious) / len(matched_ious) if matched_ious else 0.0 # print(f"types: recall={type(recall)}, precision={type(precision)}, f1={type(f1)}, mean_iou={type(mean_iou)}") return recall, precision, f1, mean_iou def evaluate_tal_record(tal_record, tiou_thresh=0.5): recalls, precisions, f1s, mean_ious = [], [], [], [] for entry in tal_record: preds = entry['prediction'] gts = entry['ground_truth'] recall, precision, f1, mean_iou = evaluate_pair(preds, gts, tiou_thresh) recalls.append(recall) precisions.append(precision) f1s.append(f1) mean_ious.append(mean_iou) # for i, (r, p, f, mi) in enumerate(zip(recalls, precisions, f1s, mean_ious)): # print(f"[{i}] types: recall={type(r)}, precision={type(p)}, f1={type(f)}, mean_iou={type(mi)}") def avg(x): return sum(x) / len(x) if x else 0.0 return { f"Recall@{tiou_thresh:.2f}": avg(recalls), # f"Precision@{tiou_thresh:.2f}": avg(precisions), # f"F1@{tiou_thresh:.2f}": avg(f1s), f"meanIoU@{tiou_thresh:.2f}": avg(mean_ious), } def pretty_print_summary(summary, label): # print(f"\nšŸ“Š {label}") for k, v in summary.items(): print(f" {k}: {v:.4f}") def group_records_by_dataset(data): """Group TAL records by dataset for per-dataset evaluation.""" dataset_groups = defaultdict(list) for key, record in data.items(): qa_type = record.get('qa_type', '') if 'tal' not in qa_type.lower(): continue # Detect dataset from video_id or other fields video_id = record.get('video_id', '') # Check data_source first (used in leaderboard format), then fall back to dataset/dataset_name dataset = record.get('data_source', record.get('dataset', record.get('dataset_name', 'Unknown'))) if dataset == 'Unknown' and video_id: # Try to infer from video_id patterns video_id_lower = str(video_id).lower() if len(video_id) == 11 and any(c.isalpha() for c in video_id): dataset = "AVOS" elif "_part" in video_id_lower: dataset = "CoPESD" elif "video" in video_id_lower: dataset = "CholecT50" dataset_groups[dataset].append(record) return dict(dataset_groups) def evaluate_dataset_tal(dataset_name, records): """Evaluate TAL for a specific dataset.""" print(f"\nEvaluating {dataset_name} ({len(records)} records)...") results_by_fps = defaultdict(list) for record in records: fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0)) if isinstance(fps, str): fps = float(fps) # Extract prediction and ground truth from record # Parse prediction from 'answer' field raw_answer = record.get('answer', '') predicted_segments = extract_segments_from_text(raw_answer) # Get ground truth spans from struc_info struc_info = record.get('struc_info', []) gt_spans = [] if isinstance(struc_info, list): for item in struc_info: if 'spans' in item: gt_spans.extend(item['spans']) # Convert to frames for seg in predicted_segments: seg['start'] = float(seg['start'] * fps) seg['end'] = float(seg['end'] * fps) for span in gt_spans: span['start'] = float(span['start'] * fps) span['end'] = float(span['end'] * fps) # Create formatted record for evaluation formatted_record = [{ 'prediction': predicted_segments, 'ground_truth': gt_spans }] # Evaluate this record at both IoU thresholds result_03 = evaluate_tal_record(formatted_record, tiou_thresh=0.3) result_05 = evaluate_tal_record(formatted_record, tiou_thresh=0.5) results_by_fps[fps].append({'0.3': result_03, '0.5': result_05}) # Aggregate results aggregated = {} for fps, results_list in results_by_fps.items(): # Extract metrics from results at both thresholds all_recalls_03 = [r['0.3'].get(f'Recall@0.30', 0) for r in results_list if r] all_mean_ious_03 = [r['0.3'].get(f'meanIoU@0.30', 0) for r in results_list if r] all_recalls_05 = [r['0.5'].get(f'Recall@0.50', 0) for r in results_list if r] all_mean_ious_05 = [r['0.5'].get(f'meanIoU@0.50', 0) for r in results_list if r] if all_recalls_03: aggregated[f'fps_{fps}'] = { 'recall@0.3': np.mean(all_recalls_03), 'meanIoU@0.3': np.mean(all_mean_ious_03), 'recall@0.5': np.mean(all_recalls_05), 'meanIoU@0.5': np.mean(all_mean_ious_05), 'count': len(all_recalls_03) } # Compute overall metrics (combining all FPS) if we have multiple FPS values if len(results_by_fps) > 1: # Collect all records' results across all FPS all_results_03 = [] all_results_05 = [] for fps, results_list in results_by_fps.items(): for r in results_list: if r: all_results_03.append(r['0.3']) all_results_05.append(r['0.5']) # Compute overall averages if all_results_03: overall_recalls_03 = [r.get('Recall@0.30', 0) for r in all_results_03] overall_mean_ious_03 = [r.get('meanIoU@0.30', 0) for r in all_results_03] overall_recalls_05 = [r.get('Recall@0.50', 0) for r in all_results_05] overall_mean_ious_05 = [r.get('meanIoU@0.50', 0) for r in all_results_05] aggregated['IoU_0.3'] = { 'Recall': np.mean(overall_recalls_03), 'meanIoU': np.mean(overall_mean_ious_03), } aggregated['IoU_0.5'] = { 'Recall': np.mean(overall_recalls_05), 'meanIoU': np.mean(overall_mean_ious_05), } return aggregated def main(): """Main evaluation function for TAL.""" import sys # Require results file via command line argument if len(sys.argv) < 2: print("Usage: python eval_tal.py ") print("Example: python eval_tal.py results/model_results.json") sys.exit(1) output_file = sys.argv[1] print(f"Loading results from: {output_file}") with open(output_file, "r") as f: infer_output = json.load(f) # Group by dataset dataset_records = group_records_by_dataset(infer_output) print(f"\nFound datasets: {list(dataset_records.keys())}") for dataset, records in dataset_records.items(): print(f" {dataset}: {len(records)} TAL records") if not any(dataset_records.values()): print("No TAL records found!") return # Evaluate each dataset all_results = {} for dataset_name, records in dataset_records.items(): if records: results = evaluate_dataset_tal(dataset_name, records) all_results[dataset_name] = results # Print summary print(f"\n{'='*80}") print("TAL EVALUATION SUMMARY") print(f"{'='*80}") # Aggregate metrics across all datasets all_miou_03 = [] all_miou_05 = [] for dataset_name, fps_results in all_results.items(): if fps_results: print(f"\n{dataset_name}:") for fps_key, metrics in sorted(fps_results.items()): print(f" {fps_key}:") for metric_name, value in metrics.items(): if metric_name != 'count': print(f" {metric_name}: {value:.4f}") else: print(f" samples: {value}") # Collect for overall average if 'meanIoU@0.3' in metrics: all_miou_03.append(metrics['meanIoU@0.3']) if 'meanIoU@0.5' in metrics: all_miou_05.append(metrics['meanIoU@0.5']) # Return overall aggregated results return { 'meanIoU@0.3': np.mean(all_miou_03) if all_miou_03 else 0.0, 'meanIoU@0.5': np.mean(all_miou_05) if all_miou_05 else 0.0 } if __name__ == "__main__": main()