|
|
"""Auto-detect prediction format and evaluate with ground-truth merging if needed.""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
import argparse |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
eval_dir = Path(__file__).parent |
|
|
sys.path.insert(0, str(eval_dir)) |
|
|
|
|
|
import evaluate_all_pai |
|
|
|
|
|
|
|
|
def detect_has_ground_truth(data): |
|
|
"""Detect if prediction file already contains ground-truth. |
|
|
|
|
|
Args: |
|
|
data: Loaded JSON data (dict or list) |
|
|
|
|
|
Returns: |
|
|
bool: True if ground-truth is present, False otherwise |
|
|
""" |
|
|
|
|
|
if isinstance(data, dict): |
|
|
|
|
|
first_key = next(iter(data)) |
|
|
sample = data[first_key] |
|
|
elif isinstance(data, list): |
|
|
if not data: |
|
|
return False |
|
|
sample = data[0] |
|
|
else: |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
has_question = 'question' in sample |
|
|
has_gnd = 'gnd' in sample |
|
|
has_struc_info = 'struc_info' in sample |
|
|
has_metadata_dict = isinstance(sample.get('metadata'), dict) |
|
|
|
|
|
|
|
|
has_id = 'id' in sample |
|
|
has_prediction = 'prediction' in sample |
|
|
|
|
|
|
|
|
if has_id and has_prediction and not has_gnd: |
|
|
return False |
|
|
|
|
|
|
|
|
if has_question and has_gnd and has_struc_info: |
|
|
return True |
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def parse_id(id_str): |
|
|
"""Parse ID string into components. |
|
|
|
|
|
Format: video_id&&start_frame&&end_frame&&fps |
|
|
Example: "kcOqlifSukA&&22425&&25124&&1.0" |
|
|
|
|
|
Returns: |
|
|
dict: {'video_id': str, 'input_video_start_frame': str, |
|
|
'input_video_end_frame': str, 'fps': str} |
|
|
""" |
|
|
parts = id_str.split('&&') |
|
|
if len(parts) != 4: |
|
|
raise ValueError(f"Invalid ID format: {id_str}") |
|
|
|
|
|
return { |
|
|
'video_id': parts[0], |
|
|
'input_video_start_frame': parts[1], |
|
|
'input_video_end_frame': parts[2], |
|
|
'fps': parts[3] |
|
|
} |
|
|
|
|
|
|
|
|
def merge_with_ground_truth(predictions_file, ground_truth_file): |
|
|
"""Merge prediction-only file with ground-truth by array index. |
|
|
|
|
|
Args: |
|
|
predictions_file: Path to predictions JSON (array format, same order as ground truth) |
|
|
ground_truth_file: Path to ground-truth JSON |
|
|
|
|
|
Returns: |
|
|
dict: Merged data in results.json format |
|
|
""" |
|
|
print(f"[EvaluationWrapper] Loading predictions from {predictions_file}") |
|
|
with open(predictions_file, 'r') as f: |
|
|
predictions = json.load(f) |
|
|
|
|
|
print(f"[EvaluationWrapper] Loading ground-truth from {ground_truth_file}") |
|
|
with open(ground_truth_file, 'r') as f: |
|
|
ground_truth = json.load(f) |
|
|
|
|
|
print(f"[EvaluationWrapper] Predictions: {len(predictions)} records") |
|
|
print(f"[EvaluationWrapper] Ground-truth: {len(ground_truth)} records") |
|
|
|
|
|
|
|
|
if len(predictions) != len(ground_truth): |
|
|
raise ValueError( |
|
|
f"Length mismatch: predictions ({len(predictions)}) != ground truth ({len(ground_truth)}). " |
|
|
f"Predictions must be in the same order as ground truth." |
|
|
) |
|
|
|
|
|
|
|
|
merged = {} |
|
|
mismatched_qa_types = [] |
|
|
|
|
|
for i, (pred, gt_record) in enumerate(zip(predictions, ground_truth)): |
|
|
|
|
|
if 'prediction' not in pred: |
|
|
raise ValueError(f"Prediction at index {i} missing 'prediction' field") |
|
|
|
|
|
|
|
|
if 'qa_type' in pred and pred['qa_type'] != gt_record.get('qa_type'): |
|
|
mismatched_qa_types.append(i) |
|
|
|
|
|
|
|
|
question = '' |
|
|
gnd = '' |
|
|
if 'conversations' in gt_record: |
|
|
for msg in gt_record['conversations']: |
|
|
if msg.get('from') in ['human', 'user']: |
|
|
|
|
|
question = msg.get('value', '').replace('<video>\n', '').replace('<video>', '') |
|
|
elif msg.get('from') in ['gpt', 'assistant']: |
|
|
gnd = msg.get('value', '') |
|
|
|
|
|
|
|
|
data_source = gt_record.get('data_source', 'Unknown') |
|
|
if data_source == 'Unknown' or not data_source: |
|
|
data_source = gt_record.get('dataset_name', 'Unknown') |
|
|
|
|
|
|
|
|
merged_record = { |
|
|
'metadata': gt_record.get('metadata', {}), |
|
|
'qa_type': gt_record.get('qa_type', ''), |
|
|
'struc_info': gt_record.get('struc_info', []), |
|
|
'question': question, |
|
|
'gnd': gnd, |
|
|
'answer': pred.get('prediction', ''), |
|
|
'data_source': data_source |
|
|
} |
|
|
|
|
|
|
|
|
merged[str(i)] = merged_record |
|
|
|
|
|
if mismatched_qa_types: |
|
|
print(f"[EvaluationWrapper] ⚠️ Warning: {len(mismatched_qa_types)} samples with mismatched qa_type") |
|
|
|
|
|
print(f"[EvaluationWrapper] ✓ Successfully merged {len(merged)}/{len(predictions)} predictions") |
|
|
|
|
|
return merged |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function with command line interface.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Evaluate predictions with automatic ground-truth merging" |
|
|
) |
|
|
parser.add_argument("predictions_file", |
|
|
help="Path to predictions JSON file (can be merged or prediction-only format)") |
|
|
parser.add_argument("--ground-truth", |
|
|
default="/root/code/MedVidBench-Leaderboard/data/ground_truth.json", |
|
|
help="Path to ground-truth JSON file (default: data/ground_truth.json)") |
|
|
parser.add_argument("--tasks", nargs="+", |
|
|
choices=["dvc", "tal", "next_action", "stg", "rc", "vs", |
|
|
"skill_assessment", "cvs_assessment", "gemini_structured", "gpt_structured"], |
|
|
help="Specific tasks to evaluate (default: all available tasks)") |
|
|
parser.add_argument("--grouping", choices=["per-dataset", "overall"], default="per-dataset", |
|
|
help="Grouping strategy: 'per-dataset' or 'overall' (default: per-dataset)") |
|
|
parser.add_argument("--analyze-only", action="store_true", |
|
|
help="Only analyze the file structure without running evaluations") |
|
|
parser.add_argument("--skip-llm-judge", action="store_true", |
|
|
help="Skip LLM judge evaluation for caption tasks (use when LLM scores are pre-computed)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
print(f"[EvaluationWrapper] Loading predictions from {args.predictions_file}", flush=True) |
|
|
with open(args.predictions_file, 'r') as f: |
|
|
predictions_data = json.load(f) |
|
|
|
|
|
|
|
|
has_ground_truth = detect_has_ground_truth(predictions_data) |
|
|
|
|
|
if has_ground_truth: |
|
|
print("[EvaluationWrapper] ✓ Detected: Predictions already contain ground-truth", flush=True) |
|
|
print("[EvaluationWrapper] Using predictions file directly for evaluation", flush=True) |
|
|
eval_file = args.predictions_file |
|
|
else: |
|
|
print("[EvaluationWrapper] ✓ Detected: Prediction-only format (id, qa_type, prediction)", flush=True) |
|
|
print("[EvaluationWrapper] Merging with ground-truth...", flush=True) |
|
|
|
|
|
|
|
|
if not os.path.exists(args.ground_truth): |
|
|
print(f"[EvaluationWrapper] ❌ ERROR: Ground-truth file not found: {args.ground_truth}", flush=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
merged_data = merge_with_ground_truth(args.predictions_file, args.ground_truth) |
|
|
|
|
|
|
|
|
import tempfile |
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: |
|
|
json.dump(merged_data, f, indent=2) |
|
|
eval_file = f.name |
|
|
|
|
|
print(f"[EvaluationWrapper] ✓ Merged data saved to temporary file: {eval_file}", flush=True) |
|
|
|
|
|
|
|
|
print(f"\n[EvaluationWrapper] {'='*80}", flush=True) |
|
|
print(f"[EvaluationWrapper] Starting evaluation with evaluate_all_pai.py", flush=True) |
|
|
print(f"[EvaluationWrapper] {'='*80}\n", flush=True) |
|
|
|
|
|
|
|
|
eval_args = [eval_file] |
|
|
if args.tasks: |
|
|
eval_args.extend(["--tasks"] + args.tasks) |
|
|
if args.grouping: |
|
|
eval_args.extend(["--grouping", args.grouping]) |
|
|
if args.analyze_only: |
|
|
eval_args.append("--analyze-only") |
|
|
if args.skip_llm_judge: |
|
|
eval_args.append("--skip-llm-judge") |
|
|
|
|
|
original_argv = sys.argv |
|
|
sys.argv = ["evaluate_all_pai.py"] + eval_args |
|
|
|
|
|
try: |
|
|
|
|
|
if args.analyze_only: |
|
|
qa_type_counts, dataset_counts = evaluate_all_pai.analyze_output_file(eval_file) |
|
|
|
|
|
available_tasks = [] |
|
|
if any("dense_captioning" in qa_type or qa_type == "dc" for qa_type in qa_type_counts): |
|
|
available_tasks.append("dvc") |
|
|
if qa_type_counts.get("tal", 0) > 0: |
|
|
available_tasks.append("tal") |
|
|
if qa_type_counts.get("next_action", 0) > 0: |
|
|
available_tasks.append("next_action") |
|
|
if qa_type_counts.get("stg", 0) > 0: |
|
|
available_tasks.append("stg") |
|
|
if any("region_caption" in qa_type for qa_type in qa_type_counts): |
|
|
available_tasks.append("rc") |
|
|
if any("video_summary" in qa_type for qa_type in qa_type_counts): |
|
|
available_tasks.append("vs") |
|
|
if qa_type_counts.get("skill_assessment", 0) > 0: |
|
|
available_tasks.append("skill_assessment") |
|
|
if qa_type_counts.get("cvs_assessment", 0) > 0: |
|
|
available_tasks.append("cvs_assessment") |
|
|
|
|
|
evaluate_all_pai.print_evaluation_results_csv(eval_file, available_tasks) |
|
|
else: |
|
|
silent_eval = (args.grouping == "overall") |
|
|
evaluate_all_pai.run_evaluation( |
|
|
eval_file, |
|
|
args.tasks, |
|
|
grouping=args.grouping, |
|
|
silent_eval=silent_eval, |
|
|
skip_llm_judge=args.skip_llm_judge |
|
|
) |
|
|
finally: |
|
|
sys.argv = original_argv |
|
|
|
|
|
|
|
|
if not has_ground_truth and os.path.exists(eval_file): |
|
|
os.unlink(eval_file) |
|
|
print(f"\n[EvaluationWrapper] ✓ Cleaned up temporary file: {eval_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|