Spaces:
Sleeping
Sleeping
Implement initial version of SegFormer training pipeline with dataset parsing and model training functionalities. Added Dockerfile for environment setup, utility scripts for parsing and training, and Gradio interface for user interaction.
e4aef33
| import os | |
| import shutil | |
| import gradio as gr | |
| import pandas as pd | |
| from utils import run_supervisely_parser, train_model | |
| def run_pipeline( | |
| base_model_zip, | |
| supervisely_project_zip, | |
| train_ratio, | |
| seed, | |
| data_percent, | |
| batch_size, | |
| num_epochs, | |
| learning_rate, | |
| image_width, | |
| image_height, | |
| early_stopping, | |
| validate_every, | |
| pr=gr.Progress(track_tqdm=True), | |
| ): | |
| # Parsing | |
| yield ( | |
| gr.update(interactive=False), # run button | |
| gr.update( # status textbox | |
| value="Parsing Supervisely project ...", | |
| visible=True, | |
| ), | |
| gr.update(visible=False), # model download button | |
| gr.update(visible=False), # output tab | |
| gr.update(value=None), # train IoU plot | |
| gr.update(value=None), # val IoU plot | |
| gr.update(value=None), # metrics table | |
| ) | |
| dataset_dir = run_supervisely_parser( | |
| project_path=supervisely_project_zip, | |
| train_ratio=train_ratio, | |
| seed=seed, | |
| ) | |
| # Training | |
| yield ( | |
| gr.update(interactive=False), # run button | |
| gr.update( # status textbox | |
| value="Starting model training...", | |
| ), | |
| gr.update(visible=False), # model download button | |
| gr.update(visible=False), # output tab | |
| gr.update(value=None), # train IoU plot | |
| gr.update(value=None), # val IoU plot | |
| gr.update(value=None), # metrics table | |
| ) | |
| best_model, metrics, dice = train_model( | |
| dataset_dir, | |
| base_model_zip, | |
| image_width, | |
| image_height, | |
| batch_size, | |
| data_percent, | |
| num_epochs, | |
| learning_rate, | |
| early_stopping, | |
| validate_every, | |
| ) | |
| # Saving model | |
| yield ( | |
| gr.update(interactive=False), # run button | |
| gr.update( # status textbox | |
| value="Saving best model...", | |
| ), | |
| gr.update(visible=False), # model download button | |
| gr.update(visible=False), # output tab | |
| gr.update(value=None), # train IoU plot | |
| gr.update(value=None), # val IoU plot | |
| gr.update(value=None), # metrics table | |
| ) | |
| best_model_dir = os.path.join( | |
| os.path.dirname(base_model_zip), | |
| "best_model", | |
| ) | |
| best_model.save_pretrained(best_model_dir) | |
| best_model_zip_path = shutil.make_archive( | |
| base_name=best_model_dir, | |
| format="zip", | |
| root_dir=best_model_dir, | |
| ) | |
| metrics_df = pd.DataFrame(metrics) | |
| initial_epoch_metrics = metrics_df.iloc[0] | |
| final_epoch_metrics = metrics_df.iloc[-1] | |
| # metrics comparison table use epoch 0 as before and final as after | |
| metrics_comparison_df = pd.DataFrame( | |
| { | |
| "Metric": ["Accuracy", "IoU", "Loss", "Dice"], | |
| "Before": [ | |
| initial_epoch_metrics["val_acc"], | |
| initial_epoch_metrics["val_iou"], | |
| initial_epoch_metrics["val_loss"], | |
| dice[0], | |
| ], | |
| "After": [ | |
| final_epoch_metrics["val_acc"], | |
| final_epoch_metrics["val_iou"], | |
| final_epoch_metrics["val_loss"], | |
| dice[1], | |
| ], | |
| } | |
| ) | |
| yield ( | |
| gr.update(interactive=True), # run button | |
| gr.update(visible=False), # status textbox | |
| gr.update( # model download button | |
| value=best_model_zip_path, | |
| visible=True, | |
| ), | |
| gr.update(visible=True), # output tab | |
| gr.update(value=metrics_df), # train IoU plot | |
| gr.update(value=metrics_df), # val IoU plot | |
| gr.update( | |
| value=metrics_comparison_df, | |
| visible=True, | |
| ), | |
| ) | |
| def _toggle_run_btn(base_model, project): | |
| """Enable run button only when both required files are selected.""" | |
| ready = bool(base_model and project) | |
| return gr.update(interactive=ready) | |
| with gr.Blocks(title="SegFormer Training & Dataset Pipeline") as demo: | |
| gr.Markdown( | |
| "# SegFormer Training Pipeline\n" | |
| "Upload your base model and Supervisely project, " | |
| "tweak parsing & training hyperparameters, then click " | |
| "**Run Training**." | |
| ) | |
| with gr.Row(): | |
| base_model_zip = gr.File( | |
| label="Base PyTorch Model (.zip)", | |
| file_types=[".zip"], | |
| file_count="single", | |
| ) | |
| supervisely_project_zip = gr.File( | |
| label="Supervisely Project (.zip)", | |
| file_types=[".zip"], | |
| file_count="single", | |
| ) | |
| with gr.Tab("Training"): | |
| gr.Markdown("Adjust training hyperparameters.") | |
| with gr.Row(): | |
| data_percent = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=100, | |
| label="Data Percent (%) used for training", | |
| ) | |
| batch_size = gr.Number( | |
| value=32, | |
| label="Batch Size (samples/step)", | |
| precision=0, | |
| minimum=1, | |
| ) | |
| num_epochs = gr.Number( | |
| value=60, | |
| label="Epochs (max passes)", | |
| precision=0, | |
| minimum=1, | |
| ) | |
| with gr.Row(): | |
| learning_rate = gr.Number( | |
| value=5e-5, | |
| label="Learning Rate", | |
| minimum=0.0, | |
| maximum=1.0, | |
| ) | |
| image_width = gr.Number( | |
| value=640, | |
| label="Image Width (px)", | |
| precision=0, | |
| minimum=1, | |
| ) | |
| image_height = gr.Number( | |
| value=640, | |
| label="Image Height (px)", | |
| precision=0, | |
| minimum=1, | |
| ) | |
| with gr.Row(): | |
| early_stopping = gr.Number( | |
| value=3, | |
| label="Early Stopping Patience (epochs w/o improvement)", | |
| precision=0, | |
| minimum=0, | |
| ) | |
| validate_every = gr.Number( | |
| value=1, | |
| label="Validate Every (epochs)", | |
| precision=0, | |
| minimum=0, | |
| ) | |
| with gr.Tab("Dataset Parsing"): | |
| gr.Markdown("Configure how the dataset is split and seeded.") | |
| with gr.Row(): | |
| train_ratio = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.95, | |
| step=0.01, | |
| value=0.8, | |
| label="Train Split Ratio (rest used for validation)", | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| label="Random Seed (reproducibility)", | |
| precision=0, | |
| ) | |
| with gr.Accordion("Parameter Help", open=False): | |
| gr.Markdown( | |
| """ | |
| **Base PyTorch Model (.zip)**: Archive containing a folder with | |
| weights and configuration file.\n | |
| **Supervisely Project (.zip)**: Archive containing Exported | |
| Supervisely project | |
| containing images and annotation JSONs.\n | |
| **Train Split Ratio**: Fraction of dataset used for training; | |
| remainder becomes validation.\n | |
| **Random Seed**: Controls shuffling for reproducible splits & | |
| training.\n | |
| **Data Percent**: Subsample percentage of training split (use | |
| <100 for quick experiments).\n | |
| **Batch Size**: Samples processed before each optimizer step.\n | |
| **Epochs**: Maximum complete passes over the (subsampled) | |
| training set.\n | |
| **Learning Rate**: Initial optimizer step size.\n | |
| **Image Width / Height**: Target spatial size for preprocessing | |
| (resize/crop).\n | |
| **Early Stopping Patience**: Stop after this many validation | |
| checks without improvement.\n | |
| **Validate Every**: Run validation after this many epochs.\n | |
| """ | |
| ) | |
| run_btn = gr.Button( | |
| "Run Training", | |
| variant="primary", | |
| interactive=False, | |
| ) | |
| status = gr.Textbox( | |
| show_label=False, | |
| visible=False, | |
| ) | |
| with gr.Tab("Results", visible=False) as output_tab: | |
| model_download_btn = gr.DownloadButton( | |
| label="Download Trained Model (.zip)", | |
| value=None, | |
| visible=False, | |
| ) | |
| # table to show before and after accuracy and iou | |
| metrics_table = gr.DataFrame( | |
| label="Metrics Comparison", | |
| interactive=False, | |
| wrap=True, | |
| ) | |
| with gr.Row(): | |
| train_iou_plot = gr.LinePlot( | |
| label="Training IoU", | |
| x="epoch", | |
| y="train_iou", | |
| x_title="Epoch", | |
| y_title="IoU", | |
| height=400, | |
| ) | |
| val_iou_plot = gr.LinePlot( | |
| label="Validation IoU", | |
| x="epoch", | |
| y="val_iou", | |
| x_title="Epoch", | |
| y_title="IoU", | |
| height=400, | |
| ) | |
| # Enable run button only when both archives provided | |
| base_model_zip.change( | |
| _toggle_run_btn, | |
| inputs=[base_model_zip, supervisely_project_zip], | |
| outputs=run_btn, | |
| ) | |
| supervisely_project_zip.change( | |
| _toggle_run_btn, | |
| inputs=[base_model_zip, supervisely_project_zip], | |
| outputs=run_btn, | |
| ) | |
| # Click handler | |
| run_btn.click( | |
| run_pipeline, | |
| inputs=[ | |
| base_model_zip, | |
| supervisely_project_zip, | |
| train_ratio, | |
| seed, | |
| data_percent, | |
| batch_size, | |
| num_epochs, | |
| learning_rate, | |
| image_width, | |
| image_height, | |
| early_stopping, | |
| validate_every, | |
| ], | |
| outputs=[ | |
| run_btn, | |
| status, | |
| model_download_btn, | |
| output_tab, | |
| train_iou_plot, | |
| val_iou_plot, | |
| metrics_table, | |
| ], | |
| show_progress_on=status, | |
| scroll_to_output=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |