MedGRPO Team Claude Sonnet 4.5 commited on
Commit
a36b7fe
·
1 Parent(s): 331979f

Update evaluation metrics and leaderboard display

Browse files

- Modified app.py to fix metric definitions and display format
- Updated eval_dvc.py and eval_tal.py for consistent metric computation
- Aligned with standard evaluation pipeline

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +158 -115
  2. evaluation/eval_dvc.py +186 -26
  3. evaluation/eval_tal.py +31 -10
app.py CHANGED
@@ -25,58 +25,83 @@ EVAL_SCRIPT = Path("evaluation/evaluate_all_pai.py") # Local copy in repo
25
  SUBMISSIONS_DIR.mkdir(exist_ok=True)
26
  RESULTS_DIR.mkdir(exist_ok=True)
27
 
28
- # MedGRPO Task Definitions (8 tasks)
29
- TASKS = {
30
- "tal": {
31
- "name": "Temporal Action Localization",
32
- "metric": "mAP@0.5",
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  "higher_better": True,
34
- "description": "Identify start/end times of surgical actions"
35
  },
36
- "stg": {
37
- "name": "Spatiotemporal Grounding",
38
- "metric": "mIoU",
39
  "higher_better": True,
40
- "description": "Locate actions in both space (bbox) and time"
41
  },
42
- "next_action": {
43
- "name": "Next Action Prediction",
44
- "metric": "Accuracy",
45
  "higher_better": True,
46
- "description": "Predict the next surgical step"
47
  },
48
- "dvc": {
49
- "name": "Dense Video Captioning",
50
- "metric": "LLM Judge (Avg)",
51
  "higher_better": True,
52
- "description": "Generate detailed segment descriptions"
53
  },
54
- "vs": {
55
- "name": "Video Summary",
56
- "metric": "LLM Judge (Avg)",
57
  "higher_better": True,
58
- "description": "Summarize entire surgical videos"
59
  },
60
- "rc": {
61
- "name": "Region Caption",
62
- "metric": "LLM Judge (Avg)",
63
  "higher_better": True,
64
- "description": "Describe regions indicated by bounding boxes"
65
  },
66
- "skill_assessment": {
67
- "name": "Skill Assessment",
68
- "metric": "Accuracy",
69
  "higher_better": True,
70
- "description": "Evaluate surgical skill levels (JIGSAWS)"
71
  },
72
- "cvs_assessment": {
73
- "name": "CVS Assessment",
74
- "metric": "Accuracy",
75
  "higher_better": True,
76
- "description": "Clinical variable scoring"
77
  },
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Test set statistics
81
  TEST_SET_STATS = {
82
  "total_samples": 6245,
@@ -92,13 +117,13 @@ def load_leaderboard() -> pd.DataFrame:
92
  data = json.load(f)
93
  if data:
94
  df = pd.DataFrame(data)
95
- # Sort by average score descending
96
- if 'average' in df.columns:
97
- df = df.sort_values('average', ascending=False).reset_index(drop=True)
98
  return df
99
 
100
- # Return empty dataframe with correct structure
101
- columns = ["rank", "model_name", "organization", "average"] + list(TASKS.keys()) + ["date", "contact"]
102
  return pd.DataFrame(columns=columns)
103
 
104
 
@@ -218,15 +243,11 @@ def run_evaluation(results_file: str, model_name: str) -> Tuple[bool, Dict, str]
218
 
219
  def parse_evaluation_output(output: str) -> Dict[str, float]:
220
  """
221
- Parse evaluation output from evaluate_all_pai.py to extract metrics.
222
-
223
- Expected output format (from --grouping overall):
224
- ================================================================================
225
- TAL - Overall Evaluation (All Datasets Combined)
226
- ================================================================================
227
- Total samples: 1234
228
- mAP@0.5: 0.4567
229
- ...
230
  """
231
  metrics = {}
232
 
@@ -237,62 +258,100 @@ def parse_evaluation_output(output: str) -> Dict[str, float]:
237
  line = line.strip()
238
 
239
  # Detect task headers
240
- if "TAL - Overall Evaluation" in line:
241
  current_task = "tal"
242
- elif "STG - Overall Evaluation" in line:
243
  current_task = "stg"
244
- elif "NEXT_ACTION - Overall Evaluation" in line:
245
  current_task = "next_action"
246
- elif "DVC - Overall Evaluation" in line:
247
  current_task = "dvc"
248
- elif "RC - Overall Evaluation" in line:
249
  current_task = "rc"
250
- elif "VS - Overall Evaluation" in line:
251
  current_task = "vs"
252
- elif "SKILL_ASSESSMENT - Overall Evaluation" in line:
253
  current_task = "skill_assessment"
254
- elif "CVS_ASSESSMENT - Overall Evaluation" in line:
255
  current_task = "cvs_assessment"
256
 
257
  # Extract metrics based on task
258
  if current_task:
259
- if current_task == "tal" and "mAP@0.5:" in line:
260
- try:
261
- value = float(line.split("mAP@0.5:")[-1].strip())
262
- metrics["tal"] = value
263
- except:
264
- pass
 
 
 
 
 
 
 
 
265
 
266
- elif current_task == "stg" and "mean_iou:" in line:
 
267
  try:
268
- value = float(line.split("mean_iou:")[-1].strip())
269
- metrics["stg"] = value
270
  except:
271
  pass
272
 
273
- elif current_task == "next_action" and "Weighted Average Accuracy" in line:
 
274
  try:
275
  value = float(line.split(":")[-1].strip())
276
- metrics["next_action"] = value
277
  except:
278
  pass
279
 
280
- elif current_task in ["dvc", "vs", "rc"]:
281
- # For caption tasks, look for average LLM judge score
282
- if "Average" in line or "Mean" in line:
283
  try:
284
- parts = line.split(":")
285
- if len(parts) == 2:
286
- value = float(parts[-1].strip())
287
- if 0 <= value <= 5: # LLM judge scores are 1-5
288
- metrics[current_task] = value
289
  except:
290
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- elif current_task in ["skill_assessment", "cvs_assessment"] and "accuracy:" in line.lower():
 
293
  try:
294
  value = float(line.split(":")[-1].strip())
295
- metrics[current_task] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  except:
297
  pass
298
 
@@ -328,38 +387,23 @@ def submit_model(file, model_name: str, organization: str, contact: str = "") ->
328
  if not success:
329
  return False, f"❌ Evaluation failed: {eval_msg}"
330
 
331
- # Check if we got metrics for all tasks
332
- missing_tasks = [task for task in TASKS.keys() if task not in metrics]
333
- if len(missing_tasks) > 0:
334
- return False, f"❌ Evaluation incomplete. Missing metrics for: {missing_tasks}"
335
-
336
- # Calculate average score (normalized across all tasks)
337
- # Normalize each task score to 0-1 range, then average
338
- task_scores = []
339
- for task in TASKS.keys():
340
- if task in metrics:
341
- score = metrics[task]
342
- # LLM judge scores are 1-5, others are 0-1
343
- if task in ["dvc", "vs", "rc"]:
344
- normalized = (score - 1) / 4 # Normalize 1-5 to 0-1
345
- else:
346
- normalized = score # Already 0-1
347
- task_scores.append(normalized)
348
-
349
- average_score = sum(task_scores) / len(task_scores) if task_scores else 0.0
350
-
351
- # Add to leaderboard
352
  new_entry = {
353
  "model_name": model_name,
354
  "organization": organization,
355
- "average": round(average_score, 4),
356
- **{task: round(metrics.get(task, 0.0), 4) for task in TASKS.keys()},
357
  "date": datetime.now().strftime("%Y-%m-%d"),
358
  "contact": contact
359
  }
360
 
361
  df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
362
- df = df.sort_values('average', ascending=False).reset_index(drop=True)
 
363
 
364
  save_leaderboard(df)
365
 
@@ -368,13 +412,12 @@ def submit_model(file, model_name: str, organization: str, contact: str = "") ->
368
 
369
  **Model**: {model_name}
370
  **Organization**: {organization}
371
- **Average Score**: {average_score:.4f}
372
 
373
- **Task Scores**:
374
  """
375
- for task, info in TASKS.items():
376
- score = metrics.get(task, 0.0)
377
- success_msg += f"\n- **{info['name']}**: {score:.4f}"
378
 
379
  success_msg += f"\n\n🏆 **Rank**: #{df[df['model_name'] == model_name].index[0] + 1} / {len(df)}"
380
 
@@ -382,24 +425,24 @@ def submit_model(file, model_name: str, organization: str, contact: str = "") ->
382
 
383
 
384
  def format_leaderboard_display(df: pd.DataFrame) -> pd.DataFrame:
385
- """Format leaderboard dataframe for display."""
386
  if df.empty:
387
  return df
388
 
389
- # Create display dataframe with selected columns
390
- display_cols = ["rank", "model_name", "organization", "average"]
391
 
392
- # Add task columns
393
- for task in TASKS.keys():
394
- if task in df.columns:
395
- display_cols.append(task)
396
 
397
  display_cols.append("date")
398
 
399
  # Rename columns for display
400
  display_df = df[display_cols].copy()
401
- display_df.columns = ["Rank", "Model", "Organization", "Average"] + \
402
- [TASKS[task]["name"] for task in TASKS.keys() if task in df.columns] + \
403
  ["Date"]
404
 
405
  return display_df
 
25
  SUBMISSIONS_DIR.mkdir(exist_ok=True)
26
  RESULTS_DIR.mkdir(exist_ok=True)
27
 
28
+ # MedGRPO Metrics Definitions (10 metrics from 8 tasks)
29
+ # Note: TAL has 2 metrics, DVC has 2 metrics, others have 1 metric each
30
+ METRICS = {
31
+ "cvs_acc": {
32
+ "name": "CVS_acc",
33
+ "full_name": "CVS Assessment Accuracy",
34
+ "higher_better": True,
35
+ "description": "Clinical variable scoring accuracy"
36
+ },
37
+ "nap_acc": {
38
+ "name": "NAP_acc",
39
+ "full_name": "Next Action Prediction Accuracy",
40
+ "higher_better": True,
41
+ "description": "Accuracy in predicting next surgical step"
42
+ },
43
+ "sa_acc": {
44
+ "name": "SA_acc",
45
+ "full_name": "Skill Assessment Accuracy",
46
  "higher_better": True,
47
+ "description": "Surgical skill level evaluation accuracy"
48
  },
49
+ "stg_miou": {
50
+ "name": "STG_mIoU",
51
+ "full_name": "Spatiotemporal Grounding mIoU",
52
  "higher_better": True,
53
+ "description": "Mean IoU for spatial+temporal localization"
54
  },
55
+ "tag_miou_03": {
56
+ "name": "TAG_mIoU@0.3",
57
+ "full_name": "Temporal Action Grounding mIoU@0.3",
58
  "higher_better": True,
59
+ "description": "Mean IoU at threshold 0.3 for temporal localization"
60
  },
61
+ "tag_miou_05": {
62
+ "name": "TAG_mIoU@0.5",
63
+ "full_name": "Temporal Action Grounding mIoU@0.5",
64
  "higher_better": True,
65
+ "description": "Mean IoU at threshold 0.5 for temporal localization"
66
  },
67
+ "dvc_llm": {
68
+ "name": "DVC_llm",
69
+ "full_name": "Dense Video Captioning LLM Score",
70
  "higher_better": True,
71
+ "description": "Caption quality score (LLM judge or semantic similarity)"
72
  },
73
+ "dvc_f1": {
74
+ "name": "DVC_F1",
75
+ "full_name": "Dense Video Captioning F1",
76
  "higher_better": True,
77
+ "description": "F1 score for temporal segment localization"
78
  },
79
+ "vs_llm": {
80
+ "name": "VS_llm",
81
+ "full_name": "Video Summary LLM Score",
82
  "higher_better": True,
83
+ "description": "Video summary quality score"
84
  },
85
+ "rc_llm": {
86
+ "name": "RC_llm",
87
+ "full_name": "Region Caption LLM Score",
88
  "higher_better": True,
89
+ "description": "Region caption quality score"
90
  },
91
  }
92
 
93
+ # Keep TASKS for backward compatibility and task descriptions
94
+ TASKS = {
95
+ "tal": "Temporal Action Localization",
96
+ "stg": "Spatiotemporal Grounding",
97
+ "next_action": "Next Action Prediction",
98
+ "dvc": "Dense Video Captioning",
99
+ "vs": "Video Summary",
100
+ "rc": "Region Caption",
101
+ "skill_assessment": "Skill Assessment",
102
+ "cvs_assessment": "CVS Assessment",
103
+ }
104
+
105
  # Test set statistics
106
  TEST_SET_STATS = {
107
  "total_samples": 6245,
 
117
  data = json.load(f)
118
  if data:
119
  df = pd.DataFrame(data)
120
+ # Sort by first metric (CVS_acc) descending - no overall average
121
+ if 'cvs_acc' in df.columns:
122
+ df = df.sort_values('cvs_acc', ascending=False).reset_index(drop=True)
123
  return df
124
 
125
+ # Return empty dataframe with correct structure (no average column)
126
+ columns = ["rank", "model_name", "organization"] + list(METRICS.keys()) + ["date", "contact"]
127
  return pd.DataFrame(columns=columns)
128
 
129
 
 
243
 
244
  def parse_evaluation_output(output: str) -> Dict[str, float]:
245
  """
246
+ Parse evaluation output to extract 10 metrics.
247
+
248
+ Returns dict with keys:
249
+ cvs_acc, nap_acc, sa_acc, stg_miou,
250
+ tag_miou_03, tag_miou_05, dvc_llm, dvc_f1, vs_llm, rc_llm
 
 
 
 
251
  """
252
  metrics = {}
253
 
 
258
  line = line.strip()
259
 
260
  # Detect task headers
261
+ if "TAL" in line and "Overall" in line:
262
  current_task = "tal"
263
+ elif "STG" in line and "Overall" in line:
264
  current_task = "stg"
265
+ elif "NEXT_ACTION" in line and "Overall" in line or "Next Action" in line:
266
  current_task = "next_action"
267
+ elif "DVC" in line and "Overall" in line or "Dense Video Captioning" in line:
268
  current_task = "dvc"
269
+ elif "RC" in line and "Overall" in line or "Region Caption" in line:
270
  current_task = "rc"
271
+ elif "VS" in line and "Overall" in line or "Video Summary" in line:
272
  current_task = "vs"
273
+ elif "SKILL" in line and "Overall" in line or "Skill Assessment" in line:
274
  current_task = "skill_assessment"
275
+ elif "CVS" in line and "Overall" in line or "CVS Assessment" in line:
276
  current_task = "cvs_assessment"
277
 
278
  # Extract metrics based on task
279
  if current_task:
280
+ # TAL: Extract both mIoU@0.3 and mIoU@0.5
281
+ if current_task == "tal":
282
+ if "meanIoU@0.3" in line or "mIoU@0.3" in line:
283
+ try:
284
+ value = float(line.split(":")[-1].strip())
285
+ metrics["tag_miou_03"] = value
286
+ except:
287
+ pass
288
+ if "meanIoU@0.5" in line or "mIoU@0.5" in line:
289
+ try:
290
+ value = float(line.split(":")[-1].strip())
291
+ metrics["tag_miou_05"] = value
292
+ except:
293
+ pass
294
 
295
+ # STG: Extract mIoU
296
+ elif current_task == "stg" and ("mean_iou" in line.lower() or "miou" in line.lower()):
297
  try:
298
+ value = float(line.split(":")[-1].strip())
299
+ metrics["stg_miou"] = value
300
  except:
301
  pass
302
 
303
+ # Next Action: Extract accuracy
304
+ elif current_task == "next_action" and "accuracy" in line.lower():
305
  try:
306
  value = float(line.split(":")[-1].strip())
307
+ metrics["nap_acc"] = value
308
  except:
309
  pass
310
 
311
+ # DVC: Extract both caption_score and temporal_f1
312
+ elif current_task == "dvc":
313
+ if "caption_score" in line.lower() or "caption score" in line.lower():
314
  try:
315
+ value = float(line.split(":")[-1].strip())
316
+ metrics["dvc_llm"] = value
 
 
 
317
  except:
318
  pass
319
+ if "temporal_f1" in line.lower() or "temporal f1" in line.lower():
320
+ try:
321
+ value = float(line.split(":")[-1].strip())
322
+ metrics["dvc_f1"] = value
323
+ except:
324
+ pass
325
+
326
+ # VS: Extract LLM score
327
+ elif current_task == "vs" and ("score" in line.lower() or "average" in line.lower()):
328
+ try:
329
+ value = float(line.split(":")[-1].strip())
330
+ metrics["vs_llm"] = value
331
+ except:
332
+ pass
333
 
334
+ # RC: Extract LLM score
335
+ elif current_task == "rc" and ("score" in line.lower() or "average" in line.lower()):
336
  try:
337
  value = float(line.split(":")[-1].strip())
338
+ metrics["rc_llm"] = value
339
+ except:
340
+ pass
341
+
342
+ # Skill Assessment: Extract accuracy
343
+ elif current_task == "skill_assessment" and "accuracy" in line.lower():
344
+ try:
345
+ value = float(line.split(":")[-1].strip())
346
+ metrics["sa_acc"] = value
347
+ except:
348
+ pass
349
+
350
+ # CVS Assessment: Extract accuracy
351
+ elif current_task == "cvs_assessment" and "accuracy" in line.lower():
352
+ try:
353
+ value = float(line.split(":")[-1].strip())
354
+ metrics["cvs_acc"] = value
355
  except:
356
  pass
357
 
 
387
  if not success:
388
  return False, f"❌ Evaluation failed: {eval_msg}"
389
 
390
+ # Check if we got all 10 metrics
391
+ missing_metrics = [m for m in METRICS.keys() if m not in metrics]
392
+ if len(missing_metrics) > 0:
393
+ return False, f"❌ Evaluation incomplete. Missing metrics: {missing_metrics}"
394
+
395
+ # Add to leaderboard (no average calculation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  new_entry = {
397
  "model_name": model_name,
398
  "organization": organization,
399
+ **{metric: round(metrics.get(metric, 0.0), 4) for metric in METRICS.keys()},
 
400
  "date": datetime.now().strftime("%Y-%m-%d"),
401
  "contact": contact
402
  }
403
 
404
  df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
405
+ # Sort by first metric (CVS_acc)
406
+ df = df.sort_values('cvs_acc', ascending=False).reset_index(drop=True)
407
 
408
  save_leaderboard(df)
409
 
 
412
 
413
  **Model**: {model_name}
414
  **Organization**: {organization}
 
415
 
416
+ **Metric Scores**:
417
  """
418
+ for metric_key, metric_info in METRICS.items():
419
+ score = metrics.get(metric_key, 0.0)
420
+ success_msg += f"\n- **{metric_info['name']}**: {score:.4f}"
421
 
422
  success_msg += f"\n\n🏆 **Rank**: #{df[df['model_name'] == model_name].index[0] + 1} / {len(df)}"
423
 
 
425
 
426
 
427
  def format_leaderboard_display(df: pd.DataFrame) -> pd.DataFrame:
428
+ """Format leaderboard dataframe for display with 10 metrics (no average)."""
429
  if df.empty:
430
  return df
431
 
432
+ # Create display dataframe with selected columns (no average)
433
+ display_cols = ["rank", "model_name", "organization"]
434
 
435
+ # Add metric columns in order
436
+ for metric_key in METRICS.keys():
437
+ if metric_key in df.columns:
438
+ display_cols.append(metric_key)
439
 
440
  display_cols.append("date")
441
 
442
  # Rename columns for display
443
  display_df = df[display_cols].copy()
444
+ display_df.columns = ["Rank", "Model", "Organization"] + \
445
+ [METRICS[m]["name"] for m in METRICS.keys() if m in df.columns] + \
446
  ["Date"]
447
 
448
  return display_df
evaluation/eval_dvc.py CHANGED
@@ -1,18 +1,115 @@
1
- """Dense Video Captioning evaluation using LLM judge."""
2
 
3
  import json
4
  import sys
 
 
5
  from eval_caption_llm_judge import evaluate_caption_task
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def group_records_by_dataset(data):
9
  """Group DVC records by dataset for per-dataset evaluation."""
10
- from collections import defaultdict
11
  dataset_groups = defaultdict(list)
12
 
13
  for key, record in data.items():
14
  qa_type = record.get('qa_type', '')
15
- if 'dense_captioning' not in qa_type.lower():
 
16
  continue
17
 
18
  dataset = record.get('dataset', record.get('dataset_name', record.get('metadata', {}).get('dataset', 'Unknown')))
@@ -33,35 +130,84 @@ def group_records_by_dataset(data):
33
 
34
 
35
  def evaluate_dataset_dvc(dataset_name, records):
36
- """Evaluate DVC for a specific dataset using caption evaluator."""
37
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
38
 
39
- # DVC uses same evaluation as caption tasks
40
- # Create a temporary file with just these records
41
  import tempfile
42
  import os
43
-
44
  temp_data = {str(i): record for i, record in enumerate(records)}
45
-
46
  with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
47
  json.dump(temp_data, f)
48
  temp_file = f.name
49
-
50
  try:
51
- # Use caption evaluator (treats DVC as extended caption task)
52
- result = evaluate_caption_task(temp_file, 'dense_captioning')
53
-
54
- # Return in expected format
55
- return {
56
- 'overall': {
57
- 'score': result['score'],
58
- 'method': result['method'],
59
- 'count': len(records)
60
- }
61
- }
62
  finally:
63
  os.unlink(temp_file)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def main():
67
  """Main evaluation function for DVC."""
@@ -84,7 +230,7 @@ def main():
84
 
85
  if not any(dataset_records.values()):
86
  print("No DVC records found!")
87
- return
88
 
89
  all_results = {}
90
  for dataset_name, records in dataset_records.items():
@@ -96,16 +242,30 @@ def main():
96
  print("DENSE VIDEO CAPTIONING EVALUATION SUMMARY")
97
  print(f"{'='*80}")
98
 
 
 
 
 
99
  for dataset_name, results in all_results.items():
100
  if results:
101
  print(f"\n{dataset_name}:")
102
  for key, metrics in results.items():
103
  if isinstance(metrics, dict):
104
- for metric_name, value in metrics.items():
105
- if metric_name != 'count':
106
- print(f" {metric_name}: {value}")
107
- else:
108
- print(f" samples: {value}")
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  if __name__ == "__main__":
 
1
+ """Dense Video Captioning evaluation using LLM judge + temporal F1."""
2
 
3
  import json
4
  import sys
5
+ import numpy as np
6
+ from collections import defaultdict
7
  from eval_caption_llm_judge import evaluate_caption_task
8
 
9
 
10
+ def compute_iou(pred_segment, gt_segment):
11
+ """Compute IoU between two segments [start, end]."""
12
+ pred_start, pred_end = pred_segment
13
+ gt_start, gt_end = gt_segment
14
+
15
+ # Compute intersection
16
+ inter_start = max(pred_start, gt_start)
17
+ inter_end = min(pred_end, gt_end)
18
+ intersection = max(0, inter_end - inter_start)
19
+
20
+ # Compute union
21
+ union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
22
+
23
+ if union == 0:
24
+ return 0
25
+
26
+ return intersection / union
27
+
28
+
29
+ def compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5):
30
+ """
31
+ Compute F1 score for temporal segment matching.
32
+
33
+ Args:
34
+ pred_segments: List of predicted [start, end] segments
35
+ gt_segments: List of ground truth [start, end] segments
36
+ iou_threshold: IoU threshold for matching (default 0.5)
37
+
38
+ Returns:
39
+ Dict with precision, recall, and f1 scores
40
+ """
41
+ if not pred_segments or not gt_segments:
42
+ return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
43
+
44
+ # Match predicted segments to ground truth
45
+ matched_gt = set()
46
+ matched_pred = set()
47
+
48
+ for pred_idx, pred_seg in enumerate(pred_segments):
49
+ best_iou = 0
50
+ best_gt_idx = -1
51
+
52
+ for gt_idx, gt_seg in enumerate(gt_segments):
53
+ if gt_idx in matched_gt:
54
+ continue
55
+
56
+ iou = compute_iou(pred_seg, gt_seg)
57
+ if iou >= iou_threshold and iou > best_iou:
58
+ best_iou = iou
59
+ best_gt_idx = gt_idx
60
+
61
+ if best_gt_idx >= 0:
62
+ matched_pred.add(pred_idx)
63
+ matched_gt.add(best_gt_idx)
64
+
65
+ # Compute precision, recall, F1
66
+ precision = len(matched_pred) / len(pred_segments) if pred_segments else 0
67
+ recall = len(matched_gt) / len(gt_segments) if gt_segments else 0
68
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
69
+
70
+ return {
71
+ 'precision': precision,
72
+ 'recall': recall,
73
+ 'f1': f1
74
+ }
75
+
76
+
77
+ def parse_dvc_segments(text):
78
+ """
79
+ Parse DVC output to extract segments.
80
+ Supports multiple formats:
81
+ - [start-end] caption
82
+ - (start-end) caption
83
+ - start-end seconds: caption
84
+ """
85
+ import re
86
+ segments = []
87
+
88
+ # Pattern 1: [0.0-5.2] or (0.0-5.2)
89
+ pattern1 = r'[\[\(](\d+\.?\d*)\s*-\s*(\d+\.?\d*)[\]\)]'
90
+
91
+ # Pattern 2: 0.0-5.2 seconds:
92
+ pattern2 = r'(\d+\.?\d*)\s*-\s*(\d+\.?\d*)\s*seconds?:'
93
+
94
+ # Try both patterns
95
+ for pattern in [pattern1, pattern2]:
96
+ matches = re.finditer(pattern, text, re.IGNORECASE)
97
+ for match in matches:
98
+ start = float(match.group(1))
99
+ end = float(match.group(2))
100
+ segments.append([start, end])
101
+
102
+ return segments
103
+
104
+
105
  def group_records_by_dataset(data):
106
  """Group DVC records by dataset for per-dataset evaluation."""
 
107
  dataset_groups = defaultdict(list)
108
 
109
  for key, record in data.items():
110
  qa_type = record.get('qa_type', '')
111
+ # Match any dense_captioning variant (dense_captioning, dense_captioning_gpt, dense_captioning_gemini, dc)
112
+ if not any(x in qa_type.lower() for x in ['dense_captioning', 'dense_caption', 'dc']):
113
  continue
114
 
115
  dataset = record.get('dataset', record.get('dataset_name', record.get('metadata', {}).get('dataset', 'Unknown')))
 
130
 
131
 
132
  def evaluate_dataset_dvc(dataset_name, records):
133
+ """Evaluate DVC for a specific dataset using caption quality + temporal F1."""
134
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
135
 
136
+ # Step 1: Evaluate caption quality using LLM judge
 
137
  import tempfile
138
  import os
139
+
140
  temp_data = {str(i): record for i, record in enumerate(records)}
141
+
142
  with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
143
  json.dump(temp_data, f)
144
  temp_file = f.name
145
+
146
  try:
147
+ # Use caption evaluator for caption quality
148
+ caption_result = evaluate_caption_task(temp_file, 'dense_captioning')
149
+ caption_score = caption_result['score']
150
+ caption_method = caption_result['method']
 
 
 
 
 
 
 
151
  finally:
152
  os.unlink(temp_file)
153
 
154
+ # Step 2: Compute temporal F1 for segment localization
155
+ all_f1_scores = []
156
+
157
+ for record in records:
158
+ # Get FPS for time-to-frame conversion
159
+ fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
160
+ if isinstance(fps, str):
161
+ fps = float(fps)
162
+
163
+ # Parse predicted segments from answer
164
+ pred_text = record.get('answer', '')
165
+ pred_segments = parse_dvc_segments(pred_text)
166
+
167
+ # Get ground truth segments from struc_info
168
+ struc_info = record.get('struc_info', [])
169
+ gt_segments = []
170
+
171
+ if isinstance(struc_info, list):
172
+ for item in struc_info:
173
+ if isinstance(item, dict):
174
+ # Handle different formats
175
+ if 'dc_segments' in item:
176
+ # NurViD format
177
+ segments = item['dc_segments']
178
+ elif 'start' in item and 'end' in item:
179
+ # Direct segment format
180
+ segments = [item]
181
+ else:
182
+ continue
183
+
184
+ for seg in (segments if isinstance(segments, list) else [segments]):
185
+ if 'start' in seg and 'end' in seg:
186
+ # Convert to seconds (struc_info is in seconds)
187
+ gt_segments.append([
188
+ float(seg['start']),
189
+ float(seg['end'])
190
+ ])
191
+
192
+ # Compute F1 for this sample
193
+ if pred_segments and gt_segments:
194
+ f1_result = compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5)
195
+ all_f1_scores.append(f1_result['f1'])
196
+
197
+ # Aggregate F1 scores
198
+ avg_f1 = np.mean(all_f1_scores) if all_f1_scores else 0.0
199
+
200
+ # Return both caption quality and temporal F1
201
+ return {
202
+ 'overall': {
203
+ 'caption_score': caption_score,
204
+ 'caption_method': caption_method,
205
+ 'temporal_f1': avg_f1,
206
+ 'count': len(records),
207
+ 'f1_samples': len(all_f1_scores)
208
+ }
209
+ }
210
+
211
 
212
  def main():
213
  """Main evaluation function for DVC."""
 
230
 
231
  if not any(dataset_records.values()):
232
  print("No DVC records found!")
233
+ return {}
234
 
235
  all_results = {}
236
  for dataset_name, records in dataset_records.items():
 
242
  print("DENSE VIDEO CAPTIONING EVALUATION SUMMARY")
243
  print(f"{'='*80}")
244
 
245
+ # Aggregate overall metrics
246
+ all_caption_scores = []
247
+ all_f1_scores = []
248
+
249
  for dataset_name, results in all_results.items():
250
  if results:
251
  print(f"\n{dataset_name}:")
252
  for key, metrics in results.items():
253
  if isinstance(metrics, dict):
254
+ print(f" Caption Score ({metrics.get('caption_method', 'unknown')}): {metrics.get('caption_score', 0):.4f}")
255
+ print(f" Temporal F1@0.5: {metrics.get('temporal_f1', 0):.4f}")
256
+ print(f" Total samples: {metrics.get('count', 0)}")
257
+ print(f" F1 computed on: {metrics.get('f1_samples', 0)} samples")
258
+
259
+ # Collect for overall average
260
+ all_caption_scores.append(metrics.get('caption_score', 0))
261
+ all_f1_scores.append(metrics.get('temporal_f1', 0))
262
+
263
+ # Return overall aggregated results
264
+ return {
265
+ 'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
266
+ 'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
267
+ 'method': all_results[list(all_results.keys())[0]]['overall'].get('caption_method', 'unknown') if all_results else 'unknown'
268
+ }
269
 
