Update src/score_calculation/score.py
Browse files
src/score_calculation/score.py
CHANGED
|
@@ -117,6 +117,8 @@ def resample_to_match_length(
|
|
| 117 |
trace_1: np.ndarray, trace_2: np.ndarray
|
| 118 |
) -> Tuple[np.ndarray, np.ndarray]:
|
| 119 |
|
|
|
|
|
|
|
| 120 |
if len(trace_1) == len(trace_2):
|
| 121 |
return trace_1, trace_2
|
| 122 |
elif len(trace_1) > len(trace_2):
|
|
@@ -308,11 +310,6 @@ def score_predictions(results_df, dataset):
|
|
| 308 |
scores = []
|
| 309 |
for _, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Scoring predictions"):
|
| 310 |
|
| 311 |
-
# Skip invalid predictions
|
| 312 |
-
if len(row["prediction"]) == 0:
|
| 313 |
-
scores.append(np.nan)
|
| 314 |
-
continue
|
| 315 |
-
|
| 316 |
# Get the corresponding ground truth sample using the lookup
|
| 317 |
sample_id = row["sample_id"]
|
| 318 |
sample = dataset[id_to_index[sample_id]]
|
|
@@ -326,6 +323,11 @@ def score_predictions(results_df, dataset):
|
|
| 326 |
if ground_truths is None:
|
| 327 |
raise ValueError(f"The sample {sample} has hidden ground-truths")
|
| 328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
# Calculate the score and append it to the list
|
| 330 |
s = score(prediction, ground_truths, segmentation_mask, embodiment)
|
| 331 |
scores.append(s)
|
|
|
|
| 117 |
trace_1: np.ndarray, trace_2: np.ndarray
|
| 118 |
) -> Tuple[np.ndarray, np.ndarray]:
|
| 119 |
|
| 120 |
+
if len(trace_1) == 0 or len(trace_2) == 0:
|
| 121 |
+
raise ValueError("One of the trace is empty")
|
| 122 |
if len(trace_1) == len(trace_2):
|
| 123 |
return trace_1, trace_2
|
| 124 |
elif len(trace_1) > len(trace_2):
|
|
|
|
| 310 |
scores = []
|
| 311 |
for _, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Scoring predictions"):
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
# Get the corresponding ground truth sample using the lookup
|
| 314 |
sample_id = row["sample_id"]
|
| 315 |
sample = dataset[id_to_index[sample_id]]
|
|
|
|
| 323 |
if ground_truths is None:
|
| 324 |
raise ValueError(f"The sample {sample} has hidden ground-truths")
|
| 325 |
|
| 326 |
+
# Skip invalid predictions
|
| 327 |
+
if len(prediction) == 0:
|
| 328 |
+
scores.append(np.nan)
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
# Calculate the score and append it to the list
|
| 332 |
s = score(prediction, ground_truths, segmentation_mask, embodiment)
|
| 333 |
scores.append(s)
|