File size: 12,166 Bytes
4a5921f
 
 
 
 
 
 
 
 
 
 
 
e46e8bc
4a5921f
 
4e87a2c
 
ed767be
4a5921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fd7ff0
ed767be
4a5921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed767be
 
 
 
 
4a5921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed767be
 
 
 
4a5921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed767be
 
 
 
 
4a5921f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e5ff65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fd7ff0
 
 
 
 
0e5ff65
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import ast
import functools
import json
import multiprocessing
from pathlib import Path
from typing import Dict, List, Tuple
from datasets import load_dataset
import numpy as np
import pandas as pd
from scipy.spatial import KDTree
from skimage.draw import line_aa
from skimage.draw import line as sk_line
from stqdm import stqdm as tqdm  # Display tqdm progress bar in Streamlit app


PENALTY_SCORES_PATH = "./src/score_calculation/category_penalty.tsv"
M2F_CONFIG_PATH = "./src/score_calculation/mask2former_config.json"
BAD_SCORE_THRESHOLD = 3234.75


@functools.lru_cache(maxsize=4)
def create_penalty_lookup(embodiment: str) -> Dict[int, float]:
    """Creates a direct mapping from a category ID (`label_id`) to its penalty factor."""

    # Load fixed penalty values
    penalty_values_df = pd.read_csv(PENALTY_SCORES_PATH, sep="\t")

    # Load Mask2Former mapping from IDs to labels
    with open(M2F_CONFIG_PATH, "r") as f:
        config = json.load(f)
    id2label = {int(k): v for k, v in config["id2label"].items()}

    label_id_to_penalty = {}
    for label_id, category_name in id2label.items():
        
        # Look up the penalty value
        row = penalty_values_df[
            penalty_values_df["category"] == category_name
        ]
        penalty_value = float(row.iloc[0][embodiment]) * 0.8  # Adjust scale
        label_id_to_penalty[label_id] = penalty_value

    return label_id_to_penalty

def rasterize_gt_trace(
    gt_trace: List[List[float]], height: int, width: int
) -> np.ndarray:
    """Converts a line trace into a dense array of pixel coordinates."""

    gt_trace_np = np.array(gt_trace)
    gt_line_pixels = []
    if len(gt_trace_np) > 1:
        for i in range(len(gt_trace_np) - 1):
            p1, p2 = gt_trace_np[i], gt_trace_np[i + 1]
            r0, c0, r1, c1 = (
                int(round(p1[1])),
                int(round(p1[0])),
                int(round(p2[1])),
                int(round(p2[0])),
            )
            rr, cc, _ = line_aa(r0, c0, r1, c1)
            valid = (rr >= 0) & (rr < height) & (cc >= 0) & (cc < width)
            gt_line_pixels.extend(zip(rr[valid], cc[valid]))
    elif len(gt_trace_np) == 1:
        r, c = int(round(gt_trace_np[0][1])), int(round(gt_trace_np[0][0]))
        if 0 <= r < height and 0 <= c < width:
            gt_line_pixels.append((r, c))

    return np.array(gt_line_pixels)

def create_penalty_mask(
    segmentation_mask: np.ndarray,
    gt_trace: List[List[float]],
    embodiment: str,
    distance_threshold: float = 35,
) -> np.ndarray:

    # Initialize mask with default no penalty
    height, width = segmentation_mask.shape
    penalty_mask = np.full((height, width), 0, dtype=float)

    # Create a KDTree from ground truth pixels for efficient distance queries
    gt_line_pixels = rasterize_gt_trace(gt_trace, height, width)
    gt_tree = KDTree(gt_line_pixels)

    # Create a more efficient lookup for segment info and penalty values
    label_id_to_penalty = create_penalty_lookup(embodiment)

    # Get label IDs for all pixels
    all_label_ids = segmentation_mask.ravel()

    # Identify pixels that belong to undesired segments
    undesired_mask = np.isin(all_label_ids, list(label_id_to_penalty.keys()))
    undesired_indices = np.where(undesired_mask)[0]
    if undesired_indices.size == 0:
        return penalty_mask

    # Map indices to coordinates
    rows, cols = np.unravel_index(undesired_indices, (height, width))
    undesired_coords = np.vstack((rows, cols)).T

    # Perform a single batch query for distances for all undesired pixels
    distances, _ = gt_tree.query(undesired_coords)

    # Filter for pixels that are beyond the distance threshold
    coords_to_penalize = undesired_coords[distances > distance_threshold]

    if coords_to_penalize.size > 0:
        # Apply penalties
        rows_pen, cols_pen = coords_to_penalize[:, 0], coords_to_penalize[:, 1]
        label_ids_to_penalize = segmentation_mask[rows_pen, cols_pen]
        penalties = np.vectorize(label_id_to_penalty.get)(label_ids_to_penalize)
        penalty_mask[rows_pen, cols_pen] = penalties

    return penalty_mask

