MedGRPO Team commited on
Commit
b28cd8f
·
1 Parent(s): 2362e57

upload prediction only

Browse files
evaluation/evaluate_predictions.py CHANGED
@@ -80,10 +80,10 @@ def parse_id(id_str):
80
 
81
 
82
  def merge_with_ground_truth(predictions_file, ground_truth_file):
83
- """Merge prediction-only file with ground-truth.
84
 
85
  Args:
86
- predictions_file: Path to predictions JSON (id, qa_type, prediction format)
87
  ground_truth_file: Path to ground-truth JSON
88
 
89
  Returns:
@@ -97,64 +97,63 @@ def merge_with_ground_truth(predictions_file, ground_truth_file):
97
  with open(ground_truth_file, 'r') as f:
98
  ground_truth = json.load(f)
99
 
100
- # Build lookup index for ground-truth
101
- print("[EvaluationWrapper] Building ground-truth index...")
102
- gt_index = {}
103
- for record in ground_truth:
104
- metadata = record.get('metadata', {})
105
- # Create key from metadata
106
- key = f"{metadata.get('video_id')}&&{metadata.get('input_video_start_frame')}&&{metadata.get('input_video_end_frame')}&&{metadata.get('fps')}"
107
- gt_index[key] = record
108
 
109
- print(f"[EvaluationWrapper] Ground-truth index size: {len(gt_index)} records")
110
- print(f"[EvaluationWrapper] Predictions to merge: {len(predictions)} records")
 
 
 
 
111
 
112
- # Merge predictions with ground-truth
113
  merged = {}
114
- matched_count = 0
115
- unmatched_ids = []
116
-
117
- for i, pred in enumerate(predictions):
118
- pred_id = pred.get('id')
119
- if not pred_id:
120
- print(f"[EvaluationWrapper] ⚠️ WARNING: Prediction {i} missing 'id' field, skipping")
121
- continue
122
-
123
- # Look up ground-truth
124
- if pred_id not in gt_index:
125
- unmatched_ids.append(pred_id)
126
- continue
127
-
128
- gt_record = gt_index[pred_id]
129
-
130
- # Create merged record (ensure data_source is properly set)
 
 
 
 
 
 
131
  data_source = gt_record.get('data_source', 'Unknown')
132
- # Fallback to dataset_name if data_source is missing
133
  if data_source == 'Unknown' or not data_source:
134
  data_source = gt_record.get('dataset_name', 'Unknown')
135
 
 
136
  merged_record = {
137
  'metadata': gt_record.get('metadata', {}),
138
- 'qa_type': pred.get('qa_type'),
139
  'struc_info': gt_record.get('struc_info', []),
140
- 'question': gt_record.get('question', ''),
141
- 'gnd': gt_record.get('answer', ''), # Ground-truth answer
142
  'answer': pred.get('prediction', ''), # Model prediction
143
  'data_source': data_source
144
  }
145
 
146
  # Use sequential keys like results.json
147
  merged[str(i)] = merged_record
148
- matched_count += 1
149
 
150
- print(f"[EvaluationWrapper] ✓ Successfully merged {matched_count}/{len(predictions)} predictions")
 
151
 
152
- if unmatched_ids:
153
- print(f"[EvaluationWrapper] ⚠️ WARNING: {len(unmatched_ids)} predictions not found in ground-truth")
154
- if len(unmatched_ids) <= 5:
155
- print(f"[EvaluationWrapper] Unmatched IDs: {unmatched_ids}")
156
- else:
157
- print(f"[EvaluationWrapper] First 5 unmatched IDs: {unmatched_ids[:5]}")
158
 
159
  return merged
160
 
 
80
 
81
 
82
  def merge_with_ground_truth(predictions_file, ground_truth_file):
83
+ """Merge prediction-only file with ground-truth by array index.
84
 
85
  Args:
86
+ predictions_file: Path to predictions JSON (array format, same order as ground truth)
87
  ground_truth_file: Path to ground-truth JSON
88
 
89
  Returns:
 
97
  with open(ground_truth_file, 'r') as f:
98
  ground_truth = json.load(f)
99
 
100
+ print(f"[EvaluationWrapper] Predictions: {len(predictions)} records")
101
+ print(f"[EvaluationWrapper] Ground-truth: {len(ground_truth)} records")
 
 
 
 
 
 
102
 
103
+ # Check lengths match
104
+ if len(predictions) != len(ground_truth):
105
+ raise ValueError(
106
+ f"Length mismatch: predictions ({len(predictions)}) != ground truth ({len(ground_truth)}). "
107
+ f"Predictions must be in the same order as ground truth."
108
+ )
109
 
