MedGRPO Team commited on
Commit
aa5db53
·
1 Parent(s): 63df552
evaluation/eval_dvc.py CHANGED
@@ -1,106 +1,197 @@
1
- """Dense Video Captioning evaluation using LLM judge + temporal F1."""
 
 
 
 
 
 
 
2
 
3
  import json
 
4
  import sys
5
  import numpy as np
6
  from collections import defaultdict
7
  from eval_caption_llm_judge import evaluate_caption_task
8
 
9
 
10
- def compute_iou(pred_segment, gt_segment):
11
- """Compute IoU between two segments [start, end]."""
12
- pred_start, pred_end = pred_segment
13
- gt_start, gt_end = gt_segment
14
-
15
- # Compute intersection
16
- inter_start = max(pred_start, gt_start)
17
- inter_end = min(pred_end, gt_end)
18
- intersection = max(0, inter_end - inter_start)
19
-
20
- # Compute union
21
- union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
22
-
23
- if union == 0:
24
- return 0
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- return intersection / union
27
-
28
-
29
- def compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  """
31
- Compute F1 score for temporal segment matching.
32
-
33
- Args:
34
- pred_segments: List of predicted [start, end] segments
35
- gt_segments: List of ground truth [start, end] segments
36
- iou_threshold: IoU threshold for matching (default 0.5)
37
-
38
- Returns:
39
- Dict with precision, recall, and f1 scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
41
- if not pred_segments or not gt_segments:
42
- return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
43
-
44
- # Match predicted segments to ground truth
45
- matched_gt = set()
46
- matched_pred = set()
47
-
48
- for pred_idx, pred_seg in enumerate(pred_segments):
49
- best_iou = 0
50
- best_gt_idx = -1
51
 
52
- for gt_idx, gt_seg in enumerate(gt_segments):
53
- if gt_idx in matched_gt:
54
- continue
55
 
56
- iou = compute_iou(pred_seg, gt_seg)
57
- if iou >= iou_threshold and iou > best_iou:
58
- best_iou = iou
59
- best_gt_idx = gt_idx
60
-
61
- if best_gt_idx >= 0:
62
- matched_pred.add(pred_idx)
63
- matched_gt.add(best_gt_idx)
64
-
65
- # Compute precision, recall, F1
66
- precision = len(matched_pred) / len(pred_segments) if pred_segments else 0
67
- recall = len(matched_gt) / len(gt_segments) if gt_segments else 0
68
- f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
69
 
70
  return {
71
- 'precision': precision,
72
- 'recall': recall,
73
- 'f1': f1
74
  }
75
 
76
 
77
- def parse_dvc_segments(text):
78
- """
79
- Parse DVC output to extract segments.
80
- Supports multiple formats:
81
- - [start-end] caption
82
- - (start-end) caption
83
- - start-end seconds: caption
84
- """
85
- import re
86
- segments = []
87
-
88
- # Pattern 1: [0.0-5.2] or (0.0-5.2)
89
- pattern1 = r'[\[\(](\d+\.?\d*)\s*-\s*(\d+\.?\d*)[\]\)]'
90
-
91
- # Pattern 2: 0.0-5.2 seconds:
92
- pattern2 = r'(\d+\.?\d*)\s*-\s*(\d+\.?\d*)\s*seconds?:'
93
-
94
- # Try both patterns
95
- for pattern in [pattern1, pattern2]:
96
- matches = re.finditer(pattern, text, re.IGNORECASE)
97
- for match in matches:
98
- start = float(match.group(1))
99
- end = float(match.group(2))
100
- segments.append([start, end])
101
-
102
- return segments
103
-
104
 
105
  def group_records_by_dataset(data):
106
  """Group DVC records by dataset for per-dataset evaluation."""
@@ -130,6 +221,26 @@ def group_records_by_dataset(data):
130
  return dict(dataset_groups)
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
134
  """Evaluate DVC for a specific dataset using caption quality + temporal F1."""
135
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
@@ -150,65 +261,71 @@ def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
150
  temp_file = f.name
151
 
152
  try:
153
- # Use caption evaluator for caption quality
154
  caption_result = evaluate_caption_task(temp_file, 'dense_captioning')
155
  caption_score = caption_result['score']
156
  caption_method = caption_result['method']
157
  finally:
158
  os.unlink(temp_file)
159
 
160
- # Step 2: Compute temporal F1 for segment localization
161
  all_f1_scores = []
 
 
162
 
163
  for record in records:
164
- # Get FPS for time-to-frame conversion
165
  fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
166
  if isinstance(fps, str):
167
  fps = float(fps)
168
 
169
- # Parse predicted segments from answer
170
- pred_text = record.get('answer', '')
171
- pred_segments = parse_dvc_segments(pred_text)
 
 
 
172
 
173
- # Get ground truth segments from struc_info
174
- struc_info = record.get('struc_info', [])
175
- gt_segments = []
176
 
177
- if isinstance(struc_info, list):
178
- for item in struc_info:
179
- if isinstance(item, dict):
180
- # Handle different formats
181
- if 'dc_segments' in item:
182
- # NurViD format
183
- segments = item['dc_segments']
184
- elif 'start' in item and 'end' in item:
185
- # Direct segment format
186
- segments = [item]
187
- else:
188
- continue
189
-
190
- for seg in (segments if isinstance(segments, list) else [segments]):
191
- if 'start' in seg and 'end' in seg:
192
- # Convert to seconds (struc_info is in seconds)
193
- gt_segments.append([
194
- float(seg['start']),
195
- float(seg['end'])
196
- ])
197
-
198
- # Compute F1 for this sample
199
  if pred_segments and gt_segments:
200
- f1_result = compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5)
201
- all_f1_scores.append(f1_result['f1'])
 
 
 
 
 
 
 
202
 
203
- # Aggregate F1 scores
204
  avg_f1 = np.mean(all_f1_scores) if all_f1_scores else 0.0
 
 
205
 
206
- # Return both caption quality and temporal F1
207
  return {
208
  'overall': {
209
  'caption_score': caption_score,
210
  'caption_method': caption_method,
211
  'temporal_f1': avg_f1,
 
 
212
  'count': len(records),
213
  'f1_samples': len(all_f1_scores)
214
  }
@@ -228,7 +345,7 @@ def main():
228
 
229
  print(f"Loading results from: {output_file}")
230
  if skip_llm_judge:
231
- print("⚠️ --skip-llm-judge flag detected: Skipping caption evaluation, computing temporal F1 only")
232
 
233
  with open(output_file, "r") as f:
234
  infer_output = json.load(f)
@@ -253,7 +370,6 @@ def main():
253
  print("DENSE VIDEO CAPTIONING EVALUATION SUMMARY")
254
  print(f"{'='*80}")
255
 
256
- # Aggregate overall metrics
257
  all_caption_scores = []
258
  all_f1_scores = []
259
 
@@ -263,15 +379,15 @@ def main():
263
  for key, metrics in results.items():
264
  if isinstance(metrics, dict):
265
  print(f" Caption Score ({metrics.get('caption_method', 'unknown')}): {metrics.get('caption_score', 0):.4f}")
266
- print(f" Temporal F1@0.5: {metrics.get('temporal_f1', 0):.4f}")
 
 
267
  print(f" Total samples: {metrics.get('count', 0)}")
268
  print(f" F1 computed on: {metrics.get('f1_samples', 0)} samples")
269
 
270
- # Collect for overall average
271
  all_caption_scores.append(metrics.get('caption_score', 0))
272
  all_f1_scores.append(metrics.get('temporal_f1', 0))
273
 
274
- # Return overall aggregated results
275
  return {
276
  'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
277
  'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
 
1
+ """Dense Video Captioning evaluation using LLM judge + temporal F1.
2
+
3
+ Temporal F1 algorithm matches Qwen2.5-VL/my_eval/eval_dvc.py exactly:
4
+ - process_raw_output() + flatten_overlapping_segments() for parsing
5
+ - Frame-based coordinates (multiply by FPS)
6
+ - Many-to-many threshold matching across IoU (0.3, 0.5, 0.7, 0.9)
7
+ - F1 = 2 * mean_precision * mean_recall / (mean_precision + mean_recall)
8
+ """
9
 
10
  import json
11
+ import re
12
  import sys
13
  import numpy as np
14
  from collections import defaultdict
15
  from eval_caption_llm_judge import evaluate_caption_task
16
 
17
 
18
+ # =============================================================================
19
+ # Ported from Qwen2.5-VL/my_eval_old/eval_dvc.py - exact same algorithms
20
+ # =============================================================================
21
+
22
+ def zs_parse_multi_segment_annotations(raw_text: str):
23
+ """Parse raw multiline string with multiple timestamped captions per line."""
24
+ all_segments = []
25
+ lines = raw_text.strip().split('\n')
26
+ for line in lines:
27
+ matches = re.findall(
28
+ r"(?:\*\*Start Time:\*\*|Start\s*\(?Time\)?|Time\s*Range:|Time\s*Interval:|^|\n)\s*(\d+\.?\d*)\s*[-–]\s*(\d+\.?\d*)\s*seconds?.*?(?:\*\*Description:\*\*|-)\s*(.+?)(?=\n\d|$)",
29
+ line, flags=re.DOTALL
30
+ )
31
+ for start, end, caption in matches:
32
+ all_segments.append({
33
+ "start": float(start),
34
+ "end": float(end),
35
+ "caption": caption.strip().rstrip('.')
36
+ })
37
+ return all_segments
38
+
39
+
40
+ def process_raw_output(raw_descriptions: str):
41
+ """Process raw frame-wise descriptions into structured segments."""
42
+ pattern = r"(\d+(?:\.\d+)?)-(\d+(?:\.\d+)?)\s+seconds?:\s+(.*?)(?=\n\d+(?:\.\d+)?-\d+(?:\.\d+)?\s+seconds?:|\Z)"
43
+ matches = re.findall(pattern, raw_descriptions, re.DOTALL)
44
 
45
+ segments = []
46
+ for start, end, desc in matches:
47
+ segments.append({
48
+ "start": float(start),
49
+ "end": float(end),
50
+ "caption": desc.strip().replace("\n", " ")
51
+ })
52
+
53
+ # Remove duplicate (start, end) segments
54
+ seen = set()
55
+ unique_segments = []
56
+ for seg in segments:
57
+ key = (seg["start"], seg["end"])
58
+ if key not in seen:
59
+ seen.add(key)
60
+ unique_segments.append(seg)
61
+
62
+ if not unique_segments:
63
+ unique_segments = zs_parse_multi_segment_annotations(raw_descriptions)
64
+
65
+ return unique_segments
66
+
67
+
68
+ def check_for_overlaps(segments):
69
+ """Check a list of temporal segments for any overlaps."""
70
+ sorted_segs = sorted(segments, key=lambda x: (x['start'], x['end']))
71
+ overlaps = []
72
+ for i in range(len(sorted_segs) - 1):
73
+ seg1 = sorted_segs[i]
74
+ seg2 = sorted_segs[i + 1]
75
+ if seg2["start"] < seg1["end"]:
76
+ overlaps.append((seg1, seg2))
77
+ return overlaps
78
+
79
+
80
+ def flatten_overlapping_segments(segments, caption_strategy="longest"):
81
+ """Split overlapping segments into non-overlapping intervals."""
82
+ time_points = sorted(set([s["start"] for s in segments] + [s["end"] for s in segments]))
83
+ result = []
84
+ for i in range(len(time_points) - 1):
85
+ start = time_points[i]
86
+ end = time_points[i + 1]
87
+ overlapping = []
88
+ for s in segments:
89
+ if s["start"] < end and s["end"] > start:
90
+ overlapping.append(s)
91
+ if not overlapping:
92
+ continue
93
+ if caption_strategy == "longest":
94
+ selected = max(overlapping, key=lambda x: x["end"] - x["start"])
95
+ elif caption_strategy == "first":
96
+ selected = overlapping[0]
97
+ else:
98
+ raise ValueError("Unsupported strategy")
99
+ result.append({
100
+ "start": start,
101
+ "end": end,
102
+ "caption": selected["caption"]
103
+ })
104
+ return result
105
+
106
+
107
+ def iou(interval_1, interval_2):
108
+ """Compute IoU between two intervals - matches old eval exactly."""
109
+ start_1, end_1 = min(*interval_1), max(*interval_1)
110
+ start_2, end_2 = min(*interval_2), max(*interval_2)
111
+
112
+ intersection = max(0, min(end_1, end_2) - max(start_1, start_2))
113
+ union = min(
114
+ max(end_1, end_2) - min(start_1, start_2),
115
+ end_1 - start_1 + end_2 - start_2)
116
+ result = float(intersection) / (union + 1e-8)
117
+ return result
118
+
119
+
120
+ def evaluate_detections(predicted_segments, gt_segments, splits,
121
+ iou_thresholds=(0.3, 0.5, 0.7, 0.9)):
122
+ """Compute P/R between predicted and ground truth segments.
123
+
124
+ Many-to-many matching: any pred-gt pair exceeding threshold counts as covered.
125
  """
126
+ best_recall = []
127
+ best_precision = []
128
+
129
+ predicted_shape = predicted_segments.shape[0]
130
+
131
+ for split in set(splits):
132
+ metrics = {}
133
+ for threshold in iou_thresholds:
134
+ metrics[str(threshold)] = {
135
+ 'gt_covered': set(),
136
+ 'pred_covered': set(),
137
+ }
138
+ split_idx = np.where(splits == split)[0]
139
+ split_gt_segments = np.array([gt_segments[idx] for idx in split_idx])
140
+ gt_shape = split_gt_segments.shape[0]
141
+
142
+ for idx_g, gt_segment in enumerate(split_gt_segments):
143
+ for idx_p, segment in enumerate(predicted_segments):
144
+ sample_iou = iou(segment, gt_segment)
145
+ for threshold in iou_thresholds:
146
+ if sample_iou > threshold:
147
+ metrics[str(threshold)]['pred_covered'].add(idx_p)
148
+ metrics[str(threshold)]['gt_covered'].add(idx_g)
149
+
150
+ for threshold, m in metrics.items():
151
+ pred_covered = m['pred_covered']
152
+ gt_covered = m['gt_covered']
153
+ m['precision'] = float(len(pred_covered)) / max(float(predicted_shape), 1.0)
154
+ m['recall'] = float(len(gt_covered)) / float(gt_shape)
155
+
156
+ precision = [m['precision'] for m in metrics.values()]
157
+ recall = [m['recall'] for m in metrics.values()]
158
+ if best_precision:
159
+ best_precision = [max(precision[i], best_precision[i]) for i in range(len(precision))]
160
+ best_recall = [max(recall[i], best_recall[i]) for i in range(len(recall))]
161
+ else:
162
+ best_precision, best_recall = precision, recall
163
+
164
+ return best_precision, best_recall
165
+
166
+
167
+ def compute_temporal_f1_single(predicted_segments, gt_segments, splits,
168
+ iou_thresholds=(0.3, 0.5, 0.7)):
169
+ """Compute temporal F1 for a single sample using the old eval algorithm.
170
+
171
+ Returns dict with Precision_Mean, Recall_Mean, F1_Score.
172
  """
173
+ if predicted_segments.shape[0] == 0 or gt_segments.shape[0] == 0:
174
+ return {'Precision_Mean': 0.0, 'Recall_Mean': 0.0, 'F1_Score': 0.0}
 
 
 
 
 
 
 
 
175
 
176
+ detection_precision, detection_recall = evaluate_detections(
177
+ predicted_segments, gt_segments, splits, iou_thresholds
178
+ )
179
 
180
+ mean_precision = sum(detection_precision) / len(detection_precision)
181
+ mean_recall = sum(detection_recall) / len(detection_recall)
182
+ f1 = 2 * mean_recall * mean_precision / (mean_recall + mean_precision) \
183
+ if (mean_recall + mean_precision) > 0 else 0.0
 
 
 
 
 
 
 
 
 
184
 
185
  return {
186
+ 'Precision_Mean': float(mean_precision),
187
+ 'Recall_Mean': float(mean_recall),
188
+ 'F1_Score': float(f1),
189
  }
190
 
191
 
192
+ # =============================================================================
193
+ # Dataset grouping and evaluation
194
+ # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  def group_records_by_dataset(data):
197
  """Group DVC records by dataset for per-dataset evaluation."""
 
221
  return dict(dataset_groups)
222
 
223
 
224
+ def _extract_gt_segments(record):
225
+ """Extract ground truth segments from struc_info, matching Qwen2.5-VL logic."""
226
+ struc_info = record.get('struc_info', [])
227
+
228
+ if isinstance(struc_info, list) and len(struc_info) > 0:
229
+ if isinstance(struc_info[0], list):
230
+ # Format: [[{segments...}]]
231
+ gnd = struc_info[0]
232
+ elif isinstance(struc_info[0], dict) and 'dc_segments' in struc_info[0]:
233
+ # NurViD format: [{'dc_segments': [...]}]
234
+ gnd = struc_info[0]['dc_segments']
235
+ else:
236
+ # Format: [{segments...}]
237
+ gnd = struc_info
238
+ else:
239
+ gnd = struc_info
240
+
241
+ return gnd
242
+
243
+
244
  def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
245
  """Evaluate DVC for a specific dataset using caption quality + temporal F1."""
246
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
 
261
  temp_file = f.name
262
 
263
  try:
 
264
  caption_result = evaluate_caption_task(temp_file, 'dense_captioning')
265
  caption_score = caption_result['score']
266
  caption_method = caption_result['method']
267
  finally:
268
  os.unlink(temp_file)
269
 
270
+ # Step 2: Compute temporal F1 matching Qwen2.5-VL algorithm exactly
271
  all_f1_scores = []
272
+ all_precision_scores = []
273
+ all_recall_scores = []
274
 
275
  for record in records:
276
+ # Get FPS
277
  fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
278
  if isinstance(fps, str):
279
  fps = float(fps)
280
 
281
+ # Parse predicted segments using process_raw_output (same as Qwen2.5-VL)
282
+ raw_answer = record.get('answer', '')
283
+ processed_answer = process_raw_output(raw_answer)
284
+ overlaps = check_for_overlaps(processed_answer)
285
+ if overlaps:
286
+ processed_answer = flatten_overlapping_segments(processed_answer, caption_strategy="longest")
287
 
288
+ # Get ground truth segments
289
+ gnd = _extract_gt_segments(record)
 
290
 
291
+ # Convert both to frame-based coordinates (multiply by fps, cast to int)
292
+ # IMPORTANT: require 'caption' field to match Qwen2.5-VL's prepare_eval_arrays
293
+ gt_segments = []
294
+ if isinstance(gnd, list):
295
+ for g in gnd:
296
+ if isinstance(g, dict) and 'start' in g and 'end' in g and 'caption' in g:
297
+ gt_segments.append([int(float(g['start']) * fps), int(float(g['end']) * fps)])
298
+
299
+ pred_segments = []
300
+ if isinstance(processed_answer, list):
301
+ for p in processed_answer:
302
+ if isinstance(p, dict) and 'start' in p and 'end' in p and 'caption' in p:
303
+ pred_segments.append([int(p['start'] * fps), int(p['end'] * fps)])
304
+
305
+ # Compute F1 using many-to-many matching across IoU thresholds (0.3, 0.5, 0.7)
 
 
 
 
 
 
 
306
  if pred_segments and gt_segments:
307
+ pred_np = np.array(pred_segments)
308
+ gt_np = np.array(gt_segments)
309
+ splits = np.ones(len(gt_segments), dtype=int)
310
+
311
+ result = compute_temporal_f1_single(pred_np, gt_np, splits,
312
+ iou_thresholds=(0.3, 0.5, 0.7))
313
+ all_f1_scores.append(result['F1_Score'])
314
+ all_precision_scores.append(result['Precision_Mean'])
315
+ all_recall_scores.append(result['Recall_Mean'])
316
 
317
+ # Aggregate scores
318
  avg_f1 = np.mean(all_f1_scores) if all_f1_scores else 0.0
319
+ avg_precision = np.mean(all_precision_scores) if all_precision_scores else 0.0
320
+ avg_recall = np.mean(all_recall_scores) if all_recall_scores else 0.0
321
 
 
322
  return {
323
  'overall': {
324
  'caption_score': caption_score,
325
  'caption_method': caption_method,
326
  'temporal_f1': avg_f1,
327
+ 'temporal_precision': avg_precision,
328
+ 'temporal_recall': avg_recall,
329
  'count': len(records),
330
  'f1_samples': len(all_f1_scores)
331
  }
 
345
 
346
  print(f"Loading results from: {output_file}")
347
  if skip_llm_judge:
348
+ print(" --skip-llm-judge flag detected: Skipping caption evaluation, computing temporal F1 only")
349
 
350
  with open(output_file, "r") as f:
351
  infer_output = json.load(f)
 
370
  print("DENSE VIDEO CAPTIONING EVALUATION SUMMARY")
371
  print(f"{'='*80}")
372
 
 
373
  all_caption_scores = []
374
  all_f1_scores = []
375
 
 
379
  for key, metrics in results.items():
380
  if isinstance(metrics, dict):
381
  print(f" Caption Score ({metrics.get('caption_method', 'unknown')}): {metrics.get('caption_score', 0):.4f}")
382
+ print(f" Temporal F1: {metrics.get('temporal_f1', 0):.4f}")
383
+ print(f" Temporal Precision: {metrics.get('temporal_precision', 0):.4f}")
384
+ print(f" Temporal Recall: {metrics.get('temporal_recall', 0):.4f}")
385
  print(f" Total samples: {metrics.get('count', 0)}")
386
  print(f" F1 computed on: {metrics.get('f1_samples', 0)} samples")
387
 
 
388
  all_caption_scores.append(metrics.get('caption_score', 0))
389
  all_f1_scores.append(metrics.get('temporal_f1', 0))
390
 
 
391
  return {
392
  'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
393
  'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
evaluation/eval_next_action.py CHANGED
@@ -462,8 +462,11 @@ def get_action_list_for_dataset(dataset, procedure=None):
462
  for actions in NURVID_PROCEDURE_ACTIONS.values():
463
  all_actions.update(actions)
464
  return sorted(list(all_actions))
 
 
 
465
  else:
466
- raise ValueError(f"Unknown dataset: {dataset}")
467
 
468
  def normalize_action_text(text, dataset):
469
  """