def resample_to_match_length(
    trace_1: np.ndarray, trace_2: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:

    if len(trace_1) == 0 or len(trace_2) == 0:
        raise ValueError("One of the traces is empty")
    if len(trace_1) == len(trace_2):
        return trace_1, trace_2
    elif len(trace_1) > len(trace_2):
        longer, shorter = (trace_1, trace_2)
    else:
        shorter, longer = (trace_1, trace_2)
    if len(shorter) == 1:
        return shorter * len(longer), longer

    # Parameterize shorter trajectory by cumulative distance
    dists = np.cumsum(
        [0]
        + [np.linalg.norm(shorter[i] - shorter[i - 1]) for i in range(1, len(shorter))]
    )
    dists = dists / dists[-1]  # Normalize to [0,1]

    # Create new parameter values matching longer trajectory length
    new_params = np.linspace(0, 1, len(longer))

    # Interpolate x and y coordinates separately
    new_x = np.interp(new_params, dists, shorter[:, 0])
    new_y = np.interp(new_params, dists, shorter[:, 1])
    shorter = np.column_stack([new_x, new_y])

    if len(trace_1) > len(trace_2):
        return longer, shorter
    else:
        return shorter, longer

def calculate_semantic_penalty(
    prediction: np.ndarray, penalty_mask: np.ndarray
) -> List[float]:

    penalties = []
    for i in range(len(prediction) - 1):
        x1, y1 = int(round(prediction[i][0])), int(round(prediction[i][1]))
        x2, y2 = int(round(prediction[i + 1][0])), int(round(prediction[i + 1][1]))

        # Use scikit-image's optimized line drawing algorithm
        rr, cc = sk_line(y1, x1, y2, x2)

        # Access mask using (y, x) coordinates
        height, width = penalty_mask.shape
        valid_indices = (rr >= 0) & (rr < height) & (cc >= 0) & (cc < width)
        penalties.extend(penalty_mask[rr[valid_indices], cc[valid_indices]].tolist())

    return np.mean(penalties)

def calculate_fde(prediction: np.ndarray, ground_truth: np.ndarray):

    return np.linalg.norm(prediction[-1] - ground_truth[-1])

def calculate_dtw(prediction: np.ndarray, ground_truth: np.ndarray):

    # Create cost matrix
    n, m = len(prediction), len(ground_truth)
    cost_matrix = np.full((n + 1, m + 1), np.inf)
    cost_matrix[0, 0] = 0

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            euclidean_distance = np.linalg.norm(prediction[i - 1] - ground_truth[j - 1])

            # Find the minimum from the three possible previous cells
            min_prev_cost = min(
                cost_matrix[i - 1, j],  # Insertion
                cost_matrix[i, j - 1],  # Deletion
                cost_matrix[i - 1, j - 1],  # Match
            )

            cost_matrix[i, j] = euclidean_distance + min_prev_cost

    return cost_matrix[n, m]

def normalize_score(score: float) -> float:

    # Normalize score so that a perferct score is at 100 and a score worse than the avg. performance of predicting a vertical line through the center is < 0
    return (BAD_SCORE_THRESHOLD - score) / BAD_SCORE_THRESHOLD * 100

def score(
    prediction: List[List[float]],
    ground_truths: List[List[List[float]]],
    segmentation_mask: np.ndarray,
    embodiment: str,
):
    
    # Iterate over all ground-truths
    scores = []
    for ground_truth in ground_truths:
        
        # Create penalty mask
        penalty_mask = create_penalty_mask(segmentation_mask, ground_truth, embodiment)

        # Convert to NumPy
        prediction, ground_truth = np.array(prediction), np.array(ground_truth)

        # Resample if necessary
        if len(prediction) != len(ground_truth):
            prediction, ground_truth = resample_to_match_length(prediction, ground_truth)

        # Calculate score function
        sem_penalty = calculate_semantic_penalty(prediction, penalty_mask)
        fde = calculate_fde(prediction, ground_truth)
        dtw = calculate_dtw(prediction, ground_truth)
        scores.append(dtw + fde + sem_penalty)
        
    # Select the best score
    score = min(scores)

    # Normalize
    return normalize_score(score)

def _initialize_worker(results_path, dataset_id, split_name):

    global _results_df, _get_sample
    
    # Load data
    _results_df = pd.read_csv(results_path, sep="\t")
    data_split = load_dataset(dataset_id)[split_name]

    # Build lookup index for efficient sample retrieval
    id_to_index = {sample_id: i for i, sample_id in enumerate(data_split["sample_id"])}
    
    def get_sample(sample_id):
        idx = id_to_index[sample_id]
        return data_split[idx]
    
    _get_sample = get_sample

def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]:

    results = []
    for idx in indices:
        row = _results_df.loc[idx]
        
        # Extract prediction and ground truth
        sample = _get_sample(row["sample_id"])
        embodiment = row["embodiment"]
        prediction = json.loads(row["prediction"])
        ground_truths = sample["ground_truth"][row["embodiment"]]
        segmentation_mask = np.array(sample["segmentation_mask"])
        
        # Check that ground-truth is not hidden as it is for the test split
        if ground_truths is None:
            raise ValueError(f"The sample {sample} has hidden ground-truths")
        
        # Skip invalid predictions
        if len(prediction) == 0:
            results.append((idx, np.nan))
            continue

        # Calculate score
        s = score(prediction, ground_truths, segmentation_mask, embodiment)
        results.append((idx, s))
    
    return results

