Spaces:
Sleeping
Sleeping
File size: 6,720 Bytes
e4aef33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
"""
Supervisely Parser Script
This script parses Supervisely projects and converts them to a format
suitable for training segmentation models. It extracts class information,
creates train/validation splits, and converts annotations to indexed
color masks.
"""
import os
import json
import random
import shutil
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
try:
import supervisely as sly
from supervisely import Annotation
except ImportError as e:
print(f"Failed to import supervisely: {e}")
print(
"Please ensure that the 'supervisely' package is installed and "
"compatible with your environment."
)
raise
def extract_class_info(project, output_dir):
"""Extract class information from the project metadata."""
id2label = {}
id2color = {}
for obj in project.meta.obj_classes:
id_str, _, label = obj.name.partition(". ")
if not label or not id_str.isdigit():
continue
index = int(id_str) - 1
id2label[index] = label
id2color[index] = obj.color
# Save class mappings
with open(f"{output_dir}/id2label.json", "w") as f:
json.dump(id2label, f, sort_keys=True, indent=2)
with open(f"{output_dir}/id2color.json", "w") as f:
json.dump(id2color, f, sort_keys=True, indent=2)
label2id = {v: k for k, v in id2label.items()}
return id2label, id2color, label2id
def create_output_directories(output_dir):
"""Create necessary output directories."""
os.makedirs(f"{output_dir}/images/training", exist_ok=True)
os.makedirs(f"{output_dir}/annotations/training", exist_ok=True)
os.makedirs(f"{output_dir}/images/validation", exist_ok=True)
os.makedirs(f"{output_dir}/annotations/validation", exist_ok=True)
def calculate_split_counts(datasets, train_ratio=0.8):
"""Calculate the number of items for training and validation."""
total_items = 0
for dataset in datasets:
total_items += len(dataset.get_items_names())
train_items = int(total_items * train_ratio)
val_items = total_items - train_items
print(
f"Total items: {total_items}\n"
f"Train items: {train_items}\n"
f"Validation items: {val_items}"
)
return train_items, val_items
def to_class_index_mask(
annotation: Annotation,
label2id: dict,
mask_path: str,
):
"""Convert annotation to class index mask and save as PNG."""
height, width = annotation.img_size
class_mask = np.zeros((height, width), dtype=np.uint8)
for label in annotation.labels:
class_name = label.obj_class.name.partition(". ")[2]
if class_name not in label2id:
tqdm.write(f"Skipping unrecognized label: {label}")
continue # skip unrecognized labels
class_index = label2id[class_name]
if label.geometry.geometry_name() == "bitmap":
origin = label.geometry.origin
top = origin.row
left = origin.col
bitmap = label.geometry.data # binary numpy array, shape (h, w)
h, w = bitmap.shape
if top + h > height or left + w > width:
tqdm.write(f"Skipping label '{class_name}': size mismatch.")
continue
class_mask[top : top + h, left : left + w][bitmap] = class_index
else:
continue
Image.fromarray(class_mask).save(mask_path)
def process_datasets(
project,
datasets,
output_dir,
label2id,
train_items,
):
"""Process all datasets and create train/validation splits."""
for dataset in tqdm(datasets, desc="Processing datasets"):
items = dataset.get_items_names()
random.shuffle(items)
for i, item in tqdm(
enumerate(items),
desc=f"Processing dataset: {dataset.name}",
total=len(items),
leave=False,
):
# Determine split
split = "training" if i < train_items else "validation"
# Copy images
item_paths = dataset.get_item_paths(item)
img_path = item_paths.img_path
img_filename = os.path.basename(img_path)
dest_path = f"{output_dir}/images/{split}/{img_filename}"
shutil.copy(img_path, dest_path)
# Convert and copy annotations
ann_path = item_paths.ann_path
ann = sly.Annotation.load_json_file(ann_path, project.meta)
mask_filename = f"{os.path.splitext(item)[0]}.png"
mask_path = f"{output_dir}/annotations/{split}/{mask_filename}"
to_class_index_mask(ann, label2id, mask_path)
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Parse Supervisely project and convert to training format"
)
parser.add_argument(
"--project_dir",
type=str,
required=True,
help="Path to the Supervisely project directory",
)
parser.add_argument(
"--output_base_dir",
type=str,
required=True,
help="Base output directory for parsed data",
)
parser.add_argument(
"--train_ratio",
type=float,
default=0.8,
help="Ratio of data to use for training (default: 0.8)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducible splits (default: 42)",
)
return parser.parse_args()
def main():
"""Main function to parse Supervisely project."""
# Parse arguments
args = parse_arguments()
# Set random seed for reproducible splits
random.seed(args.seed)
# Load project
project = sly.Project(args.project_dir, sly.OpenMode.READ)
print(f"Project: {project.name}")
# Setup output directory
output_dir = os.path.join(args.output_base_dir, project.name)
create_output_directories(output_dir)
# Extract class information
id2label, id2color, label2id = extract_class_info(project, output_dir)
# Get datasets and calculate splits
datasets = project.datasets
print(f"Datasets: {len(datasets)}")
train_items, val_items = calculate_split_counts(datasets, args.train_ratio)
# Process datasets
process_datasets(
project,
datasets,
output_dir,
label2id,
train_items,
)
print("Processing completed successfully!")
if __name__ == "__main__":
main()
|