MedGRPO Team
Fix eval_dvc.py main() to support --skip-llm-judge flag
5f41159
"""Dense Video Captioning evaluation using LLM judge + temporal F1."""
import json
import sys
import numpy as np
from collections import defaultdict
from eval_caption_llm_judge import evaluate_caption_task
def compute_iou(pred_segment, gt_segment):
"""Compute IoU between two segments [start, end]."""
pred_start, pred_end = pred_segment
gt_start, gt_end = gt_segment
# Compute intersection
inter_start = max(pred_start, gt_start)
inter_end = min(pred_end, gt_end)
intersection = max(0, inter_end - inter_start)
# Compute union
union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
if union == 0:
return 0
return intersection / union
def compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5):
"""
Compute F1 score for temporal segment matching.
Args:
pred_segments: List of predicted [start, end] segments
gt_segments: List of ground truth [start, end] segments
iou_threshold: IoU threshold for matching (default 0.5)
Returns:
Dict with precision, recall, and f1 scores
"""
if not pred_segments or not gt_segments:
return {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
# Match predicted segments to ground truth
matched_gt = set()
matched_pred = set()
for pred_idx, pred_seg in enumerate(pred_segments):
best_iou = 0
best_gt_idx = -1
for gt_idx, gt_seg in enumerate(gt_segments):
if gt_idx in matched_gt:
continue
iou = compute_iou(pred_seg, gt_seg)
if iou >= iou_threshold and iou > best_iou:
best_iou = iou
best_gt_idx = gt_idx
if best_gt_idx >= 0:
matched_pred.add(pred_idx)
matched_gt.add(best_gt_idx)
# Compute precision, recall, F1
precision = len(matched_pred) / len(pred_segments) if pred_segments else 0
recall = len(matched_gt) / len(gt_segments) if gt_segments else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1
}
def parse_dvc_segments(text):
"""
Parse DVC output to extract segments.
Supports multiple formats:
- [start-end] caption
- (start-end) caption
- start-end seconds: caption
"""
import re
segments = []
# Pattern 1: [0.0-5.2] or (0.0-5.2)
pattern1 = r'[\[\(](\d+\.?\d*)\s*-\s*(\d+\.?\d*)[\]\)]'
# Pattern 2: 0.0-5.2 seconds:
pattern2 = r'(\d+\.?\d*)\s*-\s*(\d+\.?\d*)\s*seconds?:'
# Try both patterns
for pattern in [pattern1, pattern2]:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
start = float(match.group(1))
end = float(match.group(2))
segments.append([start, end])
return segments
def group_records_by_dataset(data):
"""Group DVC records by dataset for per-dataset evaluation."""
dataset_groups = defaultdict(list)
for key, record in data.items():
qa_type = record.get('qa_type', '')
# Match any dense_captioning variant (dense_captioning, dense_captioning_gpt, dense_captioning_gemini, dc)
if not any(x in qa_type.lower() for x in ['dense_captioning', 'dense_caption', 'dc']):
continue
# Check data_source first (leaderboard format), then fall back to dataset/dataset_name
dataset = record.get('data_source', record.get('dataset', record.get('dataset_name', record.get('metadata', {}).get('dataset', 'Unknown'))))
video_id = record.get('video_id', record.get('metadata', {}).get('video_id', ''))
if dataset == 'Unknown' and video_id:
video_id_lower = str(video_id).lower()
if len(video_id) == 11 and any(c.isalpha() for c in video_id):
dataset = "AVOS"
elif "_part" in video_id_lower:
dataset = "CoPESD"
elif "video" in video_id_lower:
dataset = "CholecT50"
dataset_groups[dataset].append(record)
return dict(dataset_groups)
def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
"""Evaluate DVC for a specific dataset using caption quality + temporal F1."""
print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
# Step 1: Evaluate caption quality using LLM judge (unless skipped)
if skip_llm_judge:
print(f" Skipping LLM judge caption evaluation (--skip-llm-judge flag)")
caption_score = 0.0
caption_method = 'skipped'
else:
import tempfile
import os
temp_data = {str(i): record for i, record in enumerate(records)}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(temp_data, f)
temp_file = f.name
try:
# Use caption evaluator for caption quality
caption_result = evaluate_caption_task(temp_file, 'dense_captioning')
caption_score = caption_result['score']
caption_method = caption_result['method']
finally:
os.unlink(temp_file)
# Step 2: Compute temporal F1 for segment localization
all_f1_scores = []
for record in records:
# Get FPS for time-to-frame conversion
fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0))
if isinstance(fps, str):
fps = float(fps)
# Parse predicted segments from answer
pred_text = record.get('answer', '')
pred_segments = parse_dvc_segments(pred_text)
# Get ground truth segments from struc_info
struc_info = record.get('struc_info', [])
gt_segments = []
if isinstance(struc_info, list):
for item in struc_info:
if isinstance(item, dict):
# Handle different formats
if 'dc_segments' in item:
# NurViD format
segments = item['dc_segments']
elif 'start' in item and 'end' in item:
# Direct segment format
segments = [item]
else:
continue
for seg in (segments if isinstance(segments, list) else [segments]):
if 'start' in seg and 'end' in seg:
# Convert to seconds (struc_info is in seconds)
gt_segments.append([
float(seg['start']),
float(seg['end'])
])
# Compute F1 for this sample
if pred_segments and gt_segments:
f1_result = compute_temporal_f1(pred_segments, gt_segments, iou_threshold=0.5)
all_f1_scores.append(f1_result['f1'])
# Aggregate F1 scores
avg_f1 = np.mean(all_f1_scores) if all_f1_scores else 0.0
# Return both caption quality and temporal F1
return {
'overall': {
'caption_score': caption_score,
'caption_method': caption_method,
'temporal_f1': avg_f1,
'count': len(records),
'f1_samples': len(all_f1_scores)
}
}
def main():
"""Main evaluation function for DVC."""
if len(sys.argv) < 2:
print("Usage: python eval_dvc.py <results_json_file> [--skip-llm-judge]")
print("Example: python eval_dvc.py results/model_results.json")
print("Example: python eval_dvc.py results/model_results.json --skip-llm-judge")
sys.exit(1)
output_file = sys.argv[1]
skip_llm_judge = '--skip-llm-judge' in sys.argv
print(f"Loading results from: {output_file}")
if skip_llm_judge:
print("⚠️ --skip-llm-judge flag detected: Skipping caption evaluation, computing temporal F1 only")
with open(output_file, "r") as f:
infer_output = json.load(f)
dataset_records = group_records_by_dataset(infer_output)
print(f"\nFound datasets: {list(dataset_records.keys())}")
for dataset, records in dataset_records.items():
print(f" {dataset}: {len(records)} DVC records")
if not any(dataset_records.values()):
print("No DVC records found!")
return {}
all_results = {}
for dataset_name, records in dataset_records.items():
if records:
results = evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=skip_llm_judge)
all_results[dataset_name] = results
print(f"\n{'='*80}")
print("DENSE VIDEO CAPTIONING EVALUATION SUMMARY")
print(f"{'='*80}")
# Aggregate overall metrics
all_caption_scores = []
all_f1_scores = []
for dataset_name, results in all_results.items():
if results:
print(f"\n{dataset_name}:")
for key, metrics in results.items():
if isinstance(metrics, dict):
print(f" Caption Score ({metrics.get('caption_method', 'unknown')}): {metrics.get('caption_score', 0):.4f}")
print(f" Temporal F1@0.5: {metrics.get('temporal_f1', 0):.4f}")
print(f" Total samples: {metrics.get('count', 0)}")
print(f" F1 computed on: {metrics.get('f1_samples', 0)} samples")
# Collect for overall average
all_caption_scores.append(metrics.get('caption_score', 0))
all_f1_scores.append(metrics.get('temporal_f1', 0))
# Return overall aggregated results
return {
'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
'method': all_results[list(all_results.keys())[0]]['overall'].get('caption_method', 'unknown') if all_results else 'unknown'
}
if __name__ == "__main__":
main()