Add score normalization
Browse files
src/score_calculation/score.py
CHANGED
|
@@ -15,6 +15,7 @@ from tqdm import tqdm
|
|
| 15 |
|
| 16 |
PENALTY_SCORES_PATH = "./category_penalty.tsv"
|
| 17 |
M2F_CONFIG_PATH = "./mask2former_config.json"
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
@functools.lru_cache(maxsize=4)
|
|
@@ -118,7 +119,7 @@ def resample_to_match_length(
|
|
| 118 |
) -> Tuple[np.ndarray, np.ndarray]:
|
| 119 |
|
| 120 |
if len(trace_1) == 0 or len(trace_2) == 0:
|
| 121 |
-
raise ValueError("One of the
|
| 122 |
if len(trace_1) == len(trace_2):
|
| 123 |
return trace_1, trace_2
|
| 124 |
elif len(trace_1) > len(trace_2):
|
|
@@ -193,6 +194,11 @@ def calculate_dtw(prediction: np.ndarray, ground_truth: np.ndarray):
|
|
| 193 |
|
| 194 |
return cost_matrix[n, m]
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
def score(
|
| 197 |
prediction: List[List[float]],
|
| 198 |
ground_truths: List[List[List[float]]],
|
|
@@ -221,7 +227,10 @@ def score(
|
|
| 221 |
scores.append(dtw + fde + sem_penalty)
|
| 222 |
|
| 223 |
# Select the best score
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
def _initialize_worker(results_path, dataset_id, split_name):
|
| 227 |
|
|
@@ -246,11 +255,6 @@ def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]:
|
|
| 246 |
for idx in indices:
|
| 247 |
row = _results_df.loc[idx]
|
| 248 |
|
| 249 |
-
# Skip invalid predictions
|
| 250 |
-
if len(row["prediction"]) == 0:
|
| 251 |
-
results.append((idx, np.nan))
|
| 252 |
-
continue
|
| 253 |
-
|
| 254 |
# Extract prediction and ground truth
|
| 255 |
sample = _get_sample(row["sample_id"])
|
| 256 |
embodiment = row["embodiment"]
|
|
@@ -261,6 +265,11 @@ def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]:
|
|
| 261 |
# Check that ground-truth is not hidden as it is for the test split
|
| 262 |
if ground_truths is None:
|
| 263 |
raise ValueError(f"The sample {sample} has hidden ground-truths")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
# Calculate score
|
| 266 |
s = score(prediction, ground_truths, segmentation_mask, embodiment)
|
|
|
|
| 15 |
|
| 16 |
PENALTY_SCORES_PATH = "./category_penalty.tsv"
|
| 17 |
M2F_CONFIG_PATH = "./mask2former_config.json"
|
| 18 |
+
BAD_SCORE_THRESHOLD = 3234.75
|
| 19 |
|
| 20 |
|
| 21 |
@functools.lru_cache(maxsize=4)
|
|
|
|
| 119 |
) -> Tuple[np.ndarray, np.ndarray]:
|
| 120 |
|
| 121 |
if len(trace_1) == 0 or len(trace_2) == 0:
|
| 122 |
+
raise ValueError("One of the traces is empty")
|
| 123 |
if len(trace_1) == len(trace_2):
|
| 124 |
return trace_1, trace_2
|
| 125 |
elif len(trace_1) > len(trace_2):
|
|
|
|
| 194 |
|
| 195 |
return cost_matrix[n, m]
|
| 196 |
|
| 197 |
+
def normalize_score(score: float) -> float:
|
| 198 |
+
|
| 199 |
+
# Normalize score so that a perferct score is at 100 and a score worse than the avg. performance of predicting a vertical line through the center is < 0
|
| 200 |
+
return (BAD_SCORE_THRESHOLD - score) / BAD_SCORE_THRESHOLD * 100
|
| 201 |
+
|
| 202 |
def score(
|
| 203 |
prediction: List[List[float]],
|
| 204 |
ground_truths: List[List[List[float]]],
|
|
|
|
| 227 |
scores.append(dtw + fde + sem_penalty)
|
| 228 |
|
| 229 |
# Select the best score
|
| 230 |
+
score = min(scores)
|
| 231 |
+
|
| 232 |
+
# Normalize
|
| 233 |
+
return normalize_score(score)
|
| 234 |
|
| 235 |
def _initialize_worker(results_path, dataset_id, split_name):
|
| 236 |
|
|
|
|
| 255 |
for idx in indices:
|
| 256 |
row = _results_df.loc[idx]
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
# Extract prediction and ground truth
|
| 259 |
sample = _get_sample(row["sample_id"])
|
| 260 |
embodiment = row["embodiment"]
|
|
|
|
| 265 |
# Check that ground-truth is not hidden as it is for the test split
|
| 266 |
if ground_truths is None:
|
| 267 |
raise ValueError(f"The sample {sample} has hidden ground-truths")
|
| 268 |
+
|
| 269 |
+
# Skip invalid predictions
|
| 270 |
+
if len(prediction) == 0:
|
| 271 |
+
results.append((idx, np.nan))
|
| 272 |
+
continue
|
| 273 |
|
| 274 |
# Calculate score
|
| 275 |
s = score(prediction, ground_truths, segmentation_mask, embodiment)
|