@@ -487,6 +490,7 @@ def create_class_map_for_dataset(actions):
487
 
488
  def group_records_by_dataset(data):
489
  """Group next_action records by dataset for per-dataset evaluation."""
 
490
  dataset_groups = defaultdict(list)
491
 
492
  for key, record in data.items():
@@ -494,54 +498,32 @@ def group_records_by_dataset(data):
494
  if 'next_action' not in qa_type.lower():
495
  continue
496
 
497
- # Check data_source first (leaderboard format), then fall back to dataset/dataset_name
498
- dataset = record.get('data_source', record.get('dataset', record.get('dataset_name', record.get('metadata', {}).get('dataset', 'Unknown'))))
499
- video_id = record.get('video_id', record.get('metadata', {}).get('video_id', ''))
500
-
501
- if dataset == 'Unknown' and video_id:
502
- video_id_lower = str(video_id).lower()
503
- if len(video_id) == 11 and any(c.isalpha() for c in video_id):
504
- dataset = "AVOS"
505
- elif "_part" in video_id_lower:
506
- dataset = "CoPESD"
507
- elif "video" in video_id_lower:
508
- dataset = "CholecT50"
 
 
 
 
 
 
 
 
509
 
510
- dataset_groups[dataset].append(record)
511
 
512
  return dict(dataset_groups)