def score_predictions_parallel(results_path, dataset_id, split_name, num_processes=4):

    # Load results file
    results_df = pd.read_csv(results_path, sep='\t')

    # Split work into chunks
    total_rows = len(results_df)
    chunk_size = (total_rows + num_processes - 1) // num_processes  # Ceiling division
    indices_chunks = [
        list(range(i, min(i + chunk_size, total_rows)))
        for i in range(0, total_rows, chunk_size)
    ]
        
    # Process chunks in parallel
    scored_df = results_df.copy()
    scored_df["score"] = np.nan
    with multiprocessing.Pool(
        processes=num_processes,
        initializer=_initialize_worker,
        initargs=(
            results_path,
            dataset_id,
            split_name,
        ),
    ) as pool:
        with tqdm(total=total_rows, desc="Scoring predictions") as pbar:
            for chunk_results in pool.imap_unordered(_score_chunk, indices_chunks):
                for idx, s in chunk_results:
                    scored_df.at[idx, "score"] = s
                pbar.update(len(chunk_results))
    
    return scored_df
    
def score_predictions(results_df, dataset):

    # Build a lookup dictionary for efficient sample retrieval by ID
    id_to_index = {sample_id: i for i, sample_id in enumerate(dataset["sample_id"])}

    # Iterate over each row in the results DataFrame with a progress bar
    scores = []
    for _, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Scoring predictions"):

        # Get the corresponding ground truth sample using the lookup
        sample_id = row["sample_id"]
        sample = dataset[id_to_index[sample_id]]

        # Extract necessary data for scoring
        embodiment = row["embodiment"]
        prediction = json.loads(row["prediction"])
        ground_truths = sample["ground_truth"][row["embodiment"]]
        segmentation_mask = np.array(sample["segmentation_mask"])
        
        if ground_truths is None:
            raise ValueError(f"The sample {sample} has hidden ground-truths")

        # Skip invalid predictions
        if len(prediction) == 0:
            scores.append(np.nan)
            continue
        
        # Calculate the score and append it to the list
        s = score(prediction, ground_truths, segmentation_mask, embodiment)
        scores.append(s)
    
    # Create a copy and add the new 'score' column
    scored_df = results_df.copy()
    scored_df["score"] = scores
    
    return scored_df