is-it-max / functions.py
paddeh's picture
visualise-segmentation (#1)
9073e25 verified
raw
history blame contribute delete
820 Bytes
import os
import json
def import_class_labels(model_path):
"""Imports class labels from the classes.json file, ensuring correct sorting."""
classes_file_path = os.path.join(model_path, "classes.json")
with open(classes_file_path, "r") as f:
class_data = json.load(f)
# Get class names and their original indices
class_names = class_data["class_names"]
class_to_idx = class_data["class_to_idx"]
# Create a list of (index, class_name) tuples
idx_class_pairs = [(idx, class_name) for class_name, idx in class_to_idx.items()]
# Sort the list by index to ensure the correct order
idx_class_pairs.sort(key=lambda item: item[0])
# Extract the sorted class names
sorted_class_names = [class_name for _, class_name in idx_class_pairs]
return sorted_class_names