| import json |
| import re |
| from matplotlib import text |
| import numpy as np |
| from typing import Tuple |
| from collections import defaultdict |
|
|
|
|
| def extract_time_segments(text): |
| print("="*50) |
| print(text) |
| segments = [] |
|
|
| |
| pattern1 = re.findall( |
| r'(?:from|is from|takes place from)?\s*' |
| r'(\d+(?:\.\d+)?)(?:s| seconds?)?\s*' |
| r'to\s*' |
| r'(\d+(?:\.\d+)?)(?:s| seconds?)?', text, flags=re.IGNORECASE) |
|
|
| |
| pattern2 = re.findall( |
| r'(\d+):(\d+):(\d+)\s+to\s+(\d+):(\d+):(\d+)', text, flags=re.IGNORECASE) |
|
|
| for start, end in pattern1: |
| try: |
| segments.append({ |
| 'start': round(float(start), 2), |
| 'end': round(float(end), 2) |
| }) |
| except: |
| continue |
|
|
| for h1, m1, s1, h2, m2, s2 in pattern2: |
| start_sec = int(h1) * 3600 + int(m1) * 60 + int(s1) |
| end_sec = int(h2) * 3600 + int(m2) * 60 + int(s2) |
| segments.append({ |
| 'start': float(start_sec), |
| 'end': float(end_sec) |
| }) |
|
|
| return segments |
|
|
|
|
|
|
|
|
| def extract_segments_from_text(text): |
| |
| |
| pattern = re.findall(r'(\d+(?:\.\d+)?)\s*-\s*(\d+(?:\.\d+)?)', text) |
| segments = [] |
| for start, end in pattern: |
| segments.append({'start': float(start), 'end': float(end)}) |
| |
| if not segments: |
| |
| segments = extract_time_segments(text) |
| if not segments: |
| print(f"Warning: No valid segments found in text: {text}") |
| return segments |
|
|
|
|
| def compute_iou(seg1, seg2): |
| inter_start = max(seg1['start'], seg2['start']) |
| inter_end = min(seg1['end'], seg2['end']) |
| inter = max(0, inter_end - inter_start) |
| union = max(seg1['end'], seg2['end']) - min(seg1['start'], seg2['start']) |
| return inter / union if union > 0 else 0.0 |
|
|
| def evaluate_pair(preds, gts, tiou_thresh=0.5): |
| gt_matched = [False] * len(gts) |
| pred_matched = [False] * len(preds) |
| matched_ious = [] |
|
|
| for i, gt in enumerate(gts): |
| best_iou = 0 |
| best_j = -1 |
| for j, pred in enumerate(preds): |
| if pred_matched[j]: |
| continue |
| iou = compute_iou(pred, gt) |
| if iou > best_iou: |
| best_iou = iou |
| best_j = j |
| if best_iou >= tiou_thresh: |
| gt_matched[i] = True |
| pred_matched[best_j] = True |
| matched_ious.append(best_iou) |
|
|
| recall = sum(gt_matched) / len(gts) if gts else 0.0 |
| precision = sum(pred_matched) / len(preds) if preds else 0.0 |
| f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 |
| mean_iou = sum(matched_ious) / len(matched_ious) if matched_ious else 0.0 |
|
|
| |
|
|
| return recall, precision, f1, mean_iou |
|
|
|
|
| def evaluate_tal_record(tal_record, tiou_thresh=0.5): |
| recalls, precisions, f1s, mean_ious = [], [], [], [] |
|
|
| for entry in tal_record: |
| preds = entry['prediction'] |
| gts = entry['ground_truth'] |
| recall, precision, f1, mean_iou = evaluate_pair(preds, gts, tiou_thresh) |
| recalls.append(recall) |
| precisions.append(precision) |
| f1s.append(f1) |
| mean_ious.append(mean_iou) |
|
|
| |
| |
| |
| def avg(x): return sum(x) / len(x) if x else 0.0 |
|
|
| return { |
| f"Recall@{tiou_thresh:.2f}": avg(recalls), |
| |
| |
| f"meanIoU@{tiou_thresh:.2f}": avg(mean_ious), |
| } |
| |
| |
| def pretty_print_summary(summary, label): |
| |
| for k, v in summary.items(): |
| print(f" {k}: {v:.4f}") |
|
|
|
|
| def group_records_by_dataset(data): |
| """Group TAL records by dataset for per-dataset evaluation.""" |
| dataset_groups = defaultdict(list) |
|
|
| for key, record in data.items(): |
| qa_type = record.get('qa_type', '') |
| if 'tal' not in qa_type.lower(): |
| continue |
|
|
| |
| video_id = record.get('video_id', '') |
| |
| dataset = record.get('data_source', record.get('dataset', record.get('dataset_name', 'Unknown'))) |
|
|
| if dataset == 'Unknown' and video_id: |
| |
| video_id_lower = str(video_id).lower() |
| if len(video_id) == 11 and any(c.isalpha() for c in video_id): |
| dataset = "AVOS" |
| elif "_part" in video_id_lower: |
| dataset = "CoPESD" |
| elif "video" in video_id_lower: |
| dataset = "CholecT50" |
|
|
| dataset_groups[dataset].append(record) |
|
|
| return dict(dataset_groups) |
|
|
|
|
| def evaluate_dataset_tal(dataset_name, records): |
| """Evaluate TAL for a specific dataset.""" |
| print(f"\nEvaluating {dataset_name} ({len(records)} records)...") |
|
|
| results_by_fps = defaultdict(list) |
|
|
| for record in records: |
| fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0)) |
| if isinstance(fps, str): |
| fps = float(fps) |
|
|
| |
| |
| raw_answer = record.get('answer', '') |
| predicted_segments = extract_segments_from_text(raw_answer) |
|
|
| |
| struc_info = record.get('struc_info', []) |
| gt_spans = [] |
| if isinstance(struc_info, list): |
| for item in struc_info: |
| if 'spans' in item: |
| gt_spans.extend(item['spans']) |
|
|
| |
| for seg in predicted_segments: |
| seg['start'] = float(seg['start'] * fps) |
| seg['end'] = float(seg['end'] * fps) |
| for span in gt_spans: |
| span['start'] = float(span['start'] * fps) |
| span['end'] = float(span['end'] * fps) |
|
|
| |
| formatted_record = [{ |
| 'prediction': predicted_segments, |
| 'ground_truth': gt_spans |
| }] |
|
|
| |
| result_03 = evaluate_tal_record(formatted_record, tiou_thresh=0.3) |
| result_05 = evaluate_tal_record(formatted_record, tiou_thresh=0.5) |
| results_by_fps[fps].append({'0.3': result_03, '0.5': result_05}) |
|
|
| |
| aggregated = {} |
| for fps, results_list in results_by_fps.items(): |
| |
| all_recalls_03 = [r['0.3'].get(f'Recall@0.30', 0) for r in results_list if r] |
| all_mean_ious_03 = [r['0.3'].get(f'meanIoU@0.30', 0) for r in results_list if r] |
| all_recalls_05 = [r['0.5'].get(f'Recall@0.50', 0) for r in results_list if r] |
| all_mean_ious_05 = [r['0.5'].get(f'meanIoU@0.50', 0) for r in results_list if r] |
|
|
| if all_recalls_03: |
| aggregated[f'fps_{fps}'] = { |
| 'recall@0.3': np.mean(all_recalls_03), |
| 'meanIoU@0.3': np.mean(all_mean_ious_03), |
| 'recall@0.5': np.mean(all_recalls_05), |
| 'meanIoU@0.5': np.mean(all_mean_ious_05), |
| 'count': len(all_recalls_03) |
| } |
|
|
| |
| if len(results_by_fps) > 1: |
| |
| all_results_03 = [] |
| all_results_05 = [] |
|
|
| for fps, results_list in results_by_fps.items(): |
| for r in results_list: |
| if r: |
| all_results_03.append(r['0.3']) |
| all_results_05.append(r['0.5']) |
|
|
| |
| if all_results_03: |
| overall_recalls_03 = [r.get('Recall@0.30', 0) for r in all_results_03] |
| overall_mean_ious_03 = [r.get('meanIoU@0.30', 0) for r in all_results_03] |
| overall_recalls_05 = [r.get('Recall@0.50', 0) for r in all_results_05] |
| overall_mean_ious_05 = [r.get('meanIoU@0.50', 0) for r in all_results_05] |
|
|
| aggregated['IoU_0.3'] = { |
| 'Recall': np.mean(overall_recalls_03), |
| 'meanIoU': np.mean(overall_mean_ious_03), |
| } |
| aggregated['IoU_0.5'] = { |
| 'Recall': np.mean(overall_recalls_05), |
| 'meanIoU': np.mean(overall_mean_ious_05), |
| } |
|
|
| return aggregated |
|
|
|
|
| def main(): |
| """Main evaluation function for TAL.""" |
| import sys |
|
|
| |
| if len(sys.argv) < 2: |
| print("Usage: python eval_tal.py <results_json_file>") |
| print("Example: python eval_tal.py results/model_results.json") |
| sys.exit(1) |
|
|
| output_file = sys.argv[1] |
| print(f"Loading results from: {output_file}") |
|
|
| with open(output_file, "r") as f: |
| infer_output = json.load(f) |
|
|
| |
| dataset_records = group_records_by_dataset(infer_output) |
|
|
| print(f"\nFound datasets: {list(dataset_records.keys())}") |
| for dataset, records in dataset_records.items(): |
| print(f" {dataset}: {len(records)} TAL records") |
|
|
| if not any(dataset_records.values()): |
| print("No TAL records found!") |
| return |
|
|
| |
| all_results = {} |
| for dataset_name, records in dataset_records.items(): |
| if records: |
| results = evaluate_dataset_tal(dataset_name, records) |
| all_results[dataset_name] = results |
|
|
| |
| print(f"\n{'='*80}") |
| print("TAL EVALUATION SUMMARY") |
| print(f"{'='*80}") |
|
|
| |
| all_miou_03 = [] |
| all_miou_05 = [] |
|
|
| for dataset_name, fps_results in all_results.items(): |
| if fps_results: |
| print(f"\n{dataset_name}:") |
| for fps_key, metrics in sorted(fps_results.items()): |
| print(f" {fps_key}:") |
| for metric_name, value in metrics.items(): |
| if metric_name != 'count': |
| print(f" {metric_name}: {value:.4f}") |
| else: |
| print(f" samples: {value}") |
|
|
| |
| if 'meanIoU@0.3' in metrics: |
| all_miou_03.append(metrics['meanIoU@0.3']) |
| if 'meanIoU@0.5' in metrics: |
| all_miou_05.append(metrics['meanIoU@0.5']) |
|
|
| |
| return { |
| 'per_dataset': all_results, |
| 'meanIoU@0.3': np.mean(all_miou_03) if all_miou_03 else 0.0, |
| 'meanIoU@0.5': np.mean(all_miou_05) if all_miou_05 else 0.0 |
| } |
|
|
|
|
| if __name__ == "__main__": |
| main() |