270
 
271
  if __name__ == "__main__":
evaluation/eval_tal.py CHANGED
@@ -194,22 +194,27 @@ def evaluate_dataset_tal(dataset_name, records):
194
  'ground_truth': gt_spans
195
  }]
196
 
197
- # Evaluate this record
198
- result = evaluate_tal_record(formatted_record, tiou_thresh=0.5)
199
- results_by_fps[fps].append(result)
 
200
 
201
  # Aggregate results
202
  aggregated = {}
203
  for fps, results_list in results_by_fps.items():
204
- # Extract metrics from results (recall and meanIoU)
205
- all_recalls = [r.get(f'Recall@0.50', 0) for r in results_list if r]
206
- all_mean_ious = [r.get(f'meanIoU@0.50', 0) for r in results_list if r]
 
 
207
 
208
- if all_recalls:
209
  aggregated[f'fps_{fps}'] = {
210
- 'recall@0.5': np.mean(all_recalls),
211
- 'meanIoU@0.5': np.mean(all_mean_ious),
212
- 'count': len(all_recalls)
 
 
213
  }
214
 
215
  return aggregated
@@ -254,6 +259,10 @@ def main():
254
  print("TAL EVALUATION SUMMARY")
255
  print(f"{'='*80}")
256
 
 
 
 
 