513
 
514
 
515
- def normalize_action_text(action_text, dataset_name):
516
- """Normalize action text for comparison."""
517
- action_text = action_text.strip().lower()
518
-
519
- # Dataset-specific mappings
520
- if dataset_name == "CoPESD":
521
- action_text = COPESD_ACTION_MAPPING.get(action_text, action_text)
522
-
523
- return action_text
524
-
525
-
526
- def get_action_list_for_dataset(dataset_name, procedure=None):
527
- """Get action list for a specific dataset."""
528
- if dataset_name == "AVOS":
529
- return AVOS_ACTIONS
530
- elif dataset_name == "CholecT50":
531
- return T50_PHASES
532
- elif dataset_name == "CoPESD":
533
- return TOTAL_NEW_ACTION_LIST
534
- elif dataset_name == "NurViD" and procedure:
535
- return NURVID_PROCEDURE_ACTIONS.get(procedure, [])
536
- elif dataset_name == "EgoSurgery":
537
- # EgoSurgery uses free-form actions, return empty list
538
- return []
539
- return []
540
-
541
-
542
- def create_class_map_for_dataset(actions):
543
- """Create mapping from action name to index."""
544
- return {action: idx for idx, action in enumerate(actions)}
545
 
546
 
547
  def evaluate_dataset_next_action(dataset_name, records):
