Spaces:
Sleeping
Sleeping
Implement initial version of SegFormer training pipeline with dataset parsing and model training functionalities. Added Dockerfile for environment setup, utility scripts for parsing and training, and Gradio interface for user interaction.
e4aef33
| """ | |
| Supervisely Parser Script | |
| This script parses Supervisely projects and converts them to a format | |
| suitable for training segmentation models. It extracts class information, | |
| creates train/validation splits, and converts annotations to indexed | |
| color masks. | |
| """ | |
| import os | |
| import json | |
| import random | |
| import shutil | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| from tqdm import tqdm | |
| try: | |
| import supervisely as sly | |
| from supervisely import Annotation | |
| except ImportError as e: | |
| print(f"Failed to import supervisely: {e}") | |
| print( | |
| "Please ensure that the 'supervisely' package is installed and " | |
| "compatible with your environment." | |
| ) | |
| raise | |
| def extract_class_info(project, output_dir): | |
| """Extract class information from the project metadata.""" | |
| id2label = {} | |
| id2color = {} | |
| for obj in project.meta.obj_classes: | |
| id_str, _, label = obj.name.partition(". ") | |
| if not label or not id_str.isdigit(): | |
| continue | |
| index = int(id_str) - 1 | |
| id2label[index] = label | |
| id2color[index] = obj.color | |
| # Save class mappings | |
| with open(f"{output_dir}/id2label.json", "w") as f: | |
| json.dump(id2label, f, sort_keys=True, indent=2) | |
| with open(f"{output_dir}/id2color.json", "w") as f: | |
| json.dump(id2color, f, sort_keys=True, indent=2) | |
| label2id = {v: k for k, v in id2label.items()} | |
| return id2label, id2color, label2id | |
| def create_output_directories(output_dir): | |
| """Create necessary output directories.""" | |
| os.makedirs(f"{output_dir}/images/training", exist_ok=True) | |
| os.makedirs(f"{output_dir}/annotations/training", exist_ok=True) | |
| os.makedirs(f"{output_dir}/images/validation", exist_ok=True) | |
| os.makedirs(f"{output_dir}/annotations/validation", exist_ok=True) | |
| def calculate_split_counts(datasets, train_ratio=0.8): | |
| """Calculate the number of items for training and validation.""" | |
| total_items = 0 | |
| for dataset in datasets: | |
| total_items += len(dataset.get_items_names()) | |
| train_items = int(total_items * train_ratio) | |
| val_items = total_items - train_items | |
| print( | |
| f"Total items: {total_items}\n" | |
| f"Train items: {train_items}\n" | |
| f"Validation items: {val_items}" | |
| ) | |
| return train_items, val_items | |
| def to_class_index_mask( | |
| annotation: Annotation, | |
| label2id: dict, | |
| mask_path: str, | |
| ): | |
| """Convert annotation to class index mask and save as PNG.""" | |
| height, width = annotation.img_size | |
| class_mask = np.zeros((height, width), dtype=np.uint8) | |
| for label in annotation.labels: | |
| class_name = label.obj_class.name.partition(". ")[2] | |
| if class_name not in label2id: | |
| tqdm.write(f"Skipping unrecognized label: {label}") | |
| continue # skip unrecognized labels | |
| class_index = label2id[class_name] | |
| if label.geometry.geometry_name() == "bitmap": | |
| origin = label.geometry.origin | |
| top = origin.row | |
| left = origin.col | |
| bitmap = label.geometry.data # binary numpy array, shape (h, w) | |
| h, w = bitmap.shape | |
| if top + h > height or left + w > width: | |
| tqdm.write(f"Skipping label '{class_name}': size mismatch.") | |
| continue | |
| class_mask[top : top + h, left : left + w][bitmap] = class_index | |
| else: | |
| continue | |
| Image.fromarray(class_mask).save(mask_path) | |
| def process_datasets( | |
| project, | |
| datasets, | |
| output_dir, | |
| label2id, | |
| train_items, | |
| ): | |
| """Process all datasets and create train/validation splits.""" | |
| for dataset in tqdm(datasets, desc="Processing datasets"): | |
| items = dataset.get_items_names() | |
| random.shuffle(items) | |
| for i, item in tqdm( | |
| enumerate(items), | |
| desc=f"Processing dataset: {dataset.name}", | |
| total=len(items), | |
| leave=False, | |
| ): | |
| # Determine split | |
| split = "training" if i < train_items else "validation" | |
| # Copy images | |
| item_paths = dataset.get_item_paths(item) | |
| img_path = item_paths.img_path | |
| img_filename = os.path.basename(img_path) | |
| dest_path = f"{output_dir}/images/{split}/{img_filename}" | |
| shutil.copy(img_path, dest_path) | |
| # Convert and copy annotations | |
| ann_path = item_paths.ann_path | |
| ann = sly.Annotation.load_json_file(ann_path, project.meta) | |
| mask_filename = f"{os.path.splitext(item)[0]}.png" | |
| mask_path = f"{output_dir}/annotations/{split}/{mask_filename}" | |
| to_class_index_mask(ann, label2id, mask_path) | |
| def parse_arguments(): | |
| """Parse command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description="Parse Supervisely project and convert to training format" | |
| ) | |
| parser.add_argument( | |
| "--project_dir", | |
| type=str, | |
| required=True, | |
| help="Path to the Supervisely project directory", | |
| ) | |
| parser.add_argument( | |
| "--output_base_dir", | |
| type=str, | |
| required=True, | |
| help="Base output directory for parsed data", | |
| ) | |
| parser.add_argument( | |
| "--train_ratio", | |
| type=float, | |
| default=0.8, | |
| help="Ratio of data to use for training (default: 0.8)", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Random seed for reproducible splits (default: 42)", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| """Main function to parse Supervisely project.""" | |
| # Parse arguments | |
| args = parse_arguments() | |
| # Set random seed for reproducible splits | |
| random.seed(args.seed) | |
| # Load project | |
| project = sly.Project(args.project_dir, sly.OpenMode.READ) | |
| print(f"Project: {project.name}") | |
| # Setup output directory | |
| output_dir = os.path.join(args.output_base_dir, project.name) | |
| create_output_directories(output_dir) | |
| # Extract class information | |
| id2label, id2color, label2id = extract_class_info(project, output_dir) | |
| # Get datasets and calculate splits | |
| datasets = project.datasets | |
| print(f"Datasets: {len(datasets)}") | |
| train_items, val_items = calculate_split_counts(datasets, args.train_ratio) | |
| # Process datasets | |
| process_datasets( | |
| project, | |
| datasets, | |
| output_dir, | |
| label2id, | |
| train_items, | |
| ) | |
| print("Processing completed successfully!") | |
| if __name__ == "__main__": | |
| main() | |