MedGRPO Team commited on
Commit
e2b1040
·
1 Parent(s): a9b5dcf

fix issues

Browse files
README.md CHANGED
@@ -221,9 +221,10 @@ To compute the **average score** fairly across tasks:
221
  ## Links
222
 
223
  - 📄 **Paper**: [https://arxiv.org/abs/2512.06581](https://arxiv.org/abs/2512.06581)
224
- - 🌐 **Project**: [https://gaozhongpai.github.io/MedGRPO-Page/](https://gaozhongpai.github.io/MedGRPO-Page/)
225
  - 💾 **Dataset**: [https://huggingface.co/datasets/UIIAmerica/MedVidBench](https://huggingface.co/datasets/UIIAmerica/MedVidBench)
226
- - 💻 **GitHub**: [https://github.com/Gaozhongpai/MedGRPO](https://github.com/Gaozhongpai/MedGRPO)
 
227
  - 🏆 **Leaderboard**: [https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard](https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard)
228
 
229
  ## Citation
@@ -245,5 +246,5 @@ To compute the **average score** fairly across tasks:
245
  ## Contact
246
 
247
  For questions or issues:
248
- - Open an issue on [GitHub](https://github.com/Gaozhongpai/MedGRPO)
249
- - Visit the [project page](https://gaozhongpai.github.io/MedGRPO-Page/)
 
221
  ## Links
222
 
223
  - 📄 **Paper**: [https://arxiv.org/abs/2512.06581](https://arxiv.org/abs/2512.06581)
224
+ - 🌐 **Project**: [https://uii-america.github.io/MedGRPO/](https://uii-america.github.io/MedGRPO/)
225
  - 💾 **Dataset**: [https://huggingface.co/datasets/UIIAmerica/MedVidBench](https://huggingface.co/datasets/UIIAmerica/MedVidBench)
226
+ - 💻 **GitHub**: [https://github.com/UII-America/MedGRPO-Code](https://github.com/UII-America/MedGRPO-Code)
227
+ - 🎮 **Demo**: [https://huggingface.co/spaces/UIIAmerica/MedGRPO-Demo](https://huggingface.co/spaces/UIIAmerica/MedGRPO-Demo)
228
  - 🏆 **Leaderboard**: [https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard](https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard)
229
 
230
  ## Citation
 
246
  ## Contact
247
 
248
  For questions or issues:
249
+ - Open an issue on [GitHub](https://github.com/UII-America/MedGRPO-Code)
250
+ - Visit the [project page](https://uii-america.github.io/MedGRPO/)
app.py CHANGED
@@ -865,22 +865,23 @@ def parse_evaluation_output(output: str) -> Dict[str, float]:
865
  line = line.strip()
866
 
867
  # Detect task headers
868
- if "TAL" in line and "Overall" in line:
 
 
 
 
 
869
  current_task = "tal"
870
  elif "STG" in line and "Overall" in line:
871
  current_task = "stg"
872
- elif "NEXT_ACTION" in line and "Overall" in line or "Next Action" in line:
873
  current_task = "next_action"
874
- elif "DVC" in line and "Overall" in line or "Dense Video Captioning" in line:
875
  current_task = "dvc"
876
- elif "RC" in line and "Overall" in line or "Region Caption" in line:
877
  current_task = "rc"
878
- elif "VS" in line and "Overall" in line or "Video Summary" in line:
879
  current_task = "vs"
880
- elif "SKILL" in line and "Overall" in line or "Skill Assessment" in line:
881
- current_task = "skill_assessment"
882
- elif "CVS" in line and "Overall" in line or "CVS Assessment" in line:
883
- current_task = "cvs_assessment"
884
 
885
  # Detect IoU sections for TAL (new format)
886
  if current_task == "tal":
@@ -951,16 +952,16 @@ def parse_evaluation_output(output: str) -> Dict[str, float]:
951
  # VS: Extract LLM score
952
  elif current_task == "vs" and ("score" in line.lower() or "average" in line.lower()):
953
  try:
954
- value = float(line.split(":")[-1].strip())
955
- metrics["vs_llm"] = value
956
  except:
957
  pass
958
 
959
  # RC: Extract LLM score
960
  elif current_task == "rc" and ("score" in line.lower() or "average" in line.lower()):
961
  try:
962
- value = float(line.split(":")[-1].strip())
963
- metrics["rc_llm"] = value
964
  except:
965
  pass
966
 
@@ -1652,9 +1653,10 @@ with gr.Blocks(title="MedVidBench Leaderboard", theme=gr.themes.Soft()) as demo:
1652
  8 medical video understanding tasks across 8 surgical datasets.
1653
 
1654
  📄 **Paper**: [MedGRPO: Multi-Task Reinforcement Learning for Heterogeneous Medical Video Understanding](https://arxiv.org/abs/2512.06581)
1655
- 🌐 **Project**: [gaozhongpai.github.io/MedGRPO-Page](https://gaozhongpai.github.io/MedGRPO-Page/)
1656
  💾 **Dataset**: [huggingface.co/datasets/UIIAmerica/MedVidBench](https://huggingface.co/datasets/UIIAmerica/MedVidBench)
1657
- 💻 **GitHub**: [github.com/Gaozhongpai/MedGRPO](https://github.com/Gaozhongpai/MedGRPO)
 
1658
  """)
1659
 
1660
  with gr.Tabs():
@@ -1931,9 +1933,10 @@ with gr.Blocks(title="MedVidBench Leaderboard", theme=gr.themes.Soft()) as demo:
1931
  #### Links
1932
 
1933
  - 📄 **Paper**: [https://arxiv.org/abs/2512.06581](https://arxiv.org/abs/2512.06581)
1934
- - 🌐 **Project Page**: [https://gaozhongpai.github.io/MedGRPO-Page/](https://gaozhongpai.github.io/MedGRPO-Page/)
1935
  - 💾 **Dataset**: [https://huggingface.co/datasets/UIIAmerica/MedVidBench](https://huggingface.co/datasets/UIIAmerica/MedVidBench)
1936
- - 💻 **GitHub**: [https://github.com/Gaozhongpai/MedGRPO](https://github.com/Gaozhongpai/MedGRPO)
 
1937
  - 🏆 **Leaderboard**: [https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard](https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard)
1938
 
1939
  #### Dataset
@@ -1953,8 +1956,8 @@ with gr.Blocks(title="MedVidBench Leaderboard", theme=gr.themes.Soft()) as demo:
1953
  #### Contact
1954
 
1955
  For questions or issues:
1956
- - Open an issue on [GitHub](https://github.com/Gaozhongpai/MedGRPO)
1957
- - Visit the [project page](https://gaozhongpai.github.io/MedGRPO-Page/)
1958
  - Email: [Contact via GitHub](https://github.com/YuhaoSu)
1959
  """)
1960
 
 
865
  line = line.strip()
866
 
867
  # Detect task headers
868
+ # NOTE: Order matters check CVS before VS (since "CVS" contains "VS")
869
+ if ("CVS" in line and "Overall" in line) or "CVS Assessment" in line:
870
+ current_task = "cvs_assessment"
871
+ elif ("SKILL" in line and "Overall" in line) or "Skill Assessment" in line:
872
+ current_task = "skill_assessment"
873
+ elif "TAL" in line and "Overall" in line:
874
  current_task = "tal"
875
  elif "STG" in line and "Overall" in line:
876
  current_task = "stg"
877
+ elif ("NEXT_ACTION" in line and "Overall" in line) or "Next Action" in line:
878
  current_task = "next_action"
879
+ elif ("DVC" in line and "Overall" in line) or "Dense Video Captioning" in line:
880
  current_task = "dvc"
881
+ elif ("RC" in line and "Overall" in line) or "Region Caption" in line:
882
  current_task = "rc"
883
+ elif ("VS" in line and "Overall" in line) or "Video Summary" in line:
884
  current_task = "vs"
 
 
 
 
885
 
886
  # Detect IoU sections for TAL (new format)
887
  if current_task == "tal":
 
952
  # VS: Extract LLM score
953
  elif current_task == "vs" and ("score" in line.lower() or "average" in line.lower()):
954
  try:
955
+ val_str = line.split(":")[-1].strip().split("(")[0].strip()
956
+ metrics["vs_llm"] = float(val_str)
957
  except:
958
  pass
959
 
960
  # RC: Extract LLM score
961
  elif current_task == "rc" and ("score" in line.lower() or "average" in line.lower()):
962
  try:
963
+ val_str = line.split(":")[-1].strip().split("(")[0].strip()
964
+ metrics["rc_llm"] = float(val_str)
965
  except:
966
  pass
967
 
 
1653
  8 medical video understanding tasks across 8 surgical datasets.
1654
 
1655
  📄 **Paper**: [MedGRPO: Multi-Task Reinforcement Learning for Heterogeneous Medical Video Understanding](https://arxiv.org/abs/2512.06581)
1656
+ 🌐 **Project**: [uii-america.github.io/MedGRPO](https://uii-america.github.io/MedGRPO/)
1657
  💾 **Dataset**: [huggingface.co/datasets/UIIAmerica/MedVidBench](https://huggingface.co/datasets/UIIAmerica/MedVidBench)
1658
+ 💻 **GitHub**: [github.com/UII-America/MedGRPO-Code](https://github.com/UII-America/MedGRPO-Code)
1659
+ 🎮 **Demo**: [huggingface.co/spaces/UIIAmerica/MedGRPO-Demo](https://huggingface.co/spaces/UIIAmerica/MedGRPO-Demo)
1660
  """)
1661
 
1662
  with gr.Tabs():
 
1933
  #### Links
1934
 
1935
  - 📄 **Paper**: [https://arxiv.org/abs/2512.06581](https://arxiv.org/abs/2512.06581)
1936
+ - 🌐 **Project Page**: [https://uii-america.github.io/MedGRPO/](https://uii-america.github.io/MedGRPO/)
1937
  - 💾 **Dataset**: [https://huggingface.co/datasets/UIIAmerica/MedVidBench](https://huggingface.co/datasets/UIIAmerica/MedVidBench)
1938
+ - 💻 **GitHub**: [https://github.com/UII-America/MedGRPO-Code](https://github.com/UII-America/MedGRPO-Code)
1939
+ - 🎮 **Demo**: [https://huggingface.co/spaces/UIIAmerica/MedGRPO-Demo](https://huggingface.co/spaces/UIIAmerica/MedGRPO-Demo)
1940
  - 🏆 **Leaderboard**: [https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard](https://huggingface.co/spaces/UIIAmerica/MedVidBench-Leaderboard)
1941
 
1942
  #### Dataset
 
1956
  #### Contact
1957
 
1958
  For questions or issues:
1959
+ - Open an issue on [GitHub](https://github.com/UII-America/MedGRPO-Code)
1960
+ - Visit the [project page](https://uii-america.github.io/MedGRPO/)
1961
  - Email: [Contact via GitHub](https://github.com/YuhaoSu)
1962
  """)
1963
 
evaluation/eval_caption_llm_judge.py CHANGED
@@ -163,7 +163,7 @@ def call_llm_judge_api(prediction: str, ground_truth: str, task_type: str, api_k
163
 
164
  with progress_lock:
165
  completed_calls += 1
166
- if completed_calls % 50 == 0:
167
  print(f" Progress: {completed_calls}/{total_calls} API calls completed")
168
 
169
  return scores
 
163
 
164
  with progress_lock:
165
  completed_calls += 1
166
+ if total_calls > 0 and completed_calls % 50 == 0:
167
  print(f" Progress: {completed_calls}/{total_calls} API calls completed")
168
 
169
  return scores
evaluation/eval_cvs_assessment.py CHANGED
@@ -373,10 +373,17 @@ def main():
373
  print("CVS ASSESSMENT EVALUATION SUMMARY")
374
  print(f"{'='*60}")
375
 
 
376
  for dataset_name, results in all_results.items():
377
  if results:
378
  print(f"\n{dataset_name}:")
379
  print(f" Overall Accuracy: {results['accuracy']:.4f} ({results['correct']}/{results['total']})")
 
 
 
 
 
 
380
 
381
 
382
  if __name__ == "__main__":
 
373
  print("CVS ASSESSMENT EVALUATION SUMMARY")
374
  print(f"{'='*60}")
375
 
376
+ all_bal_acc = []
377
  for dataset_name, results in all_results.items():
378
  if results:
379
  print(f"\n{dataset_name}:")
380
  print(f" Overall Accuracy: {results['accuracy']:.4f} ({results['correct']}/{results['total']})")
381
+ all_bal_acc.append(results.get('component_balanced_accuracy', 0.0))
382
+
383
+ return {
384
+ 'per_dataset': all_results,
385
+ 'component_balanced_accuracy': np.mean(all_bal_acc) if all_bal_acc else 0.0
386
+ }
387
 
388
 
389
  if __name__ == "__main__":
evaluation/eval_dvc.py CHANGED
@@ -1,18 +1,30 @@
1
  """Dense Video Captioning evaluation using LLM judge + temporal F1.
2
 
 
 
 
 
 
3
  Temporal F1 algorithm matches Qwen2.5-VL/my_eval/eval_dvc.py exactly:
4
  - process_raw_output() + flatten_overlapping_segments() for parsing
5
  - Frame-based coordinates (multiply by FPS)
6
- - Many-to-many threshold matching across IoU (0.3, 0.5, 0.7, 0.9)
7
  - F1 = 2 * mean_precision * mean_recall / (mean_precision + mean_recall)
8
  """
9
 
10
  import json
 
11
  import re
12
  import sys
 
13
  import numpy as np
14
  from collections import defaultdict
15
- from eval_caption_llm_judge import evaluate_caption_task
 
 
 
 
 
16
 
17
 
18
  # =============================================================================
@@ -190,7 +202,7 @@ def compute_temporal_f1_single(predicted_segments, gt_segments, splits,
190
 
191
 
192
  # =============================================================================
193
- # Dataset grouping and evaluation
194
  # =============================================================================
195
 
196
  def group_records_by_dataset(data):
@@ -241,31 +253,176 @@ def _extract_gt_segments(record):
241
  return gnd
242
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
245
  """Evaluate DVC for a specific dataset using caption quality + temporal F1."""
246
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
247
 
248
- # Step 1: Evaluate caption quality using LLM judge (unless skipped)
249
  if skip_llm_judge:
250
  print(f" Skipping LLM judge caption evaluation (--skip-llm-judge flag)")
251
  caption_score = 0.0
252
  caption_method = 'skipped'
253
  else:
254
- import tempfile
255
- import os
256
-
257
- temp_data = {str(i): record for i, record in enumerate(records)}
258
-
259
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
260
- json.dump(temp_data, f)
261
- temp_file = f.name
262
-
263
- try:
264
- caption_result = evaluate_caption_task(temp_file, 'dense_captioning')
265
- caption_score = caption_result['score']
266
- caption_method = caption_result['method']
267
- finally:
268
- os.unlink(temp_file)
269
 
270
  # Step 2: Compute temporal F1 matching Qwen2.5-VL algorithm exactly
271
  all_f1_scores = []
@@ -389,6 +546,7 @@ def main():
389
  all_f1_scores.append(metrics.get('temporal_f1', 0))
390
 
391
  return {
 
392
  'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
393
  'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
394
  'method': all_results[list(all_results.keys())[0]]['overall'].get('caption_method', 'unknown') if all_results else 'unknown'
 
1
  """Dense Video Captioning evaluation using LLM judge + temporal F1.
2
 
3
+ LLM judge uses IoU-matched segment pairs (matching original Qwen2.5-VL/llm_judge/):
4
+ - Match predicted segments to GT segments at IoU thresholds (0.3, 0.5, 0.7)
5
+ - Only judge matched pairs individually (not concatenated)
6
+ - Average across matched pairs, then across thresholds
7
+
8
  Temporal F1 algorithm matches Qwen2.5-VL/my_eval/eval_dvc.py exactly:
9
  - process_raw_output() + flatten_overlapping_segments() for parsing
10
  - Frame-based coordinates (multiply by FPS)
11
+ - Many-to-many threshold matching across IoU (0.3, 0.5, 0.7)
12
  - F1 = 2 * mean_precision * mean_recall / (mean_precision + mean_recall)
13
  """
14
 
15
  import json
16
+ import os
17
  import re
18
  import sys
19
+ import time
20
  import numpy as np
21
  from collections import defaultdict
22
+ from concurrent.futures import ThreadPoolExecutor, as_completed
23
+ from threading import Lock
24
+ from eval_caption_llm_judge import (
25
+ call_llm_judge_api, BEST5_ASPECTS, OPENAI_AVAILABLE,
26
+ compute_semantic_similarity_fallback
27
+ )
28
 
29
 
30
  # =============================================================================
 
202
 
203
 
204
  # =============================================================================
205
+ # Dataset grouping and evaluation (LlamaFactory specific)
206
  # =============================================================================
207
 
208
  def group_records_by_dataset(data):
 
253
  return gnd
254
 
255
 
256
+ DVC_IOU_THRESHOLDS = [0.3, 0.5, 0.7]
257
+ DVC_MAX_WORKERS = 20
258
+
259
+ # Thread-safe progress counter for DVC LLM judge
260
+ _dvc_progress_lock = Lock()
261
+ _dvc_completed = 0
262
+ _dvc_total = 0
263
+
264
+
265
+ def _segment_iou(seg1, seg2):
266
+ """Compute IoU for two temporal segments (dicts with 'start' and 'end')."""
267
+ intersection = max(0, min(seg1['end'], seg2['end']) - max(seg1['start'], seg2['start']))
268
+ union = (seg1['end'] - seg1['start']) + (seg2['end'] - seg2['start']) - intersection
269
+ return intersection / union if union > 0 else 0.0
270
+
271
+
272
+ def _match_captions_at_threshold(pred_segments, gt_segments, threshold):
273
+ """Match predicted to ground truth segments at a specific IoU threshold.
274
+
275
+ Returns list of (pred_caption, gt_caption) pairs.
276
+ """
277
+ matched_pairs = []
278
+ for pred_seg in pred_segments:
279
+ best_iou = 0.0
280
+ best_gt_caption = None
281
+ for gt_seg in gt_segments:
282
+ current_iou = _segment_iou(pred_seg, gt_seg)
283
+ if current_iou >= threshold and current_iou > best_iou:
284
+ best_iou = current_iou
285
+ best_gt_caption = gt_seg['caption']
286
+ if best_gt_caption is not None:
287
+ matched_pairs.append((pred_seg['caption'], best_gt_caption))
288
+ return matched_pairs
289
+
290
+
291
+ def _evaluate_dvc_caption_iou_matched(records, api_key):
292
+ """Evaluate DVC captions using IoU-matched segment pairs + LLM judge.
293
+
294
+ Matches the original Qwen2.5-VL/llm_judge/ approach:
295
+ 1. Parse pred and GT into segments
296
+ 2. Match at IoU thresholds (0.3, 0.5, 0.7)
297
+ 3. Judge each matched pair individually
298
+ 4. Average across pairs, then across thresholds
299
+ """
300
+ global _dvc_completed, _dvc_total
301
+
302
+ # Phase 1: Match all samples at all thresholds
303
+ print(f" Phase 1: Matching segments at IoU thresholds {DVC_IOU_THRESHOLDS}...")
304
+ all_matched = []
305
+
306
+ for record in records:
307
+ pred_text = record.get('answer', '')
308
+ gt_text = record.get('gnd', '')
309
+
310
+ pred_segments = process_raw_output(pred_text)
311
+ gt_segments = _extract_gt_segments(record)
312
+
313
+ if not isinstance(gt_segments, list):
314
+ continue
315
+
316
+ # Ensure gt_segments are dicts with caption
317
+ gt_segs = [g for g in gt_segments if isinstance(g, dict) and 'start' in g and 'end' in g and 'caption' in g]
318
+
319
+ if not pred_segments or not gt_segs:
320
+ continue
321
+
322
+ matched_pairs = {}
323
+ for threshold in DVC_IOU_THRESHOLDS:
324
+ pairs = _match_captions_at_threshold(pred_segments, gt_segs, threshold)
325
+ matched_pairs[threshold] = pairs
326
+
327
+ all_matched.append(matched_pairs)
328
+
329
+ total_pairs = sum(sum(len(pairs) for pairs in m.values()) for m in all_matched)
330
+ print(f" ✓ Matched {len(all_matched)} samples, {total_pairs} total pairs across all thresholds")
331
+
332
+ if total_pairs == 0:
333
+ return 0.0, 'llm_judge_iou_matched', 0.0
334
+
335
+ # Phase 2: Evaluate all matched pairs in parallel
336
+ _dvc_total = total_pairs
337
+ _dvc_completed = 0
338
+
339
+ print(f" Phase 2: Evaluating {total_pairs} pairs with LLM Judge ({DVC_MAX_WORKERS} workers)...")
340
+
341
+ # Collect all tasks: (sample_idx, threshold, pred_caption, gt_caption)
342
+ tasks = []
343
+ for sample_idx, matched_pairs in enumerate(all_matched):
344
+ for threshold in DVC_IOU_THRESHOLDS:
345
+ for pred_cap, gt_cap in matched_pairs[threshold]:
346
+ tasks.append((sample_idx, threshold, pred_cap, gt_cap))
347
+
348
+ # Store results per threshold
349
+ threshold_scores = {t: {aspect: [] for aspect in BEST5_ASPECTS} for t in DVC_IOU_THRESHOLDS}
350
+ api_successes = 0
351
+
352
+ def _judge_pair(pred_cap, gt_cap):
353
+ global _dvc_completed
354
+ result = call_llm_judge_api(pred_cap, gt_cap, 'dense_captioning', api_key)
355
+ with _dvc_progress_lock:
356
+ _dvc_completed += 1
357
+ if _dvc_completed % 50 == 0:
358
+ print(f" Progress: {_dvc_completed}/{_dvc_total} API calls completed")
359
+ return result
360
+
361
+ with ThreadPoolExecutor(max_workers=DVC_MAX_WORKERS) as executor:
362
+ future_to_task = {
363
+ executor.submit(_judge_pair, pred_cap, gt_cap): (sample_idx, threshold)
364
+ for sample_idx, threshold, pred_cap, gt_cap in tasks
365
+ }
366
+
367
+ for future in as_completed(future_to_task):
368
+ _, threshold = future_to_task[future]
369
+ try:
370
+ result = future.result()
371
+ if result.get('api_success', False):
372
+ for aspect in BEST5_ASPECTS:
373
+ threshold_scores[threshold][aspect].append(result[aspect])
374
+ api_successes += 1
375
+ except Exception as e:
376
+ print(f" ⚠ Error: {e}")
377
+
378
+ # Phase 3: Aggregate — average per threshold, then across thresholds
379
+ per_threshold_avg = {}
380
+ for threshold in DVC_IOU_THRESHOLDS:
381
+ aspect_avgs = {}
382
+ for aspect in BEST5_ASPECTS:
383
+ scores = threshold_scores[threshold][aspect]
384
+ aspect_avgs[aspect] = np.mean(scores) if scores else 0.0
385
+ valid = [v for v in aspect_avgs.values() if v > 0]
386
+ per_threshold_avg[threshold] = np.mean(valid) if valid else 0.0
387
+
388
+ # Overall: average across thresholds
389
+ valid_thresholds = [v for v in per_threshold_avg.values() if v > 0]
390
+ overall_score = np.mean(valid_thresholds) if valid_thresholds else 0.0
391
+ success_rate = api_successes / total_pairs if total_pairs > 0 else 0.0
392
+
393
+ print(f" ✓ LLM Judge completed: {api_successes}/{total_pairs} successful")
394
+ for t in DVC_IOU_THRESHOLDS:
395
+ print(f" IoU@{t}: {per_threshold_avg[t]:.3f}")
396
+ print(f" Overall (threshold-averaged): {overall_score:.3f}")
397
+
398
+ return overall_score, 'llm_judge_iou_matched', success_rate
399
+
400
+
401
  def evaluate_dataset_dvc(dataset_name, records, skip_llm_judge=False):
402
  """Evaluate DVC for a specific dataset using caption quality + temporal F1."""
403
  print(f"\nEvaluating {dataset_name} ({len(records)} records)...")
404
 
405
+ # Step 1: Evaluate caption quality using IoU-matched LLM judge
406
  if skip_llm_judge:
407
  print(f" Skipping LLM judge caption evaluation (--skip-llm-judge flag)")
408
  caption_score = 0.0
409
  caption_method = 'skipped'
410
  else:
411
+ api_key = os.getenv('OPENAI_API_KEY')
412
+ if api_key and OPENAI_AVAILABLE:
413
+ caption_score, caption_method, _ = _evaluate_dvc_caption_iou_matched(records, api_key)
414
+ else:
415
+ print(f" ⚠ No API key, using semantic similarity fallback")
416
+ import tempfile
417
+ temp_data = {str(i): record for i, record in enumerate(records)}
418
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
419
+ json.dump(temp_data, f)
420
+ temp_file = f.name
421
+ try:
422
+ caption_score = compute_semantic_similarity_fallback(temp_data, 'dense_captioning')
423
+ caption_method = 'semantic_similarity'
424
+ finally:
425
+ os.unlink(temp_file)
426
 
427
  # Step 2: Compute temporal F1 matching Qwen2.5-VL algorithm exactly
428
  all_f1_scores = []
 
546
  all_f1_scores.append(metrics.get('temporal_f1', 0))
547
 
548
  return {
549
+ 'per_dataset': all_results,
550
  'caption_score': np.mean(all_caption_scores) if all_caption_scores else 0.0,
551
  'temporal_f1': np.mean(all_f1_scores) if all_f1_scores else 0.0,
552
  'method': all_results[list(all_results.keys())[0]]['overall'].get('caption_method', 'unknown') if all_results else 'unknown'
evaluation/eval_next_action.py CHANGED
@@ -684,6 +684,9 @@ def main():
684
  print("NEXT ACTION EVALUATION SUMMARY")
685
  print(f"{'='*80}")
686
 
 
 
 
687
  for dataset_name, fps_results in all_results.items():
688
  if fps_results:
689
  print(f"\n{dataset_name}:")
@@ -694,6 +697,18 @@ def main():
694
  print(f" {metric_name}: {value:.4f}")
695
  else:
696
  print(f" samples: {value}")
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
 
699
  if __name__ == "__main__":
 
684
  print("NEXT ACTION EVALUATION SUMMARY")
685
  print(f"{'='*80}")
686
 
687
+ all_accuracies = []
688
+ total_correct = 0
689
+ total_samples = 0
690
  for dataset_name, fps_results in all_results.items():
691
  if fps_results:
692
  print(f"\n{dataset_name}:")
 
697
  print(f" {metric_name}: {value:.4f}")
698
  else:
699
  print(f" samples: {value}")
700
+ if 'overall' in fps_results:
701
+ acc = fps_results['overall'].get('accuracy', 0.0)
702
+ count = fps_results['overall'].get('count', 0)
703
+ all_accuracies.append(acc)
704
+ total_correct += int(acc * count)
705
+ total_samples += count
706
+
707
+ return {
708
+ 'per_dataset': all_results,
709
+ 'accuracy': total_correct / total_samples if total_samples > 0 else 0.0,
710
+ 'macro_accuracy': np.mean(all_accuracies) if all_accuracies else 0.0
711
+ }
712
 
713
 
714
  if __name__ == "__main__":
evaluation/eval_skill_assessment.py CHANGED
@@ -421,6 +421,12 @@ def main():
421
  # Show overall skill level accuracy
422
  print(f" Overall Skill Level Accuracy: {results['accuracy']:.4f} ({results['correct']}/{results['total']})")
423
 
 
 
 
 
 
 
424
 
425
  if __name__ == "__main__":
426
  main()
 
421
  # Show overall skill level accuracy
422
  print(f" Overall Skill Level Accuracy: {results['accuracy']:.4f} ({results['correct']}/{results['total']})")
423
 
424
+ all_bal_acc = [r.get('aspect_balanced_accuracy', 0.0) for r in all_results.values() if r]
425
+ return {
426
+ 'per_dataset': all_results,
427
+ 'aspect_balanced_accuracy': np.mean(all_bal_acc) if all_bal_acc else 0.0
428
+ }
429
+
430
 
431
  if __name__ == "__main__":
432
  main()
evaluation/eval_stg.py CHANGED
@@ -354,6 +354,7 @@ def main():
354
  print("STG EVALUATION SUMMARY")
355
  print(f"{'='*80}")
356
 
 
357
  for dataset_name, fps_results in all_results.items():
358
  if fps_results:
359
  print(f"\n{dataset_name}:")
@@ -364,6 +365,13 @@ def main():
364
  print(f" {metric_name}: {value:.4f}")
365
  else:
366
  print(f" samples: {value}")
 
 
 
 
 
 
 
367
 
368
 
369
  if __name__ == "__main__":
 
354
  print("STG EVALUATION SUMMARY")
355
  print(f"{'='*80}")
356
 
357
+ all_ious = []
358
  for dataset_name, fps_results in all_results.items():
359
  if fps_results:
360
  print(f"\n{dataset_name}:")
 
365
  print(f" {metric_name}: {value:.4f}")
366
  else:
367
  print(f" samples: {value}")
368
+ if 'overall' in fps_results:
369
+ all_ious.append(fps_results['overall'].get('mean_iou', 0.0))
370
+
371
+ return {
372
+ 'per_dataset': all_results,
373
+ 'mean_iou': np.mean(all_ious) if all_ious else 0.0
374
+ }
375
 
376
 
377
  if __name__ == "__main__":
evaluation/eval_tal.py CHANGED
@@ -309,8 +309,9 @@ def main():
309
  if 'meanIoU@0.5' in metrics:
310
  all_miou_05.append(metrics['meanIoU@0.5'])
311
 
312
- # Return overall aggregated results
313
  return {
 
314
  'meanIoU@0.3': np.mean(all_miou_03) if all_miou_03 else 0.0,
315
  'meanIoU@0.5': np.mean(all_miou_05) if all_miou_05 else 0.0
316
  }
 
309
  if 'meanIoU@0.5' in metrics:
310
  all_miou_05.append(metrics['meanIoU@0.5'])
311
 
312
+ # Return per-dataset results for caching + macro averages
313
  return {
314
+ 'per_dataset': all_results,
315
  'meanIoU@0.3': np.mean(all_miou_03) if all_miou_03 else 0.0,
316
  'meanIoU@0.5': np.mean(all_miou_05) if all_miou_05 else 0.0
317
  }
evaluation/evaluate_all_pai.py CHANGED
@@ -442,278 +442,138 @@ def print_evaluation_results_csv_internal(output_file, tasks, evaluation_results
442
 
443
 
444
  def print_overall_evaluation_results(output_file, tasks, all_task_results, skip_llm_judge=False):
445
- """Print evaluation results in overall mode (dataset-agnostic).
446
 
447
- For each task, computes metrics by processing individual samples across
448
- all datasets together, rather than averaging per-dataset metrics.
449
  """
 
 
450
  print(f"\n{'='*80}")
451
  print(f"EVALUATION RESULTS - OVERALL (Dataset-Agnostic)")
452
  print(f"{'='*80}")
453
 
454
- # Load the data to re-process at individual level
455
- with open(output_file, "r") as f:
456
- data = json.load(f)
457
-
458
- # Handle both dict and list formats
459
- if isinstance(data, dict):
460
- records = list(data.values())
461
- elif isinstance(data, list):
462
- records = data
463
- else:
464
- print(f"Unexpected data format: {type(data)}")
465
- return
466
-
467
- # For each task, collect all records across datasets and re-evaluate
468
  for task_name in sorted(tasks):
469
  print(f"\n{'='*80}")
470
  print(f"{task_name.upper()} - Overall Evaluation (All Datasets Combined)")
471
  print(f"{'='*80}")
472
 
473
- # Filter records for this task
474
- task_records = []
475
- for record in records:
476
- qa_type = record.get("qa_type", "unknown")
477
-
478
- # Map qa_type to task name
479
- mapped_task = None
480
- if any("dense_captioning" in qa_type or qa_type == "dc" for _ in [qa_type]):
481
- mapped_task = "dvc"
482
- elif qa_type == "tal":
483
- mapped_task = "tal"
484
- elif qa_type == "next_action":
485
- mapped_task = "next_action"
486
- elif qa_type == "stg":
487
- mapped_task = "stg"
488
- elif "region_caption" in qa_type:
489
- mapped_task = "rc"
490
- elif "video_summary" in qa_type:
491
- mapped_task = "vs"
492
- elif qa_type == "skill_assessment":
493
- mapped_task = "skill_assessment"
494
- elif qa_type == "cvs_assessment":
495
- mapped_task = "cvs_assessment"
496
-
497
- if mapped_task == task_name:
498
- task_records.append(record)
499
-
500
- if not task_records:
501
- print(f"No records found for {task_name}")
502
  continue
503
 
504
- print(f"Total samples: {len(task_records)}")
505
-
506
- # Re-run evaluation on all records together
507
- # Import and call the appropriate evaluation function
508
  try:
509
  if task_name == "tal":
510
- # Import the eval module
511
- module = load_eval_module("eval_tal")
512
- # Create a temporary dict with sequential keys
513
- temp_data = {str(i): record for i, record in enumerate(task_records)}
514
- # Get grouped records
515
- dataset_records_dict = module.group_records_by_dataset(temp_data)
516
- # Combine all records across datasets
517
- all_records = []
518
- for ds_records in dataset_records_dict.values():
519
- all_records.extend(ds_records)
520
- # Evaluate as single dataset
521
- results = module.evaluate_dataset_tal("Overall", all_records)
522
- # Print results
523
- for iou_key, metrics in results.items():
524
- if isinstance(metrics, dict):
525
- print(f"\n{iou_key}:")
526
- for metric_name, value in metrics.items():
527
- print(f" {metric_name}: {value:.4f}")
528
- else:
529
- print(f"{iou_key}: {metrics:.4f}")
530
 
531
  elif task_name == "stg":
532
- module = load_eval_module("eval_stg")
533
- temp_data = {str(i): record for i, record in enumerate(task_records)}
534
- dataset_records_dict = module.group_records_by_dataset(temp_data)
535
- all_records = []
536
- for ds_records in dataset_records_dict.values():
537
- all_records.extend(ds_records)
538
- results = module.evaluate_dataset_stg("Overall", all_records)
539
- # Extract overall metrics
540
- if 'overall' in results:
541
- mean_iou = results['overall'].get('mean_iou', 0.0)
542
- print(f"\nmean_iou: {mean_iou:.4f}")
543
- else:
544
- # Compute from per-FPS metrics if overall not available
545
  all_ious = []
546
- for fps_key, metrics in results.items():
547
- if isinstance(metrics, dict) and 'mIoU' in metrics:
548
- count = metrics.get('count', 0)
549
- miou = metrics.get('mIoU', 0)
550
- all_ious.extend([miou] * int(count))
551
-
552
- if all_ious:
553
- import numpy as np
554
- overall_miou = np.mean(all_ious)
555
- print(f"\nmean_iou: {overall_miou:.4f}")
556
- else:
557
- print(f"\nmean_iou: 0.0000")
 
558
 
559
  elif task_name in ["rc", "vs"]:
560
- # Use server-side LLM judge for caption evaluation
561
- module = load_eval_module("eval_caption_llm_judge")
562
- task_type = "region_caption" if task_name == "rc" else "video_summary"
563
-
564
- # Save task records to temp file for evaluation
565
- import tempfile
566
- with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
567
- json.dump(task_records, f)
568
- temp_file = f.name
569
-
570
- try:
571
- result = module.evaluate_caption_task(temp_file, task_type)
572
- print(f"Method: {result['method']}")
573
- print(f"Score: {result['score']:.4f} ({result['scale']} scale)")
574
- if 'aspect_scores' in result:
575
  print("Aspect Scores:")
576
- for aspect, score in sorted(result['aspect_scores'].items()):
577
  print(f" {aspect}: {score:.3f}")
578
- finally:
579
- os.unlink(temp_file)
580
 
581
  elif task_name == "next_action":
582
- module = load_eval_module("eval_next_action")
583
- temp_data = {str(i): record for i, record in enumerate(task_records)}
584
- dataset_records_dict = module.group_records_by_dataset(temp_data)
585
-
586
- # For next_action, we need to evaluate per dataset (different action lists)
587
- # then aggregate the results - but suppress per-dataset output
588
- all_accuracies = []
589
- total_correct = 0
590
- total_samples = 0
591
-
592
- # Suppress output during per-dataset evaluation
593
- import io
594
- import contextlib
595
-
596
- for dataset_name, ds_records in dataset_records_dict.items():
597
- if ds_records:
598
- # Silently evaluate each dataset
599
- # Suppress SentenceTransformer/safetensors warnings at fd level
600
- import logging, os
601
- logging.disable(logging.WARNING)
602
- old_fd_out = os.dup(1)
603
- old_fd_err = os.dup(2)
604
- devnull = os.open(os.devnull, os.O_WRONLY)
605
- os.dup2(devnull, 1)
606
- os.dup2(devnull, 2)
607
- try:
608
- with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
609
- ds_results = module.evaluate_dataset_next_action(dataset_name, ds_records)
610
- finally:
611
- os.dup2(old_fd_out, 1)
612
- os.dup2(old_fd_err, 2)
613
- os.close(old_fd_out)
614
- os.close(old_fd_err)
615
- os.close(devnull)
616
- logging.disable(logging.NOTSET)
617
- if "overall" in ds_results:
618
- accuracy = ds_results["overall"].get("accuracy", 0.0)
619
- # Use actual evaluated count, not input count (some records may be skipped)
620
- evaluated_count = ds_results["overall"].get("count", len(ds_records))
621
- all_accuracies.append(accuracy)
622
- total_correct += int(accuracy * evaluated_count)
623
- total_samples += evaluated_count
624
-
625
- # Print only final aggregate metrics
626
- if all_accuracies:
627
- macro_avg = sum(all_accuracies) / len(all_accuracies)
628
- weighted_avg = total_correct / total_samples if total_samples > 0 else 0.0
629
- print(f"\nMacro Average Accuracy (across {len(all_accuracies)} datasets): {macro_avg:.4f}")
630
- print(f"Weighted Average Accuracy (across {total_samples} samples): {weighted_avg:.4f}")
631
  else:
632
- # Fallback: compute overall accuracy directly from all records
633
- print(f"\nNext Action Metrics:")
634
- all_correct = 0
635
- all_total = 0
636
- for dataset_name, ds_records in dataset_records_dict.items():
637
- if ds_records:
638
- with contextlib.redirect_stdout(io.StringIO()):
639
- ds_results = module.evaluate_dataset_next_action(dataset_name, ds_records)
640
- # Extract accuracy from any FPS key
641
- for fps_key, metrics in ds_results.items():
642
- if isinstance(metrics, dict) and 'accuracy' in metrics:
643
- accuracy = metrics['accuracy']
644
- count = metrics.get('count', len(ds_records))
645
- all_correct += int(accuracy * count)
646
- all_total += count
647
- break
648
- if all_total > 0:
649
- overall_acc = all_correct / all_total
650
- print(f" accuracy: {overall_acc:.4f}")
651
 
652
  elif task_name == "dvc":
653
- module = load_eval_module("eval_dvc")
654
- temp_data = {str(i): record for i, record in enumerate(task_records)}
655
- dataset_records_dict = module.group_records_by_dataset(temp_data)
656
- # Combine all records across datasets
657
- all_records = []
658
- for ds_records in dataset_records_dict.values():
659
- all_records.extend(ds_records)
660
- # Evaluate as single dataset (pass skip_llm_judge flag)
661
- results = module.evaluate_dataset_dvc("Overall", all_records, skip_llm_judge=skip_llm_judge)
662
- # Print results
663
  print(f"\nDense Video Captioning Metrics:")
664
- if 'overall' in results:
665
- overall_metrics = results['overall']
666
- for metric_name, value in overall_metrics.items():
667
- if isinstance(value, (int, float)):
668
- print(f" {metric_name}: {value:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
  elif task_name == "cvs_assessment":
671
- module = load_eval_module("eval_cvs_assessment")
672
- temp_data = {str(i): record for i, record in enumerate(task_records)}
673
- dataset_records_dict = module.group_records_by_dataset(temp_data)
674
- # Combine all records across datasets
675
- all_records = []
676
- for ds_records in dataset_records_dict.values():
677
- all_records.extend(ds_records)
678
- # Evaluate combined
679
- results = module.evaluate_cvs_assessment(all_records)
680
- # Print results
681
- print(f"\nCVS Assessment Metrics:")
682
- if "overall" in results:
683
- for metric_name, value in results["overall"].items():
684
- if isinstance(value, (int, float)):
685
- print(f" {metric_name}: {value:.4f}")
686
  else:
687
- for metric_name, value in results.items():
688
- if isinstance(value, (int, float)):
689
- print(f" {metric_name}: {value:.4f}")
690
 
691
  elif task_name == "skill_assessment":
692
- module = load_eval_module("eval_skill_assessment")
693
- temp_data = {str(i): record for i, record in enumerate(task_records)}
694
- dataset_records_dict = module.group_records_by_dataset(temp_data)
695
- # Combine all records across datasets
696
- all_records = []
697
- for ds_records in dataset_records_dict.values():
698
- all_records.extend(ds_records)
699
- # Evaluate combined
700
- results = module.evaluate_skill_assessment(all_records)
701
- # Print results
702
- print(f"\nSkill Assessment Metrics:")
703
- if "overall" in results:
704
- for metric_name, value in results["overall"].items():
705
- if isinstance(value, (int, float)):
706
- print(f" {metric_name}: {value:.4f}")
707
  else:
708
- for metric_name, value in results.items():
709
- if isinstance(value, (int, float)):
710
- print(f" {metric_name}: {value:.4f}")
711
 
712
  else:
713
  print(f"Overall evaluation not implemented for {task_name} yet")
714
 
715
  except Exception as e:
716
- print(f"Error running overall evaluation for {task_name}: {e}")
717
  import traceback
718
  traceback.print_exc()
719
 
 
442
 
443
 
444
  def print_overall_evaluation_results(output_file, tasks, all_task_results, skip_llm_judge=False):
445
+ """Print evaluation results in overall mode using cached per-dataset results.
446
 
447
+ Aggregates per-dataset results from _run_task_eval (pooled across all datasets)
448
+ so that each data point is only evaluated once.
449
  """
450
+ import numpy as np
451
+
452
  print(f"\n{'='*80}")
453
  print(f"EVALUATION RESULTS - OVERALL (Dataset-Agnostic)")
454
  print(f"{'='*80}")
455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  for task_name in sorted(tasks):
457
  print(f"\n{'='*80}")
458
  print(f"{task_name.upper()} - Overall Evaluation (All Datasets Combined)")
459
  print(f"{'='*80}")
460
 
461
+ cached = all_task_results.get(task_name, {})
462
+ if not cached:
463
+ print(f"No results found for {task_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  continue
465
 
 
 
 
 
466
  try:
467
  if task_name == "tal":
468
+ per_dataset = cached.get('per_dataset', {})
469
+ if per_dataset:
470
+ # Pool all per-sample meanIoU across datasets and FPS groups
471
+ all_miou_03 = []
472
+ all_miou_05 = []
473
+ for ds_name, fps_results in per_dataset.items():
474
+ for fps_key, metrics in fps_results.items():
475
+ if isinstance(metrics, dict) and 'meanIoU@0.3' in metrics:
476
+ count = metrics.get('count', 1)
477
+ all_miou_03.extend([metrics['meanIoU@0.3']] * count)
478
+ all_miou_05.extend([metrics['meanIoU@0.5']] * count)
479
+ print(f"\n mIoU@0.3: {np.mean(all_miou_03):.4f}" if all_miou_03 else "\n mIoU@0.3: 0.0000")
480
+ print(f" mIoU@0.5: {np.mean(all_miou_05):.4f}" if all_miou_05 else " mIoU@0.5: 0.0000")
481
+ else:
482
+ print(f" mIoU@0.3: {cached.get('meanIoU@0.3', 0.0):.4f}")
483
+ print(f" mIoU@0.5: {cached.get('meanIoU@0.5', 0.0):.4f}")
 
 
 
 
484
 
485
  elif task_name == "stg":
486
+ per_dataset = cached.get('per_dataset', {})
487
+ if per_dataset:
488
+ # Pool all per-sample IoUs across datasets
 
 
 
 
 
 
 
 
 
 
489
  all_ious = []
490
+ for ds_name, fps_results in per_dataset.items():
491
+ if 'overall' in fps_results:
492
+ count = fps_results['overall'].get('valid_records', 1)
493
+ miou = fps_results['overall'].get('mean_iou', 0.0)
494
+ all_ious.extend([miou] * count)
495
+ else:
496
+ for fps_key, metrics in fps_results.items():
497
+ if isinstance(metrics, dict) and 'mIoU' in metrics:
498
+ count = metrics.get('count', 1)
499
+ all_ious.extend([metrics['mIoU']] * count)
500
+ print(f"\nmean_iou: {np.mean(all_ious):.4f}" if all_ious else "\nmean_iou: 0.0000")
501
+ else:
502
+ print(f"\nmean_iou: {cached.get('mean_iou', 0.0):.4f}")
503
 
504
  elif task_name in ["rc", "vs"]:
505
+ # LLM judge use cached results directly (already pooled)
506
+ if 'score' in cached:
507
+ print(f"Method: {cached['method']}")
508
+ print(f"Score: {cached['score']:.4f} ({cached['scale']} scale)")
509
+ if 'aspect_scores' in cached:
 
 
 
 
 
 
 
 
 
 
510
  print("Aspect Scores:")
511
+ for aspect, score in sorted(cached['aspect_scores'].items()):
512
  print(f" {aspect}: {score:.3f}")
513
+ else:
514
+ print(f"No LLM judge results available")
515
 
516
  elif task_name == "next_action":
517
+ per_dataset = cached.get('per_dataset', {})
518
+ if per_dataset:
519
+ # Pool per-sample correct/total across datasets
520
+ total_correct = 0
521
+ total_samples = 0
522
+ for ds_name, fps_results in per_dataset.items():
523
+ if 'overall' in fps_results:
524
+ acc = fps_results['overall'].get('accuracy', 0.0)
525
+ count = fps_results['overall'].get('count', 0)
526
+ total_correct += round(acc * count)
527
+ total_samples += count
528
+ if total_samples > 0:
529
+ print(f"\n accuracy: {total_correct / total_samples:.4f}")
530
+ else:
531
+ print(f"\n accuracy: 0.0000")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  else:
533
+ print(f"\n accuracy: {cached.get('accuracy', 0.0):.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
  elif task_name == "dvc":
536
+ per_dataset = cached.get('per_dataset', {})
 
 
 
 
 
 
 
 
 
537
  print(f"\nDense Video Captioning Metrics:")
538
+ if per_dataset:
539
+ # Pool caption_score and temporal_f1 weighted by sample count
540
+ total_caption = 0.0
541
+ total_f1 = 0.0
542
+ total_count = 0
543
+ for ds_name, ds_results in per_dataset.items():
544
+ if ds_results and 'overall' in ds_results:
545
+ overall = ds_results['overall']
546
+ count = overall.get('count', 0)
547
+ total_caption += overall.get('caption_score', 0.0) * count
548
+ total_f1 += overall.get('temporal_f1', 0.0) * count
549
+ total_count += count
550
+ if total_count > 0:
551
+ print(f" caption_score: {total_caption / total_count:.4f}")
552
+ print(f" temporal_f1: {total_f1 / total_count:.4f}")
553
+ else:
554
+ for metric_name in ['caption_score', 'temporal_f1']:
555
+ if metric_name in cached and isinstance(cached[metric_name], (int, float)):
556
+ print(f" {metric_name}: {cached[metric_name]:.4f}")
557
 
558
  elif task_name == "cvs_assessment":
559
+ per_dataset = cached.get('per_dataset', {})
560
+ if per_dataset:
561
+ print(f"\n component_balanced_accuracy: {cached.get('component_balanced_accuracy', 0.0):.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
562
  else:
563
+ print(f"\n component_balanced_accuracy: {cached.get('component_balanced_accuracy', 0.0):.4f}")
 
 
564
 
565
  elif task_name == "skill_assessment":
566
+ per_dataset = cached.get('per_dataset', {})
567
+ if per_dataset:
568
+ print(f"\n aspect_balanced_accuracy: {cached.get('aspect_balanced_accuracy', 0.0):.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
569
  else:
570
+ print(f"\n aspect_balanced_accuracy: {cached.get('aspect_balanced_accuracy', 0.0):.4f}")
 
 
571
 
572
  else:
573
  print(f"Overall evaluation not implemented for {task_name} yet")
574
 
575
  except Exception as e:
576
+ print(f"Error printing overall evaluation for {task_name}: {e}")
577
  import traceback
578
  traceback.print_exc()
579
 
evaluation/evaluate_predictions.py CHANGED
@@ -174,7 +174,12 @@ def _parse_metrics_from_output(output):
174
  line = line.strip()
175
 
176
  # Detect task sections
177
- if "TAL" in line and "Overall" in line:
 
 
 
 
 
178
  current_task = "tal"
179
  elif "STG" in line and "Overall" in line:
180
  current_task = "stg"
@@ -186,10 +191,6 @@ def _parse_metrics_from_output(output):
186
  current_task = "rc"
187
  elif ("VS" in line and "Overall" in line) or "Video Summary" in line:
188
  current_task = "vs"
189
- elif ("SKILL" in line and "Overall" in line) or "Skill Assessment" in line:
190
- current_task = "skill_assessment"
191
- elif ("CVS" in line and "Overall" in line) or "CVS Assessment" in line:
192
- current_task = "cvs_assessment"
193
 
194
  if current_task == "tal":
195
  if "IoU_0.3:" in line:
@@ -226,10 +227,12 @@ def _parse_metrics_from_output(output):
226
  metrics["dvc_f1"] = float(line.split(":")[-1].strip())
227
 
228
  elif current_task == "vs" and ("score" in line.lower() or "average" in line.lower()):
229
- metrics["vs_llm"] = float(line.split(":")[-1].strip())
 
230
 
231
  elif current_task == "rc" and ("score" in line.lower() or "average" in line.lower()):
232
- metrics["rc_llm"] = float(line.split(":")[-1].strip())
 
233
 
234
  elif current_task == "skill_assessment" and "aspect_balanced_accuracy" in line.lower():
235
  metrics["sa_acc"] = float(line.split(":")[1].split("(")[0].strip())
 
174
  line = line.strip()
175
 
176
  # Detect task sections
177
+ # NOTE: Order matters check CVS before VS (since "CVS" contains "VS")
178
+ if ("CVS" in line and "Overall" in line) or "CVS Assessment" in line:
179
+ current_task = "cvs_assessment"
180
+ elif ("SKILL" in line and "Overall" in line) or "Skill Assessment" in line:
181
+ current_task = "skill_assessment"
182
+ elif "TAL" in line and "Overall" in line:
183
  current_task = "tal"
184
  elif "STG" in line and "Overall" in line:
185
  current_task = "stg"
 
191
  current_task = "rc"
192
  elif ("VS" in line and "Overall" in line) or "Video Summary" in line:
193
  current_task = "vs"
 
 
 
 
194
 
195
  if current_task == "tal":
196
  if "IoU_0.3:" in line:
 
227
  metrics["dvc_f1"] = float(line.split(":")[-1].strip())
228
 
229
  elif current_task == "vs" and ("score" in line.lower() or "average" in line.lower()):
230
+ val_str = line.split(":")[-1].strip().split("(")[0].strip()
231
+ metrics["vs_llm"] = float(val_str)
232
 
233
  elif current_task == "rc" and ("score" in line.lower() or "average" in line.lower()):
234
+ val_str = line.split(":")[-1].strip().split("(")[0].strip()
235
+ metrics["rc_llm"] = float(val_str)
236
 
237
  elif current_task == "skill_assessment" and "aspect_balanced_accuracy" in line.lower():
238
  metrics["sa_acc"] = float(line.split(":")[1].split("(")[0].strip())