110
+ # Merge predictions with ground-truth by index
111
  merged = {}
112
+ mismatched_qa_types = []
113
+
114
+ for i, (pred, gt_record) in enumerate(zip(predictions, ground_truth)):
115
+ # Validate prediction has 'prediction' field
116
+ if 'prediction' not in pred:
117
+ raise ValueError(f"Prediction at index {i} missing 'prediction' field")
118
+
119
+ # Optional: check qa_type matches
120
+ if 'qa_type' in pred and pred['qa_type'] != gt_record.get('qa_type'):
121
+ mismatched_qa_types.append(i)
122
+
123
+ # Extract question and ground truth from conversations
124
+ question = ''
125
+ gnd = ''
126
+ if 'conversations' in gt_record:
127
+ for msg in gt_record['conversations']:
128
+ if msg.get('from') in ['human', 'user']:
129
+ # Remove <video> token to match original format
130
+ question = msg.get('value', '').replace('<video>\n', '').replace('<video>', '')
131
+ elif msg.get('from') in ['gpt', 'assistant']:
132
+ gnd = msg.get('value', '')
133
+
134
+ # Get data_source
135
  data_source = gt_record.get('data_source', 'Unknown')
 
136
  if data_source == 'Unknown' or not data_source:
137
  data_source = gt_record.get('dataset_name', 'Unknown')
138
 
139
+ # Create merged record
140
  merged_record = {
141
  'metadata': gt_record.get('metadata', {}),
142
+ 'qa_type': gt_record.get('qa_type', ''),
143
  'struc_info': gt_record.get('struc_info', []),
144
+ 'question': question,
145
+ 'gnd': gnd,
146
  'answer': pred.get('prediction', ''), # Model prediction
147
  'data_source': data_source
148
  }
149
 
150
  # Use sequential keys like results.json
151
  merged[str(i)] = merged_record
 
152
 
153
+ if mismatched_qa_types:
154
+ print(f"[EvaluationWrapper] ⚠️ Warning: {len(mismatched_qa_types)} samples with mismatched qa_type")
155
 
156
+ print(f"[EvaluationWrapper] ✓ Successfully merged {len(merged)}/{len(predictions)} predictions")
 
 
 
 
 
157
 
158
  return merged
159
 
evaluation/extract_predictions.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Extract predictions from results.json for user submission format.
3
+
4
+ This script extracts only the prediction-related fields from results.json,
5
+ creating a format that users would submit (without ground truth data).
6
+ """
7
+
8
+ import json
9
+ import sys
10
+ from pathlib import Path
11
+
12
+
13
+ def extract_predictions(results_file: str, output_file: str) -> None:
14
+ """
15
+ Extract predictions from results.json.
16
+
17
+ Args:
18
+ results_file: Path to results.json (dict format with numeric keys)
19
+ output_file: Path to save predictions (list format)
20
+ """
21
+ print(f"Loading results from: {results_file}")
22
+ with open(results_file) as f:
23
+ results = json.load(f)
24
+
25
+ # results.json is a dict with numeric string keys ("0", "1", "2", ...)
26
+ # We need to convert to list format with proper IDs
27
+
28
+ print(f"Loaded {len(results)} results")
29
+
30
+ # Extract predictions
31
+ predictions = []
32
+ for idx, (key, result) in enumerate(results.items()):
33
+ # Create ID from metadata
34
+ metadata = result.get('metadata', {})
35
+ video_id = metadata.get('video_id', '')
36
+
37
+ # Try both naming conventions for frame numbers
38
+ start_frame = metadata.get('input_video_start_frame', '') or metadata.get('start_frame', '')
39
+ end_frame = metadata.get('input_video_end_frame', '') or metadata.get('end_frame', '')
40
+ fps = metadata.get('fps', '')
41
+
42
+ # ID format: video_id&&start_frame&&end_frame&&fps
43
+ sample_id = f"{video_id}&&{start_frame}&&{end_frame}&&{fps}"
44
+
45
+ prediction = {
46
+ 'id': sample_id,
47
+ 'qa_type': result.get('qa_type', ''),
48
+ 'prediction': result.get('answer', '')
49
+ }
50
+
51
+ predictions.append(prediction)
52
+
53
+ if (idx + 1) % 1000 == 0:
54
+ print(f"Processed {idx + 1} predictions...")
55
+
56
+ # Save predictions
57
+ print(f"Saving {len(predictions)} predictions to: {output_file}")
58
+ with open(output_file, 'w') as f:
59
+ json.dump(predictions, f, indent=2)
60
+
61
+ print(f"✓ Successfully extracted {len(predictions)} predictions")
62
+
63
+ # Show sample
64
+ if predictions:
65
+ print("\nSample prediction (first entry):")
66
+ print(json.dumps(predictions[0], indent=2))
67
+
68
+
69
+ def main():
70
+ """Command-line interface."""
71
+ if len(sys.argv) != 3:
72
+ print("Usage: python extract_predictions.py results.json predictions.json")
73
+ print()
74
+ print("Arguments:")
75
+ print(" results.json - Input results file (with ground truth)")
76
+ print(" predictions.json - Output predictions file (user format)")
77
+ sys.exit(1)
78
+
79
+ results_file = sys.argv[1]
80
+ output_file = sys.argv[2]
81
+
82
+ extract_predictions(results_file, output_file)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()
evaluation/merge_predictions_with_gt.py CHANGED
@@ -16,10 +16,10 @@ def merge_predictions_with_ground_truth(
16
  output_file: str
17
  ) -> Tuple[bool, str]:
18
  """