257
  for dataset_name, fps_results in all_results.items():
258
  if fps_results:
259
  print(f"\n{dataset_name}:")
@@ -265,6 +274,18 @@ def main():
265
  else:
266
  print(f" samples: {value}")
267
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  if __name__ == "__main__":
270
  main()
 
194
  'ground_truth': gt_spans
195
  }]
196
 
197
+ # Evaluate this record at both IoU thresholds
198
+ result_03 = evaluate_tal_record(formatted_record, tiou_thresh=0.3)
199
+ result_05 = evaluate_tal_record(formatted_record, tiou_thresh=0.5)
200
+ results_by_fps[fps].append({'0.3': result_03, '0.5': result_05})
201
 
202
  # Aggregate results
203
  aggregated = {}
204
  for fps, results_list in results_by_fps.items():
205
+ # Extract metrics from results at both thresholds
206
+ all_recalls_03 = [r['0.3'].get(f'Recall@0.30', 0) for r in results_list if r]
207
+ all_mean_ious_03 = [r['0.3'].get(f'meanIoU@0.30', 0) for r in results_list if r]
208
+ all_recalls_05 = [r['0.5'].get(f'Recall@0.50', 0) for r in results_list if r]
209
+ all_mean_ious_05 = [r['0.5'].get(f'meanIoU@0.50', 0) for r in results_list if r]
210
 
