File size: 9,976 Bytes
a36b7fe
331979f
 
 
a36b7fe
 
331979f
 
 
a36b7fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331979f
 
 
 
 
 
a36b7fe
 
331979f
 
a605ebb
 
331979f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd1b9c6
a36b7fe
331979f
 
dd1b9c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331979f
a36b7fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331979f
 
 
 
5f41159
331979f
5f41159
331979f
 
 
5f41159
 
331979f
5f41159
 
331979f
 
 
 
 
 
 
 
 
 
 
 
a36b7fe
331979f
 
 
 
5f41159
331979f
 
 
 
 
 
a36b7fe
 
 
 
331979f
 
 
 
 
a36b7fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331979f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""Dense Video Captioning evaluation using LLM judge + temporal F1."""

import json
import sys
import numpy as np
from collections import defaultdict
from eval_caption_llm_judge import evaluate_caption_task


def compute_iou(pred_segment, gt_segment):
    """Compute IoU between two segments [start, end]."""
    pred_start, pred_end = pred_segment
    gt_start, gt_end = gt_segment

    # Compute intersection
    inter_start = max(pred_start, gt_start)
    inter_end = min(pred_end, gt_end)
    intersection = max(0, inter_end - inter_start)

    # Compute union
    union = (pred_end - pred_start) + (gt_end - gt_start) - intersection

    if union == 0:
        return 0

    return intersection / union


def compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5):
    """
    Compute F1 score for temporal segment matching.

    Args:
        pred_segments: List of predicted [start, end] segments
        gt_segments: List of ground truth [start, end] segments
        iou_threshold: IoU threshold for matching (default 0.5)

    Returns:
        Dict with precision, recall, and f1 scores
    """
    if not pred_segments or not gt_segments:
        return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    # Match predicted segments to ground truth
    matched_gt = set()
    matched_pred = set()

    for pred_idx, pred_seg in enumerate(pred_segments):
        best_iou = 0
        best_gt_idx = -1

        for gt_idx, gt_seg in enumerate(gt_segments):
            if gt_idx in matched_gt:
                continue

            iou = compute_iou(pred_seg, gt_seg)
            if iou >= iou_threshold and iou > best_iou:
                best_iou = iou
                best_gt_idx = gt_idx

        if best_gt_idx >= 0:
            matched_pred.add(pred_idx)
            matched_gt.add(best_gt_idx)

    # Compute precision, recall, F1
    precision = len(matched_pred) / len(pred_segments) if pred_segments else 0
    recall = len(matched_gt) / len(gt_segments) if gt_segments else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1
    }


def parse_dvc_segments(text):
    """
    Parse DVC output to extract segments.
    Supports multiple formats:
    - [start-end] caption
    - (start-end) caption
    - start-end seconds: caption
    """
    import re
    segments = []

    # Pattern 1: [0.0-5.2] or (0.0-5.2)
    pattern1 = r'[\[\(](\d+\.?\d*)\s*-\s*(\d+\.?\d*)[\]\)]'

    # Pattern 2: 0.0-5.2 seconds:
    pattern2 = r'(\d+\.?\d*)\s*-\s*(\d+\.?\d*)\s*seconds?:'

    # Try both patterns
    for pattern in [pattern1, pattern2]:
        matches = re.finditer(pattern, text, re.IGNORECASE)
        for match in matches:
            start = float(match.group(1))
            end = float(match.group(2))
            segments.append([start, end])

    return segments


def group_records_by_dataset(data):
    """Group DVC records by dataset for per-dataset evaluation."""
    dataset_groups = defaultdict(list)

    for key, record in data.items():
        qa_type = record.get('qa_type', '')
        # Match any dense_captioning variant (dense_captioning, dense_captioning_gpt, dense_captioning_gemini, dc)
        if not any(x in qa_type.lower() for x in ['dense_captioning', 'dense_caption', 'dc']):
            continue

        # Check data_source first (leaderboard format), then fall back to dataset/dataset_name
        dataset = record.get('data_source', record.get('dataset', record.get('dataset_name', record.get('metadata', {}).get('dataset', 'Unknown'))))
        video_id = record.get('video_id', record.get('metadata', {}).get('video_id', ''))

        if dataset == 'Unknown' and video_id:
            video_id_lower = str(video_id).lower()
            if len(video_id) == 11 and any(c.isalpha() for c in video_id):
                dataset = "AVOS"
            elif "_part" in video_id_lower:
                dataset = "CoPESD"
            elif "video" in video_id_lower:
                dataset = "CholecT50"

        dataset_groups[dataset].append(record)

    return dict(dataset_groups)


def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
    """Evaluate DVC for a specific dataset using caption quality + temporal F1."""
    print(f"\nEvaluating {dataset_name} ({len(records)} records)...")

    # Step 1: Evaluate caption quality using LLM judge (unless skipped)
    if skip_llm_judge:
        print(f"  Skipping LLM judge caption evaluation (--skip-llm-judge flag)")
        caption_score = 0.0
        caption_method = 'skipped'
    else:
        import tempfile
        import os

        temp_data = {str(i): record for i, record in enumerate(records)}

        with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
            json.dump(temp_data, f)
            temp_file = f.name

        try:
            # Use caption evaluator for caption quality
            caption_result = evaluate_caption_task(temp_file, 'dense_captioning')
            caption_score = caption_result['score']
            caption_method = caption_result['method']
        finally:
            os.unlink(temp_file)

    # Step 2: Compute temporal F1 for segment localization
    all_f1_scores = []

    for record in records:
        # Get FPS for time-to-frame conversion
        fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
        if isinstance(fps, str):
            fps = float(fps)

        # Parse predicted segments from answer
        pred_text = record.get('answer', '')
        pred_segments = parse_dvc_segments(pred_text)

        # Get ground truth segments from struc_info
        struc_info = record.get('struc_info', [])
        gt_segments = []

        if isinstance(struc_info, list):
            for item in struc_info:
                if isinstance(item, dict):
                    # Handle different formats
                    if 'dc_segments' in item:
                        # NurViD format
                        segments = item['dc_segments']
                    elif 'start' in item and 'end' in item:
                        # Direct segment format
                        segments = [item]
                    else:
                        continue

                    for seg in (segments if isinstance(segments, list) else [segments]):
                        if 'start' in seg and 'end' in seg:
                            # Convert to seconds (struc_info is in seconds)
                            gt_segments.append([
                                float(seg['start']),
                                float(seg['end'])
                            ])

        # Compute F1 for this sample
        if pred_segments and gt_segments:
            f1_result = compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5)
            all_f1_scores.append(f1_result['f1'])

    # Aggregate F1 scores
    avg_f1 = np.mean(all_f1_scores) if all_f1_scores else 0.0

    # Return both caption quality and temporal F1
    return {
        'overall': {
            'caption_score': caption_score,
            'caption_method': caption_method,
            'temporal_f1': avg_f1,
            'count': len(records),
            'f1_samples': len(all_f1_scores)
        }
    }


def main():
    """Main evaluation function for DVC."""
    if len(sys.argv) < 2:
        print("Usage: python eval_dvc.py <results_json_file> [--skip-llm-judge]")
        print("Example: python eval_dvc.py results/model_results.json")
        print("Example: python eval_dvc.py results/model_results.json --skip-llm-judge")
        sys.exit(1)

    output_file = sys.argv[1]
    skip_llm_judge = '--skip-llm-judge' in sys.argv

    print(f"Loading results from: {output_file}")
    if skip_llm_judge:
        print("⚠️  --skip-llm-judge flag detected: Skipping caption evaluation, computing temporal F1 only")

    with open(output_file, "r") as f:
        infer_output = json.load(f)

    dataset_records = group_records_by_dataset(infer_output)

    print(f"\nFound datasets: {list(dataset_records.keys())}")
    for dataset, records in dataset_records.items():
        print(f"  {dataset}: {len(records)} DVC records")

    if not any(dataset_records.values()):
        print("No DVC records found!")
        return {}

    all_results = {}
    for dataset_name, records in dataset_records.items():
        if records:
            results = evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=skip_llm_judge)
            all_results[dataset_name] = results

    print(f"\n{'='*80}")
    print("DENSE VIDEO CAPTIONING EVALUATION SUMMARY")
    print(f"{'='*80}")

    # Aggregate overall metrics
    all_caption_scores = []
    all_f1_scores = []

    for dataset_name, results in all_results.items():
        if results:
            print(f"\n{dataset_name}:")
            for key, metrics in results.items():
                if isinstance(metrics, dict):
                    print(f"  Caption Score ({metrics.get('caption_method', 'unknown')}): {metrics.get('caption_score', 0):.4f}")
                    print(f"  Temporal F1@0.5: {metrics.get('temporal_f1', 0):.4f}")
                    print(f"  Total samples: {metrics.get('count', 0)}")
                    print(f"  F1 computed on: {metrics.get('f1_samples', 0)} samples")

                    # Collect for overall average
                    all_caption_scores.append(metrics.get('caption_score', 0))
                    all_f1_scores.append(metrics.get('temporal_f1', 0))

    # Return overall aggregated results
    return {
        'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
        'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
        'method': all_results[list(all_results.keys())[0]]['overall'].get('caption_method', 'unknown') if all_results else 'unknown'
    }


if __name__ == "__main__":
    main()