MedGRPO Team
commited on
Commit
·
b28cd8f
1
Parent(s):
2362e57
upload prediction only
Browse files
evaluation/evaluate_predictions.py
CHANGED
|
@@ -80,10 +80,10 @@ def parse_id(id_str):
|
|
| 80 |
|
| 81 |
|
| 82 |
def merge_with_ground_truth(predictions_file, ground_truth_file):
|
| 83 |
-
"""Merge prediction-only file with ground-truth.
|
| 84 |
|
| 85 |
Args:
|
| 86 |
-
predictions_file: Path to predictions JSON (
|
| 87 |
ground_truth_file: Path to ground-truth JSON
|
| 88 |
|
| 89 |
Returns:
|
|
@@ -97,64 +97,63 @@ def merge_with_ground_truth(predictions_file, ground_truth_file):
|
|
| 97 |
with open(ground_truth_file, 'r') as f:
|
| 98 |
ground_truth = json.load(f)
|
| 99 |
|
| 100 |
-
|
| 101 |
-
print("[EvaluationWrapper]
|
| 102 |
-
gt_index = {}
|
| 103 |
-
for record in ground_truth:
|
| 104 |
-
metadata = record.get('metadata', {})
|
| 105 |
-
# Create key from metadata
|
| 106 |
-
key = f"{metadata.get('video_id')}&&{metadata.get('input_video_start_frame')}&&{metadata.get('input_video_end_frame')}&&{metadata.get('fps')}"
|
| 107 |
-
gt_index[key] = record
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
# Merge predictions with ground-truth
|
| 113 |
merged = {}
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
data_source = gt_record.get('data_source', 'Unknown')
|
| 132 |
-
# Fallback to dataset_name if data_source is missing
|
| 133 |
if data_source == 'Unknown' or not data_source:
|
| 134 |
data_source = gt_record.get('dataset_name', 'Unknown')
|
| 135 |
|
|
|
|
| 136 |
merged_record = {
|
| 137 |
'metadata': gt_record.get('metadata', {}),
|
| 138 |
-
'qa_type':
|
| 139 |
'struc_info': gt_record.get('struc_info', []),
|
| 140 |
-
'question':
|
| 141 |
-
'gnd':
|
| 142 |
'answer': pred.get('prediction', ''), # Model prediction
|
| 143 |
'data_source': data_source
|
| 144 |
}
|
| 145 |
|
| 146 |
# Use sequential keys like results.json
|
| 147 |
merged[str(i)] = merged_record
|
| 148 |
-
matched_count += 1
|
| 149 |
|
| 150 |
-
|
|
|
|
| 151 |
|
| 152 |
-
|
| 153 |
-
print(f"[EvaluationWrapper] ⚠️ WARNING: {len(unmatched_ids)} predictions not found in ground-truth")
|
| 154 |
-
if len(unmatched_ids) <= 5:
|
| 155 |
-
print(f"[EvaluationWrapper] Unmatched IDs: {unmatched_ids}")
|
| 156 |
-
else:
|
| 157 |
-
print(f"[EvaluationWrapper] First 5 unmatched IDs: {unmatched_ids[:5]}")
|
| 158 |
|
| 159 |
return merged
|
| 160 |
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def merge_with_ground_truth(predictions_file, ground_truth_file):
|
| 83 |
+
"""Merge prediction-only file with ground-truth by array index.
|
| 84 |
|
| 85 |
Args:
|
| 86 |
+
predictions_file: Path to predictions JSON (array format, same order as ground truth)
|
| 87 |
ground_truth_file: Path to ground-truth JSON
|
| 88 |
|
| 89 |
Returns:
|
|
|
|
| 97 |
with open(ground_truth_file, 'r') as f:
|
| 98 |
ground_truth = json.load(f)
|
| 99 |
|
| 100 |
+
print(f"[EvaluationWrapper] Predictions: {len(predictions)} records")
|
| 101 |
+
print(f"[EvaluationWrapper] Ground-truth: {len(ground_truth)} records")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
# Check lengths match
|
| 104 |
+
if len(predictions) != len(ground_truth):
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"Length mismatch: predictions ({len(predictions)}) != ground truth ({len(ground_truth)}). "
|
| 107 |
+
f"Predictions must be in the same order as ground truth."
|
| 108 |
+
)
|
| 109 |
|
| 110 |
+
# Merge predictions with ground-truth by index
|
| 111 |
merged = {}
|
| 112 |
+
mismatched_qa_types = []
|
| 113 |
+
|
| 114 |
+
for i, (pred, gt_record) in enumerate(zip(predictions, ground_truth)):
|
| 115 |
+
# Validate prediction has 'prediction' field
|
| 116 |
+
if 'prediction' not in pred:
|
| 117 |
+
raise ValueError(f"Prediction at index {i} missing 'prediction' field")
|
| 118 |
+
|
| 119 |
+
# Optional: check qa_type matches
|
| 120 |
+
if 'qa_type' in pred and pred['qa_type'] != gt_record.get('qa_type'):
|
| 121 |
+
mismatched_qa_types.append(i)
|
| 122 |
+
|
| 123 |
+
# Extract question and ground truth from conversations
|
| 124 |
+
question = ''
|
| 125 |
+
gnd = ''
|
| 126 |
+
if 'conversations' in gt_record:
|
| 127 |
+
for msg in gt_record['conversations']:
|
| 128 |
+
if msg.get('from') in ['human', 'user']:
|
| 129 |
+
# Remove <video> token to match original format
|
| 130 |
+
question = msg.get('value', '').replace('<video>\n', '').replace('<video>', '')
|
| 131 |
+
elif msg.get('from') in ['gpt', 'assistant']:
|
| 132 |
+
gnd = msg.get('value', '')
|
| 133 |
+
|
| 134 |
+
# Get data_source
|
| 135 |
data_source = gt_record.get('data_source', 'Unknown')
|
|
|
|
| 136 |
if data_source == 'Unknown' or not data_source:
|
| 137 |
data_source = gt_record.get('dataset_name', 'Unknown')
|
| 138 |
|
| 139 |
+
# Create merged record
|
| 140 |
merged_record = {
|
| 141 |
'metadata': gt_record.get('metadata', {}),
|
| 142 |
+
'qa_type': gt_record.get('qa_type', ''),
|
| 143 |
'struc_info': gt_record.get('struc_info', []),
|
| 144 |
+
'question': question,
|
| 145 |
+
'gnd': gnd,
|
| 146 |
'answer': pred.get('prediction', ''), # Model prediction
|
| 147 |
'data_source': data_source
|
| 148 |
}
|
| 149 |
|
| 150 |
# Use sequential keys like results.json
|
| 151 |
merged[str(i)] = merged_record
|
|
|
|
| 152 |
|
| 153 |
+
if mismatched_qa_types:
|
| 154 |
+
print(f"[EvaluationWrapper] ⚠️ Warning: {len(mismatched_qa_types)} samples with mismatched qa_type")
|
| 155 |
|
| 156 |
+
print(f"[EvaluationWrapper] ✓ Successfully merged {len(merged)}/{len(predictions)} predictions")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
return merged
|
| 159 |
|
evaluation/extract_predictions.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Extract predictions from results.json for user submission format.
|
| 3 |
+
|
| 4 |
+
This script extracts only the prediction-related fields from results.json,
|
| 5 |
+
creating a format that users would submit (without ground truth data).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def extract_predictions(results_file: str, output_file: str) -> None:
|
| 14 |
+
"""
|
| 15 |
+
Extract predictions from results.json.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
results_file: Path to results.json (dict format with numeric keys)
|
| 19 |
+
output_file: Path to save predictions (list format)
|
| 20 |
+
"""
|
| 21 |
+
print(f"Loading results from: {results_file}")
|
| 22 |
+
with open(results_file) as f:
|
| 23 |
+
results = json.load(f)
|
| 24 |
+
|
| 25 |
+
# results.json is a dict with numeric string keys ("0", "1", "2", ...)
|
| 26 |
+
# We need to convert to list format with proper IDs
|
| 27 |
+
|
| 28 |
+
print(f"Loaded {len(results)} results")
|
| 29 |
+
|
| 30 |
+
# Extract predictions
|
| 31 |
+
predictions = []
|
| 32 |
+
for idx, (key, result) in enumerate(results.items()):
|
| 33 |
+
# Create ID from metadata
|
| 34 |
+
metadata = result.get('metadata', {})
|
| 35 |
+
video_id = metadata.get('video_id', '')
|
| 36 |
+
|
| 37 |
+
# Try both naming conventions for frame numbers
|
| 38 |
+
start_frame = metadata.get('input_video_start_frame', '') or metadata.get('start_frame', '')
|
| 39 |
+
end_frame = metadata.get('input_video_end_frame', '') or metadata.get('end_frame', '')
|
| 40 |
+
fps = metadata.get('fps', '')
|
| 41 |
+
|
| 42 |
+
# ID format: video_id&&start_frame&&end_frame&&fps
|
| 43 |
+
sample_id = f"{video_id}&&{start_frame}&&{end_frame}&&{fps}"
|
| 44 |
+
|
| 45 |
+
prediction = {
|
| 46 |
+
'id': sample_id,
|
| 47 |
+
'qa_type': result.get('qa_type', ''),
|
| 48 |
+
'prediction': result.get('answer', '')
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
predictions.append(prediction)
|
| 52 |
+
|
| 53 |
+
if (idx + 1) % 1000 == 0:
|
| 54 |
+
print(f"Processed {idx + 1} predictions...")
|
| 55 |
+
|
| 56 |
+
# Save predictions
|
| 57 |
+
print(f"Saving {len(predictions)} predictions to: {output_file}")
|
| 58 |
+
with open(output_file, 'w') as f:
|
| 59 |
+
json.dump(predictions, f, indent=2)
|
| 60 |
+
|
| 61 |
+
print(f"✓ Successfully extracted {len(predictions)} predictions")
|
| 62 |
+
|
| 63 |
+
# Show sample
|
| 64 |
+
if predictions:
|
| 65 |
+
print("\nSample prediction (first entry):")
|
| 66 |
+
print(json.dumps(predictions[0], indent=2))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def main():
|
| 70 |
+
"""Command-line interface."""
|
| 71 |
+
if len(sys.argv) != 3:
|
| 72 |
+
print("Usage: python extract_predictions.py results.json predictions.json")
|
| 73 |
+
print()
|
| 74 |
+
print("Arguments:")
|
| 75 |
+
print(" results.json - Input results file (with ground truth)")
|
| 76 |
+
print(" predictions.json - Output predictions file (user format)")
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
results_file = sys.argv[1]
|
| 80 |
+
output_file = sys.argv[2]
|
| 81 |
+
|
| 82 |
+
extract_predictions(results_file, output_file)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
main()
|
evaluation/merge_predictions_with_gt.py
CHANGED
|
@@ -16,10 +16,10 @@ def merge_predictions_with_ground_truth(
|
|
| 16 |
output_file: str
|
| 17 |
) -> Tuple[bool, str]:
|
| 18 |
"""
|
| 19 |
-
Merge user predictions with server-side ground truth.
|
| 20 |
|
| 21 |
Args:
|
| 22 |
-
predictions_file: Path to user's predictions JSON (
|
| 23 |
ground_truth_file: Path to ground truth JSON (struc_info, GPT responses)
|
| 24 |
output_file: Path to save merged JSON for evaluation
|
| 25 |
|
|
@@ -44,73 +44,71 @@ def merge_predictions_with_ground_truth(
|
|
| 44 |
if not isinstance(predictions, list):
|
| 45 |
return False, "Predictions must be a JSON array"
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
for i, pred in enumerate(predictions):
|
| 50 |
-
if 'id' not in pred:
|
| 51 |
-
return False, f"Prediction {i} missing 'id' field"
|
| 52 |
if 'prediction' not in pred:
|
| 53 |
-
return False, f"Prediction {
|
| 54 |
-
|
| 55 |
-
pred_lookup[pred['id']] = pred
|
| 56 |
|
| 57 |
-
# Merge predictions with ground truth
|
| 58 |
-
merged =
|
| 59 |
-
missing_predictions = []
|
| 60 |
mismatched_qa_types = []
|
| 61 |
|
| 62 |
-
for gt_sample in ground_truth:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# Skip if user didn't provide prediction for this sample
|
| 66 |
-
if sample_id not in pred_lookup:
|
| 67 |
-
missing_predictions.append(sample_id)
|
| 68 |
-
continue
|
| 69 |
-
|
| 70 |
-
pred = pred_lookup[sample_id]
|
| 71 |
|
| 72 |
# Verify qa_type matches (optional validation)
|
| 73 |
if 'qa_type' in pred and pred['qa_type'] != gt_sample.get('qa_type'):
|
| 74 |
mismatched_qa_types.append({
|
| 75 |
-
'
|
| 76 |
-
'predicted': pred
|
| 77 |
'actual': gt_sample.get('qa_type')
|
| 78 |
})
|
| 79 |
|
| 80 |
-
#
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
print(f"Saving merged data to: {output_file}")
|
| 95 |
with open(output_file, 'w') as f:
|
| 96 |
json.dump(merged, f, indent=2)
|
| 97 |
|
| 98 |
# Build result message
|
| 99 |
message_parts = [
|
| 100 |
-
f"Successfully merged {len(merged)}
|
| 101 |
]
|
| 102 |
|
| 103 |
-
if missing_predictions:
|
| 104 |
-
message_parts.append(
|
| 105 |
-
f"Warning: {len(missing_predictions)} samples without predictions"
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
if mismatched_qa_types:
|
| 109 |
message_parts.append(
|
| 110 |
f"Warning: {len(mismatched_qa_types)} samples with mismatched qa_type"
|
| 111 |
)
|
| 112 |
for mismatch in mismatched_qa_types[:5]: # Show first 5
|
| 113 |
-
print(f" Mismatch
|
| 114 |
|
| 115 |
message = ". ".join(message_parts)
|
| 116 |
print(message)
|
|
@@ -131,9 +129,11 @@ def main():
|
|
| 131 |
print("Usage: python merge_predictions_with_gt.py predictions.json ground_truth.json output.json")
|
| 132 |
print()
|
| 133 |
print("Arguments:")
|
| 134 |
-
print(" predictions.json - User's predictions (
|
| 135 |
print(" ground_truth.json - Server's ground truth (struc_info, GPT responses)")
|
| 136 |
print(" output.json - Merged output for evaluation")
|
|
|
|
|
|
|
| 137 |
sys.exit(1)
|
| 138 |
|
| 139 |
predictions_file = sys.argv[1]
|
|
|
|
| 16 |
output_file: str
|
| 17 |
) -> Tuple[bool, str]:
|
| 18 |
"""
|
| 19 |
+
Merge user predictions with server-side ground truth by array index.
|
| 20 |
|
| 21 |
Args:
|
| 22 |
+
predictions_file: Path to user's predictions JSON array (same order as ground truth)
|
| 23 |
ground_truth_file: Path to ground truth JSON (struc_info, GPT responses)
|
| 24 |
output_file: Path to save merged JSON for evaluation
|
| 25 |
|
|
|
|
| 44 |
if not isinstance(predictions, list):
|
| 45 |
return False, "Predictions must be a JSON array"
|
| 46 |
|
| 47 |
+
# Check lengths match for index-based merging
|
| 48 |
+
if len(predictions) != len(ground_truth):
|
| 49 |
+
return False, f"Predictions ({len(predictions)}) and ground truth ({len(ground_truth)}) must have the same length"
|
| 50 |
+
|
| 51 |
+
# Validate predictions have required fields
|
| 52 |
for i, pred in enumerate(predictions):
|
|
|
|
|
|
|
| 53 |
if 'prediction' not in pred:
|
| 54 |
+
return False, f"Prediction at index {i} missing 'prediction' field"
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
# Merge predictions with ground truth by index
|
| 57 |
+
merged = {}
|
|
|
|
| 58 |
mismatched_qa_types = []
|
| 59 |
|
| 60 |
+
for idx, gt_sample in enumerate(ground_truth):
|
| 61 |
+
pred = predictions[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Verify qa_type matches (optional validation)
|
| 64 |
if 'qa_type' in pred and pred['qa_type'] != gt_sample.get('qa_type'):
|
| 65 |
mismatched_qa_types.append({
|
| 66 |
+
'index': idx,
|
| 67 |
+
'predicted': pred.get('qa_type'),
|
| 68 |
'actual': gt_sample.get('qa_type')
|
| 69 |
})
|
| 70 |
|
| 71 |
+
# Create minimal format matching original results.json
|
| 72 |
+
# Only include essential fields: metadata, qa_type, struc_info, question, gnd, answer, data_source
|
| 73 |
+
merged_sample = {
|
| 74 |
+
'metadata': gt_sample.get('metadata', {}),
|
| 75 |
+
'qa_type': gt_sample.get('qa_type', ''),
|
| 76 |
+
'struc_info': gt_sample.get('struc_info', []),
|
| 77 |
+
'question': '', # Extract from conversations if present
|
| 78 |
+
'gnd': '', # Extract from conversations if present
|
| 79 |
+
'answer': pred['prediction'],
|
| 80 |
+
'data_source': gt_sample.get('data_source', '')
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# Extract question and ground truth answer from conversations
|
| 84 |
+
if 'conversations' in gt_sample:
|
| 85 |
+
for msg in gt_sample['conversations']:
|
| 86 |
+
if msg.get('from') in ['human', 'user']:
|
| 87 |
+
# Remove <video> token from question to match original format
|
| 88 |
+
question = msg.get('value', '')
|
| 89 |
+
merged_sample['question'] = question.replace('<video>\n', '').replace('<video>', '')
|
| 90 |
+
elif msg.get('from') in ['gpt', 'assistant']:
|
| 91 |
+
merged_sample['gnd'] = msg.get('value', '')
|
| 92 |
+
|
| 93 |
+
# Use numeric string key to match original format
|
| 94 |
+
merged[str(idx)] = merged_sample
|
| 95 |
+
|
| 96 |
+
# Save merged data as dict with numeric string keys
|
| 97 |
print(f"Saving merged data to: {output_file}")
|
| 98 |
with open(output_file, 'w') as f:
|
| 99 |
json.dump(merged, f, indent=2)
|
| 100 |
|
| 101 |
# Build result message
|
| 102 |
message_parts = [
|
| 103 |
+
f"Successfully merged {len(merged)} samples"
|
| 104 |
]
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
if mismatched_qa_types:
|
| 107 |
message_parts.append(
|
| 108 |
f"Warning: {len(mismatched_qa_types)} samples with mismatched qa_type"
|
| 109 |
)
|
| 110 |
for mismatch in mismatched_qa_types[:5]: # Show first 5
|
| 111 |
+
print(f" Mismatch at index {mismatch['index']}: predicted: {mismatch['predicted']}, actual: {mismatch['actual']}")
|
| 112 |
|
| 113 |
message = ". ".join(message_parts)
|
| 114 |
print(message)
|
|
|
|
| 129 |
print("Usage: python merge_predictions_with_gt.py predictions.json ground_truth.json output.json")
|
| 130 |
print()
|
| 131 |
print("Arguments:")
|
| 132 |
+
print(" predictions.json - User's predictions array (same length/order as ground truth)")
|
| 133 |
print(" ground_truth.json - Server's ground truth (struc_info, GPT responses)")
|
| 134 |
print(" output.json - Merged output for evaluation")
|
| 135 |
+
print()
|
| 136 |
+
print("Note: Predictions and ground truth are merged by array index (0-based).")
|
| 137 |
sys.exit(1)
|
| 138 |
|
| 139 |
predictions_file = sys.argv[1]
|