19
- Merge user predictions with server-side ground truth.
20
 
21
  Args:
22
- predictions_file: Path to user's predictions JSON (id, qa_type, prediction)
23
  ground_truth_file: Path to ground truth JSON (struc_info, GPT responses)
24
  output_file: Path to save merged JSON for evaluation
25
 
@@ -44,73 +44,71 @@ def merge_predictions_with_ground_truth(
44
  if not isinstance(predictions, list):
45
  return False, "Predictions must be a JSON array"
46
 
47
- # Create prediction lookup by id
48
- pred_lookup = {}
 
 
 
49
  for i, pred in enumerate(predictions):
50
- if 'id' not in pred:
51
- return False, f"Prediction {i} missing 'id' field"
52
  if 'prediction' not in pred:
53
- return False, f"Prediction {pred['id']} missing 'prediction' field"
54
-
55
- pred_lookup[pred['id']] = pred
56
 
57
- # Merge predictions with ground truth
58
- merged = []
59
- missing_predictions = []
60
  mismatched_qa_types = []
61
 
62
- for gt_sample in ground_truth:
63
- sample_id = gt_sample['id']
64
-
65
- # Skip if user didn't provide prediction for this sample
66
- if sample_id not in pred_lookup:
67
- missing_predictions.append(sample_id)
68
- continue
69
-
70
- pred = pred_lookup[sample_id]
71
 
72
  # Verify qa_type matches (optional validation)
73
  if 'qa_type' in pred and pred['qa_type'] != gt_sample.get('qa_type'):
74
  mismatched_qa_types.append({
75
- 'id': sample_id,
76
- 'predicted': pred['qa_type'],
77
  'actual': gt_sample.get('qa_type')
78
  })
79
 
80
- # Start with ground truth sample (contains all metadata + ground truth)
81
- merged_sample = gt_sample.copy()
82
-
83
- # Add user's prediction
84
- merged_sample['prediction'] = pred['prediction']
85
-
86
- # Ground truth fields that are already present:
87
- # - struc_info (for TAL/STG evaluation)
88
- # - conversations with GPT responses (for caption evaluation)
89
- # - all metadata (video paths, fps, etc.)
90
-
91
- merged.append(merged_sample)
92
-
93
- # Save merged data
 
 
 
 
 
 
 
 
 
 
 
 
94
  print(f"Saving merged data to: {output_file}")
95
  with open(output_file, 'w') as f:
96
  json.dump(merged, f, indent=2)
97
 
98
  # Build result message
99
  message_parts = [
100
- f"Successfully merged {len(merged)}/{len(ground_truth)} samples"
101
  ]
102
 
103
- if missing_predictions:
104
- message_parts.append(
105
- f"Warning: {len(missing_predictions)} samples without predictions"
106
- )
107
-
108
  if mismatched_qa_types:
109
  message_parts.append(
110
  f"Warning: {len(mismatched_qa_types)} samples with mismatched qa_type"
111
  )
112
  for mismatch in mismatched_qa_types[:5]: # Show first 5
113
- print(f" Mismatch: {mismatch['id']} - predicted: {mismatch['predicted']}, actual: {mismatch['actual']}")
114
 
115
  message = ". ".join(message_parts)
116
  print(message)
@@ -131,9 +129,11 @@ def main():
131
  print("Usage: python merge_predictions_with_gt.py predictions.json ground_truth.json output.json")
132
  print()
133
  print("Arguments:")
134
- print(" predictions.json - User's predictions (id, qa_type, prediction)")
135
  print(" ground_truth.json - Server's ground truth (struc_info, GPT responses)")
136
  print(" output.json - Merged output for evaluation")
 
 
137
  sys.exit(1)
138
 
139
  predictions_file = sys.argv[1]
 
16
  output_file: str
17
  ) -> Tuple[bool, str]:
18
  """
19
+ Merge user predictions with server-side ground truth by array index.
20
 
21
  Args:
22
+ predictions_file: Path to user's predictions JSON array (same order as ground truth)
23
  ground_truth_file: Path to ground truth JSON (struc_info, GPT responses)
24
  output_file: Path to save merged JSON for evaluation
25
 
 
44
  if not isinstance(predictions, list):
45
  return False, "Predictions must be a JSON array"
46
 
47
+ # Check lengths match for index-based merging
48
+ if len(predictions) != len(ground_truth):
49
+ return False, f"Predictions ({len(predictions)}) and ground truth ({len(ground_truth)}) must have the same length"
50
+
51
+ # Validate predictions have required fields
52
  for i, pred in enumerate(predictions):
 
 
53
  if 'prediction' not in pred:
54
+ return False, f"Prediction at index {i} missing 'prediction' field"
 
 
55
 
56
+ # Merge predictions with ground truth by index
57
+ merged = {}
 
58
  mismatched_qa_types = []
59
 
60
+ for idx, gt_sample in enumerate(ground_truth):
61
+ pred = predictions[idx]
 
 
 
 
 
 
 
62
 
63
  # Verify qa_type matches (optional validation)
64
  if 'qa_type' in pred and pred['qa_type'] != gt_sample.get('qa_type'):
65
  mismatched_qa_types.append({
66
+ 'index': idx,
67
+ 'predicted': pred.get('qa_type'),
68
  'actual': gt_sample.get('qa_type')
69
  })
70
 
71
+ # Create minimal format matching original results.json
72
+ # Only include essential fields: metadata, qa_type, struc_info, question, gnd, answer, data_source
73
+ merged_sample = {
74
+ 'metadata': gt_sample.get('metadata', {}),
75
+ 'qa_type': gt_sample.get('qa_type', ''),
76
+ 'struc_info': gt_sample.get('struc_info', []),
77
+ 'question': '', # Extract from conversations if present
78
+ 'gnd': '', # Extract from conversations if present
79
+ 'answer': pred['prediction'],
80
+ 'data_source': gt_sample.get('data_source', '')
81
+ }
82
+
83
+ # Extract question and ground truth answer from conversations
84
+ if 'conversations' in gt_sample:
85
+ for msg in gt_sample['conversations']:
86
+ if msg.get('from') in ['human', 'user']:
87
+ # Remove <video> token from question to match original format
88
+ question = msg.get('value', '')
89
+ merged_sample['question'] = question.replace('<video>\n', '').replace('<video>', '')
90
+ elif msg.get('from') in ['gpt', 'assistant']:
91
+ merged_sample['gnd'] = msg.get('value', '')
92
+
93
+ # Use numeric string key to match original format
94
+ merged[str(idx)] = merged_sample
95
+
96
+ # Save merged data as dict with numeric string keys
97
  print(f"Saving merged data to: {output_file}")
98
  with open(output_file, 'w') as f:
99
  json.dump(merged, f, indent=2)
100
 
101
  # Build result message
102
  message_parts = [
103
+ f"Successfully merged {len(merged)} samples"
104
  ]
105
 
 
 
 
 
 
106
  if mismatched_qa_types:
107
  message_parts.append(
108
  f"Warning: {len(mismatched_qa_types)} samples with mismatched qa_type"
109
  )
110
  for mismatch in mismatched_qa_types[:5]: # Show first 5
111
+ print(f" Mismatch at index {mismatch['index']}: predicted: {mismatch['predicted']}, actual: {mismatch['actual']}")
112
 
113
  message = ". ".join(message_parts)
114
  print(message)
 
129
  print("Usage: python merge_predictions_with_gt.py predictions.json ground_truth.json output.json")
130
  print()
131
  print("Arguments:")
132
+ print(" predictions.json - User's predictions array (same length/order as ground truth)")
133
  print(" ground_truth.json - Server's ground truth (struc_info, GPT responses)")
134
  print(" output.json - Merged output for evaluation")
135
+ print()
136
+ print("Note: Predictions and ground truth are merged by array index (0-based).")
137
  sys.exit(1)
138
 
139
  predictions_file = sys.argv[1]