@@ -568,13 +550,7 @@ def evaluate_dataset_next_action(dataset_name, records):
568
  temp_records = []
569
 
570
  for record in proc_records:
571
- struc_info = record.get('struc_info', {})
572
- if isinstance(struc_info, list) and len(struc_info) > 0:
573
- struc_info = struc_info[0]
574
-
575
- gnd_text = struc_info.get('next_action', '')
576
- if not gnd_text:
577
- gnd_text = record.get('gnd', '')
578
 
579
  gnd_text = normalize_action_text(gnd_text, dataset_name)
580
  if gnd_text:
@@ -627,15 +603,8 @@ def evaluate_dataset_next_action(dataset_name, records):
627
 
628
  pred_text = normalize_action_text(record.get('answer', ''), dataset_name)
629
 
630
- # Get ground truth - try struc_info first, then gnd field
631
- struc_info = record.get('struc_info', {})
632
- if isinstance(struc_info, list) and len(struc_info) > 0:
633
- struc_info = struc_info[0]
634
-
635
- gnd_text = struc_info.get('next_action', '')
636
- if not gnd_text:
637
- # Fallback to gnd field (used for CholecT50 and others)
638
- gnd_text = record.get('gnd', '')
639
 
640
  gnd_text = normalize_action_text(gnd_text, dataset_name)
