import ast import functools import json import multiprocessing from pathlib import Path from typing import Dict, List, Tuple from datasets import load_dataset import numpy as np import pandas as pd from scipy.spatial import KDTree from skimage.draw import line_aa from skimage.draw import line as sk_line from stqdm import stqdm as tqdm # Display tqdm progress bar in Streamlit app PENALTY_SCORES_PATH = "./src/score_calculation/category_penalty.tsv" M2F_CONFIG_PATH = "./src/score_calculation/mask2former_config.json" BAD_SCORE_THRESHOLD = 3234.75 @functools.lru_cache(maxsize=4) def create_penalty_lookup(embodiment: str) -> Dict[int, float]: """Creates a direct mapping from a category ID (`label_id`) to its penalty factor.""" # Load fixed penalty values penalty_values_df = pd.read_csv(PENALTY_SCORES_PATH, sep="\t") # Load Mask2Former mapping from IDs to labels with open(M2F_CONFIG_PATH, "r") as f: config = json.load(f) id2label = {int(k): v for k, v in config["id2label"].items()} label_id_to_penalty = {} for label_id, category_name in id2label.items(): # Look up the penalty value row = penalty_values_df[ penalty_values_df["category"] == category_name ] penalty_value = float(row.iloc[0][embodiment]) * 0.8 # Adjust scale label_id_to_penalty[label_id] = penalty_value return label_id_to_penalty def rasterize_gt_trace( gt_trace: List[List[float]], height: int, width: int ) -> np.ndarray: """Converts a line trace into a dense array of pixel coordinates.""" gt_trace_np = np.array(gt_trace) gt_line_pixels = [] if len(gt_trace_np) > 1: for i in range(len(gt_trace_np) - 1): p1, p2 = gt_trace_np[i], gt_trace_np[i + 1] r0, c0, r1, c1 = ( int(round(p1[1])), int(round(p1[0])), int(round(p2[1])), int(round(p2[0])), ) rr, cc, _ = line_aa(r0, c0, r1, c1) valid = (rr >= 0) & (rr < height) & (cc >= 0) & (cc < width) gt_line_pixels.extend(zip(rr[valid], cc[valid])) elif len(gt_trace_np) == 1: r, c = int(round(gt_trace_np[0][1])), int(round(gt_trace_np[0][0])) if 0 <= r < height and 0 <= c < width: gt_line_pixels.append((r, c)) return np.array(gt_line_pixels) def create_penalty_mask( segmentation_mask: np.ndarray, gt_trace: List[List[float]], embodiment: str, distance_threshold: float = 35, ) -> np.ndarray: # Initialize mask with default no penalty height, width = segmentation_mask.shape penalty_mask = np.full((height, width), 0, dtype=float) # Create a KDTree from ground truth pixels for efficient distance queries gt_line_pixels = rasterize_gt_trace(gt_trace, height, width) gt_tree = KDTree(gt_line_pixels) # Create a more efficient lookup for segment info and penalty values label_id_to_penalty = create_penalty_lookup(embodiment) # Get label IDs for all pixels all_label_ids = segmentation_mask.ravel() # Identify pixels that belong to undesired segments undesired_mask = np.isin(all_label_ids, list(label_id_to_penalty.keys())) undesired_indices = np.where(undesired_mask)[0] if undesired_indices.size == 0: return penalty_mask # Map indices to coordinates rows, cols = np.unravel_index(undesired_indices, (height, width)) undesired_coords = np.vstack((rows, cols)).T # Perform a single batch query for distances for all undesired pixels distances, _ = gt_tree.query(undesired_coords) # Filter for pixels that are beyond the distance threshold coords_to_penalize = undesired_coords[distances > distance_threshold] if coords_to_penalize.size > 0: # Apply penalties rows_pen, cols_pen = coords_to_penalize[:, 0], coords_to_penalize[:, 1] label_ids_to_penalize = segmentation_mask[rows_pen, cols_pen] penalties = np.vectorize(label_id_to_penalty.get)(label_ids_to_penalize) penalty_mask[rows_pen, cols_pen] = penalties return penalty_mask def resample_to_match_length( trace_1: np.ndarray, trace_2: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: if len(trace_1) == 0 or len(trace_2) == 0: raise ValueError("One of the traces is empty") if len(trace_1) == len(trace_2): return trace_1, trace_2 elif len(trace_1) > len(trace_2): longer, shorter = (trace_1, trace_2) else: shorter, longer = (trace_1, trace_2) if len(shorter) == 1: return shorter * len(longer), longer # Parameterize shorter trajectory by cumulative distance dists = np.cumsum( [0] + [np.linalg.norm(shorter[i] - shorter[i - 1]) for i in range(1, len(shorter))] ) dists = dists / dists[-1] # Normalize to [0,1] # Create new parameter values matching longer trajectory length new_params = np.linspace(0, 1, len(longer)) # Interpolate x and y coordinates separately new_x = np.interp(new_params, dists, shorter[:, 0]) new_y = np.interp(new_params, dists, shorter[:, 1]) shorter = np.column_stack([new_x, new_y]) if len(trace_1) > len(trace_2): return longer, shorter else: return shorter, longer def calculate_semantic_penalty( prediction: np.ndarray, penalty_mask: np.ndarray ) -> List[float]: penalties = [] for i in range(len(prediction) - 1): x1, y1 = int(round(prediction[i][0])), int(round(prediction[i][1])) x2, y2 = int(round(prediction[i + 1][0])), int(round(prediction[i + 1][1])) # Use scikit-image's optimized line drawing algorithm rr, cc = sk_line(y1, x1, y2, x2) # Access mask using (y, x) coordinates height, width = penalty_mask.shape valid_indices = (rr >= 0) & (rr < height) & (cc >= 0) & (cc < width) penalties.extend(penalty_mask[rr[valid_indices], cc[valid_indices]].tolist()) return np.mean(penalties) def calculate_fde(prediction: np.ndarray, ground_truth: np.ndarray): return np.linalg.norm(prediction[-1] - ground_truth[-1]) def calculate_dtw(prediction: np.ndarray, ground_truth: np.ndarray): # Create cost matrix n, m = len(prediction), len(ground_truth) cost_matrix = np.full((n + 1, m + 1), np.inf) cost_matrix[0, 0] = 0 for i in range(1, n + 1): for j in range(1, m + 1): euclidean_distance = np.linalg.norm(prediction[i - 1] - ground_truth[j - 1]) # Find the minimum from the three possible previous cells min_prev_cost = min( cost_matrix[i - 1, j], # Insertion cost_matrix[i, j - 1], # Deletion cost_matrix[i - 1, j - 1], # Match ) cost_matrix[i, j] = euclidean_distance + min_prev_cost return cost_matrix[n, m] def normalize_score(score: float) -> float: # Normalize score so that a perferct score is at 100 and a score worse than the avg. performance of predicting a vertical line through the center is < 0 return (BAD_SCORE_THRESHOLD - score) / BAD_SCORE_THRESHOLD * 100 def score( prediction: List[List[float]], ground_truths: List[List[List[float]]], segmentation_mask: np.ndarray, embodiment: str, ): # Iterate over all ground-truths scores = [] for ground_truth in ground_truths: # Create penalty mask penalty_mask = create_penalty_mask(segmentation_mask, ground_truth, embodiment) # Convert to NumPy prediction, ground_truth = np.array(prediction), np.array(ground_truth) # Resample if necessary if len(prediction) != len(ground_truth): prediction, ground_truth = resample_to_match_length(prediction, ground_truth) # Calculate score function sem_penalty = calculate_semantic_penalty(prediction, penalty_mask) fde = calculate_fde(prediction, ground_truth) dtw = calculate_dtw(prediction, ground_truth) scores.append(dtw + fde + sem_penalty) # Select the best score score = min(scores) # Normalize return normalize_score(score) def _initialize_worker(results_path, dataset_id, split_name): global _results_df, _get_sample # Load data _results_df = pd.read_csv(results_path, sep="\t") data_split = load_dataset(dataset_id)[split_name] # Build lookup index for efficient sample retrieval id_to_index = {sample_id: i for i, sample_id in enumerate(data_split["sample_id"])} def get_sample(sample_id): idx = id_to_index[sample_id] return data_split[idx] _get_sample = get_sample def _score_chunk(indices: List[int]) -> List[Tuple[int, float]]: results = [] for idx in indices: row = _results_df.loc[idx] # Extract prediction and ground truth sample = _get_sample(row["sample_id"]) embodiment = row["embodiment"] prediction = json.loads(row["prediction"]) ground_truths = sample["ground_truth"][row["embodiment"]] segmentation_mask = np.array(sample["segmentation_mask"]) # Check that ground-truth is not hidden as it is for the test split if ground_truths is None: raise ValueError(f"The sample {sample} has hidden ground-truths") # Skip invalid predictions if len(prediction) == 0: results.append((idx, np.nan)) continue # Calculate score s = score(prediction, ground_truths, segmentation_mask, embodiment) results.append((idx, s)) return results def score_predictions_parallel(results_path, dataset_id, split_name, num_processes=4): # Load results file results_df = pd.read_csv(results_path, sep='\t') # Split work into chunks total_rows = len(results_df) chunk_size = (total_rows + num_processes - 1) // num_processes # Ceiling division indices_chunks = [ list(range(i, min(i + chunk_size, total_rows))) for i in range(0, total_rows, chunk_size) ] # Process chunks in parallel scored_df = results_df.copy() scored_df["score"] = np.nan with multiprocessing.Pool( processes=num_processes, initializer=_initialize_worker, initargs=( results_path, dataset_id, split_name, ), ) as pool: with tqdm(total=total_rows, desc="Scoring predictions") as pbar: for chunk_results in pool.imap_unordered(_score_chunk, indices_chunks): for idx, s in chunk_results: scored_df.at[idx, "score"] = s pbar.update(len(chunk_results)) return scored_df def score_predictions(results_df, dataset): # Build a lookup dictionary for efficient sample retrieval by ID id_to_index = {sample_id: i for i, sample_id in enumerate(dataset["sample_id"])} # Iterate over each row in the results DataFrame with a progress bar scores = [] for _, row in tqdm(results_df.iterrows(), total=len(results_df), desc="Scoring predictions"): # Get the corresponding ground truth sample using the lookup sample_id = row["sample_id"] sample = dataset[id_to_index[sample_id]] # Extract necessary data for scoring embodiment = row["embodiment"] prediction = json.loads(row["prediction"]) ground_truths = sample["ground_truth"][row["embodiment"]] segmentation_mask = np.array(sample["segmentation_mask"]) if ground_truths is None: raise ValueError(f"The sample {sample} has hidden ground-truths") # Skip invalid predictions if len(prediction) == 0: scores.append(np.nan) continue # Calculate the score and append it to the list s = score(prediction, ground_truths, segmentation_mask, embodiment) scores.append(s) # Create a copy and add the new 'score' column scored_df = results_df.copy() scored_df["score"] = scores return scored_df