"""Auto-detect prediction format and evaluate with ground-truth merging if needed.""" import json import sys import argparse import os from pathlib import Path # Add evaluation directory to path to import evaluate_all_pai eval_dir = Path(__file__).parent sys.path.insert(0, str(eval_dir)) import evaluate_all_pai def detect_has_ground_truth(data): """Detect if prediction file already contains ground-truth. Args: data: Loaded JSON data (dict or list) Returns: bool: True if ground-truth is present, False otherwise """ # Handle both dict and list formats if isinstance(data, dict): # Check first record first_key = next(iter(data)) sample = data[first_key] elif isinstance(data, list): if not data: return False sample = data[0] else: return False # Check for ground-truth indicators # results.json format has: question, gnd, answer, struc_info, metadata, qa_type, data_source has_question = 'question' in sample has_gnd = 'gnd' in sample has_struc_info = 'struc_info' in sample has_metadata_dict = isinstance(sample.get('metadata'), dict) # predictions_only.json format has: id, qa_type, prediction has_id = 'id' in sample has_prediction = 'prediction' in sample # If it has id + prediction format, it's prediction-only if has_id and has_prediction and not has_gnd: return False # If it has question + gnd + struc_info, it's already merged if has_question and has_gnd and has_struc_info: return True # Default: assume needs merging if unclear return False def parse_id(id_str): """Parse ID string into components. Format: video_id&&start_frame&&end_frame&&fps Example: "kcOqlifSukA&&22425&&25124&&1.0" Returns: dict: {'video_id': str, 'input_video_start_frame': str, 'input_video_end_frame': str, 'fps': str} """ parts = id_str.split('&&') if len(parts) != 4: raise ValueError(f"Invalid ID format: {id_str}") return { 'video_id': parts[0], 'input_video_start_frame': parts[1], 'input_video_end_frame': parts[2], 'fps': parts[3] } def merge_with_ground_truth(predictions_file, ground_truth_file): """Merge prediction-only file with ground-truth by array index. Args: predictions_file: Path to predictions JSON (array format, same order as ground truth) ground_truth_file: Path to ground-truth JSON Returns: dict: Merged data in results.json format """ print(f"[EvaluationWrapper] Loading predictions from {predictions_file}") with open(predictions_file, 'r') as f: predictions = json.load(f) print(f"[EvaluationWrapper] Loading ground-truth from {ground_truth_file}") with open(ground_truth_file, 'r') as f: ground_truth = json.load(f) print(f"[EvaluationWrapper] Predictions: {len(predictions)} records") print(f"[EvaluationWrapper] Ground-truth: {len(ground_truth)} records") # Check lengths match if len(predictions) != len(ground_truth): raise ValueError( f"Length mismatch: predictions ({len(predictions)}) != ground truth ({len(ground_truth)}). " f"Predictions must be in the same order as ground truth." ) # Merge predictions with ground-truth by index merged = {} mismatched_qa_types = [] for i, (pred, gt_record) in enumerate(zip(predictions, ground_truth)): # Validate prediction has 'prediction' field if 'prediction' not in pred: raise ValueError(f"Prediction at index {i} missing 'prediction' field") # Optional: check qa_type matches if 'qa_type' in pred and pred['qa_type'] != gt_record.get('qa_type'): mismatched_qa_types.append(i) # Extract question and ground truth from conversations question = '' gnd = '' if 'conversations' in gt_record: for msg in gt_record['conversations']: if msg.get('from') in ['human', 'user']: # Remove