TimWindecker commited on
Commit
0e5ff65
·
verified ·
1 Parent(s): 641159b

Update src/score_calculation/score.py

Browse files
Files changed (1) hide show
  1. src/score_calculation/score.py +36 -10
src/score_calculation/score.py CHANGED
@@ -41,7 +41,6 @@ def create_penalty_lookup(embodiment: str) -> Dict[int, float]:
41
 
42
  return label_id_to_penalty
43
 
44
-
45
  def rasterize_gt_trace(
46
  gt_trace: List[List[float]], height: int, width: int
47
  ) -> np.ndarray:
@@ -68,7 +67,6 @@ def rasterize_gt_trace(
68
 
69
  return np.array(gt_line_pixels)
70
 
71
-
72
  def create_penalty_mask(
73
  segmentation_mask: np.ndarray,
74
  gt_trace: List[List[float]],
@@ -115,7 +113,6 @@ def create_penalty_mask(
115
 
116
  return penalty_mask
117
 
118
-
119
  def resample_to_match_length(
120
  trace_1: np.ndarray, trace_2: np.ndarray
121
  ) -> Tuple[np.ndarray, np.ndarray]:
@@ -149,7 +146,6 @@ def resample_to_match_length(
149
  else:
150
  return shorter, longer
151
 
152
-
153
  def calculate_semantic_penalty(
154
  prediction: np.ndarray, penalty_mask: np.ndarray
155
  ) -> List[float]:
@@ -169,12 +165,10 @@ def calculate_semantic_penalty(
169
 
170
  return np.mean(penalties)
171
 
172
-
173
  def calculate_fde(prediction: np.ndarray, ground_truth: np.ndarray):
174
 
175
  return np.linalg.norm(prediction[-1] - ground_truth[-1])
176
 
177
-
178
  def calculate_dtw(prediction: np.ndarray, ground_truth: np.ndarray):
179
 
180
  # Create cost matrix
@@ -197,7 +191,6 @@ def calculate_dtw(prediction: np.ndarray, ground_truth: np.ndarray):
197
 
198
  return cost_matrix[n, m]
199
 
200
-
201
  def score(
202
  prediction: List[List[float]],
203
  ground_truths: List[List[List[float]]],
@@ -227,7 +220,6 @@ def score(
227
 
228
  # Select the best score
229
  return min(scores)
230
-
231
 
232
  def _initialize_worker(results_path, dataset_id, split_name):
233
 
@@ -246,7 +238,6 @@ def _initialize_worker(results_path, dataset_id, split_name):
246
 
247
  _get_sample = get_sample
248
 
249
-
250
  def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]:
251
 
252
  results = []
@@ -275,7 +266,6 @@ def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]:
275
 
276
  return results
277
 
278
-
279
  def score_predictions_parallel(results_path, dataset_id, split_name, num_processes=4):
280
 
281
  # Load results file
@@ -309,3 +299,39 @@ def score_predictions_parallel(results_path, dataset_id, split_name, num_process
309
 
310
  return scored_df
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  return label_id_to_penalty
43
 
 
44
  def rasterize_gt_trace(
45
  gt_trace: List[List[float]], height: int, width: int
46
  ) -> np.ndarray:
 
67
 
68
  return np.array(gt_line_pixels)
69
 
 
70
  def create_penalty_mask(
71
  segmentation_mask: np.ndarray,
72
  gt_trace: List[List[float]],
 
113
 
114
  return penalty_mask
115
 
 
116
  def resample_to_match_length(
117
  trace_1: np.ndarray, trace_2: np.ndarray
118
  ) -> Tuple[np.ndarray, np.ndarray]:
 
146
  else:
147
  return shorter, longer
148
 
 
149
  def calculate_semantic_penalty(
150
  prediction: np.ndarray, penalty_mask: np.ndarray
151
  ) -> List[float]:
 
165
 
166
  return np.mean(penalties)
167
 
 
168
  def calculate_fde(prediction: np.ndarray, ground_truth: np.ndarray):
169
 
170
  return np.linalg.norm(prediction[-1] - ground_truth[-1])
171
 
 
172
  def calculate_dtw(prediction: np.ndarray, ground_truth: np.ndarray):
173
 
174
  # Create cost matrix
 
191
 
192
  return cost_matrix[n, m]
193
 
 
194
  def score(
195
  prediction: List[List[float]],
196
  ground_truths: List[List[List[float]]],
 
220
 
221
  # Select the best score
222
  return min(scores)
 
223
 
224
  def _initialize_worker(results_path, dataset_id, split_name):
225
 
 
238
 
239
  _get_sample = get_sample
240
 
 
241
  def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]:
242
 
243
  results = []
 
266
 
267
  return results
268
 
 
269
  def score_predictions_parallel(results_path, dataset_id, split_name, num_processes=4):
270
 
271
  # Load results file
 
299
 
300
  return scored_df
301
 
302
+ def score_predictions(results_df, dataset):
303
+
304
+ # Build a lookup dictionary for efficient sample retrieval by ID
305
+ id_to_index = {sample_id: i for i, sample_id in enumerate(dataset["sample_id"])}
306
+
307
+ # Iterate over each row in the results DataFrame with a progress bar
308
+ scores = []
309
+ for _, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Scoring predictions"):
310
+
311
+ # Skip invalid predictions
312
+ if len(row["prediction"]) == 0:
313
+ scores.append(np.nan)
314
+ continue
315
+
316
+ # Get the corresponding ground truth sample using the lookup
317
+ sample_id = row["sample_id"]
318
+ sample = dataset[id_to_index[sample_id]]
319
+
320
+ # Extract necessary data for scoring
321
+ embodiment = row["embodiment"]
322
+ prediction = json.loads(row["prediction"])
323
+ ground_truths = sample["ground_truth"][row["embodiment"]]
324
+ segmentation_mask = np.array(sample["segmentation_mask"])
325
+
326
+ if ground_truths is None:
327
+ raise ValueError(f"The sample {sample} has hidden ground-truths")
328
+
329
+ # Calculate the score and append it to the list
330
+ s = score(prediction, ground_truths, segmentation_mask, embodiment)
331
+ scores.append(s)
332
+
333
+ # Create a copy and add the new 'score' column
334
+ scored_df = results_df.copy()
335
+ scored_df["score"] = scores
336
+
337
+ return scored_df