641
 
 
462
  for actions in NURVID_PROCEDURE_ACTIONS.values():
463
  all_actions.update(actions)
464
  return sorted(list(all_actions))
465
+ elif dataset == "EgoSurgery":
466
+ # EgoSurgery uses free-form actions, return empty list
467
+ return []
468
  else:
469
+ return []
470
 
471
  def normalize_action_text(text, dataset):
472
  """
 
490
 
491
  def group_records_by_dataset(data):
492
  """Group next_action records by dataset for per-dataset evaluation."""
493
+ from dataset_utils import get_dataset_name
494
  dataset_groups = defaultdict(list)
495
 
496
  for key, record in data.items():
 
498
  if 'next_action' not in qa_type.lower():
499
  continue
500
 
501
+ # Detect dataset
502
+ dataset = get_dataset_name(record)
503
+
504
+ # Extract procedure for NurViD
505
+ procedure = None
506
+ if dataset == "NurViD":
507
+ question_lower = record.get("question", "").lower()
508
+ for proc_name in NURVID_PROCEDURE_ACTIONS.keys():
509
+ if proc_name.lower() in question_lower:
510
+ procedure = proc_name
511
+ break
512
+
513
+ # Restructure record to only include needed fields (consistent with Qwen2.5-VL)
514
+ record_data = {
515
+ "answer": record.get("answer", ""),
516
+ "gnd": record.get("gnd", ""),
517
+ "question": record.get("question", ""),
518
+ "video_id": record.get("metadata", {}).get("video_id", record.get("video_id", "")),
519
+ "procedure": procedure
520
+ }
521
 
522
+ dataset_groups[dataset].append(record_data)
523
 
524
  return dict(dataset_groups)
525
 
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
 
529
  def evaluate_dataset_next_action(dataset_name, records):
 
550
  temp_records = []
551
 
552
  for record in proc_records:
553
+ gnd_text = record.get('gnd', '')
 
 
 
 
 
 
554
 
555
  gnd_text = normalize_action_text(gnd_text, dataset_name)
556
  if gnd_text:
 
603
 
604
  pred_text = normalize_action_text(record.get('answer', ''), dataset_name)
605
 
606
+ # Get ground truth from gnd field only (consistent with Qwen2.5-VL)
607
+ gnd_text = record.get('gnd', '')
 
 
 
 
 
 
 
608
 
609
  gnd_text = normalize_action_text(gnd_text, dataset_name)
610
 
evaluation/evaluate_all_pai.py CHANGED
@@ -596,14 +596,31 @@ def print_overall_evaluation_results(output_file, tasks, all_task_results, skip_
596
  for dataset_name, ds_records in dataset_records_dict.items():
597
  if ds_records:
598
  # Silently evaluate each dataset
599
- with contextlib.redirect_stdout(io.StringIO()):
600
- ds_results = module.evaluate_dataset_next_action(dataset_name, ds_records)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  if "overall" in ds_results:
602
  accuracy = ds_results["overall"].get("accuracy", 0.0)
 
 
603
  all_accuracies.append(accuracy)
604
- # Track weighted metrics
605
- total_correct += int(accuracy * len(ds_records))
606
- total_samples += len(ds_records)
607
 
608
  # Print only final aggregate metrics
609
  if all_accuracies:
 
596
  for dataset_name, ds_records in dataset_records_dict.items():
597
  if ds_records:
598
  # Silently evaluate each dataset
599
+ # Suppress SentenceTransformer/safetensors warnings at fd level
600
+ import logging, os
601
+ logging.disable(logging.WARNING)
602
+ old_fd_out = os.dup(1)
603
+ old_fd_err = os.dup(2)
604
+ devnull = os.open(os.devnull, os.O_WRONLY)
605
+ os.dup2(devnull, 1)
606
+ os.dup2(devnull, 2)
607
+ try:
608
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
609
+ ds_results = module.evaluate_dataset_next_action(dataset_name, ds_records)
610
+ finally:
611
+ os.dup2(old_fd_out, 1)
612
+ os.dup2(old_fd_err, 2)
613
+ os.close(old_fd_out)
614
+ os.close(old_fd_err)
615
+ os.close(devnull)
616
+ logging.disable(logging.NOTSET)
617
  if "overall" in ds_results:
618
  accuracy = ds_results["overall"].get("accuracy", 0.0)
619
+ # Use actual evaluated count, not input count (some records may be skipped)
620
+ evaluated_count = ds_results["overall"].get("count", len(ds_records))
621
  all_accuracies.append(accuracy)
622
+ total_correct += int(accuracy * evaluated_count)
623
+ total_samples += evaluated_count
 
624
 
625
  # Print only final aggregate metrics
626
  if all_accuracies:
evaluation/evaluate_predictions.py CHANGED
@@ -306,7 +306,7 @@ def main():
306
  help="Grouping strategy: 'per-dataset' or 'overall' (default: overall)")
307
  parser.add_argument("--analyze-only", action="store_true",
308
  help="Only analyze the file structure without running evaluations")
309
- parser.add_argument("--skip-llm-judge", action="store_true",
310
  help="Skip LLM judge evaluation for caption tasks (use when LLM scores are pre-computed)")
311
 
312
  args = parser.parse_args()
 
306
  help="Grouping strategy: 'per-dataset' or 'overall' (default: overall)")
307
  parser.add_argument("--analyze-only", action="store_true",
308
  help="Only analyze the file structure without running evaluations")
309
+ parser.add_argument("--skip-llm-judge", default=True, action="store_true",
310
  help="Skip LLM judge evaluation for caption tasks (use when LLM scores are pre-computed)")
311
 
312
  args = parser.parse_args()