211
+ if all_recalls_03:
212
  aggregated[f'fps_{fps}'] = {
213
+ 'recall@0.3': np.mean(all_recalls_03),
214
+ 'meanIoU@0.3': np.mean(all_mean_ious_03),
215
+ 'recall@0.5': np.mean(all_recalls_05),
216
+ 'meanIoU@0.5': np.mean(all_mean_ious_05),
217
+ 'count': len(all_recalls_03)
218
  }
219
 
220
  return aggregated
 
259
  print("TAL EVALUATION SUMMARY")
260
  print(f"{'='*80}")
261
 
262
+ # Aggregate metrics across all datasets
263
+ all_miou_03 = []
264
+ all_miou_05 = []
265
+
266
  for dataset_name, fps_results in all_results.items():
267
  if fps_results:
268
  print(f"\n{dataset_name}:")
 
274
  else:
275
  print(f" samples: {value}")
276
 
277
+ # Collect for overall average
278
+ if 'meanIoU@0.3' in metrics:
279
+ all_miou_03.append(metrics['meanIoU@0.3'])
280
+ if 'meanIoU@0.5' in metrics:
281
+ all_miou_05.append(metrics['meanIoU@0.5'])
282
+
283
+ # Return overall aggregated results
284
+ return {
285
+ 'meanIoU@0.3': np.mean(all_miou_03) if all_miou_03 else 0.0,
286
+ 'meanIoU@0.5': np.mean(all_miou_05) if all_miou_05 else 0.0
287
+ }
288
+
289
 
290
  if __name__ == "__main__":
291
  main()