HomeSenseTest / scripts /supervisely_parser.py
YusufMesbah's picture
Implement initial version of SegFormer training pipeline with dataset parsing and model training functionalities. Added Dockerfile for environment setup, utility scripts for parsing and training, and Gradio interface for user interaction.
e4aef33
"""
Supervisely Parser Script
This script parses Supervisely projects and converts them to a format
suitable for training segmentation models. It extracts class information,
creates train/validation splits, and converts annotations to indexed
color masks.
"""
import os
import json
import random
import shutil
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
try:
import supervisely as sly
from supervisely import Annotation
except ImportError as e:
print(f"Failed to import supervisely: {e}")
print(
"Please ensure that the 'supervisely' package is installed and "
"compatible with your environment."
)
raise
def extract_class_info(project, output_dir):
"""Extract class information from the project metadata."""
id2label = {}
id2color = {}
for obj in project.meta.obj_classes:
id_str, _, label = obj.name.partition(". ")
if not label or not id_str.isdigit():
continue
index = int(id_str) - 1
id2label[index] = label
id2color[index] = obj.color
# Save class mappings
with open(f"{output_dir}/id2label.json", "w") as f:
json.dump(id2label, f, sort_keys=True, indent=2)
with open(f"{output_dir}/id2color.json", "w") as f:
json.dump(id2color, f, sort_keys=True, indent=2)
label2id = {v: k for k, v in id2label.items()}
return id2label, id2color, label2id
def create_output_directories(output_dir):
"""Create necessary output directories."""
os.makedirs(f"{output_dir}/images/training", exist_ok=True)
os.makedirs(f"{output_dir}/annotations/training", exist_ok=True)
os.makedirs(f"{output_dir}/images/validation", exist_ok=True)
os.makedirs(f"{output_dir}/annotations/validation", exist_ok=True)
def calculate_split_counts(datasets, train_ratio=0.8):
"""Calculate the number of items for training and validation."""
total_items = 0
for dataset in datasets:
total_items += len(dataset.get_items_names())
train_items = int(total_items * train_ratio)
val_items = total_items - train_items
print(
f"Total items: {total_items}\n"
f"Train items: {train_items}\n"
f"Validation items: {val_items}"
)
return train_items, val_items
def to_class_index_mask(
annotation: Annotation,
label2id: dict,
mask_path: str,
):
"""Convert annotation to class index mask and save as PNG."""
height, width = annotation.img_size
class_mask = np.zeros((height, width), dtype=np.uint8)
for label in annotation.labels:
class_name = label.obj_class.name.partition(". ")[2]
if class_name not in label2id:
tqdm.write(f"Skipping unrecognized label: {label}")
continue # skip unrecognized labels
class_index = label2id[class_name]
if label.geometry.geometry_name() == "bitmap":
origin = label.geometry.origin
top = origin.row
left = origin.col
bitmap = label.geometry.data # binary numpy array, shape (h, w)
h, w = bitmap.shape
if top + h > height or left + w > width:
tqdm.write(f"Skipping label '{class_name}': size mismatch.")
continue
class_mask[top : top + h, left : left + w][bitmap] = class_index
else:
continue
Image.fromarray(class_mask).save(mask_path)
def process_datasets(
project,
datasets,
output_dir,
label2id,
train_items,
):
"""Process all datasets and create train/validation splits."""
for dataset in tqdm(datasets, desc="Processing datasets"):
items = dataset.get_items_names()
random.shuffle(items)
for i, item in tqdm(
enumerate(items),
desc=f"Processing dataset: {dataset.name}",
total=len(items),
leave=False,
):
# Determine split
split = "training" if i < train_items else "validation"
# Copy images
item_paths = dataset.get_item_paths(item)
img_path = item_paths.img_path
img_filename = os.path.basename(img_path)
dest_path = f"{output_dir}/images/{split}/{img_filename}"
shutil.copy(img_path, dest_path)
# Convert and copy annotations
ann_path = item_paths.ann_path
ann = sly.Annotation.load_json_file(ann_path, project.meta)
mask_filename = f"{os.path.splitext(item)[0]}.png"
mask_path = f"{output_dir}/annotations/{split}/{mask_filename}"
to_class_index_mask(ann, label2id, mask_path)
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Parse Supervisely project and convert to training format"
)
parser.add_argument(
"--project_dir",
type=str,
required=True,
help="Path to the Supervisely project directory",
)
parser.add_argument(
"--output_base_dir",
type=str,
required=True,
help="Base output directory for parsed data",
)
parser.add_argument(
"--train_ratio",
type=float,
default=0.8,
help="Ratio of data to use for training (default: 0.8)",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducible splits (default: 42)",
)
return parser.parse_args()
def main():
"""Main function to parse Supervisely project."""
# Parse arguments
args = parse_arguments()
# Set random seed for reproducible splits
random.seed(args.seed)
# Load project
project = sly.Project(args.project_dir, sly.OpenMode.READ)
print(f"Project: {project.name}")
# Setup output directory
output_dir = os.path.join(args.output_base_dir, project.name)
create_output_directories(output_dir)
# Extract class information
id2label, id2color, label2id = extract_class_info(project, output_dir)
# Get datasets and calculate splits
datasets = project.datasets
print(f"Datasets: {len(datasets)}")
train_items, val_items = calculate_split_counts(datasets, args.train_ratio)
# Process datasets
process_datasets(
project,
datasets,
output_dir,
label2id,
train_items,
)
print("Processing completed successfully!")
if __name__ == "__main__":
main()