import os import shutil import gradio as gr import pandas as pd from utils import run_supervisely_parser, train_model def run_pipeline( base_model_zip, supervisely_project_zip, train_ratio, seed, data_percent, batch_size, num_epochs, learning_rate, image_width, image_height, early_stopping, validate_every, pr=gr.Progress(track_tqdm=True), ): # Parsing yield ( gr.update(interactive=False), # run button gr.update( # status textbox value="Parsing Supervisely project ...", visible=True, ), gr.update(visible=False), # model download button gr.update(visible=False), # output tab gr.update(value=None), # train IoU plot gr.update(value=None), # val IoU plot gr.update(value=None), # metrics table ) dataset_dir = run_supervisely_parser( project_path=supervisely_project_zip, train_ratio=train_ratio, seed=seed, ) # Training yield ( gr.update(interactive=False), # run button gr.update( # status textbox value="Starting model training...", ), gr.update(visible=False), # model download button gr.update(visible=False), # output tab gr.update(value=None), # train IoU plot gr.update(value=None), # val IoU plot gr.update(value=None), # metrics table ) best_model, metrics, dice = train_model( dataset_dir, base_model_zip, image_width, image_height, batch_size, data_percent, num_epochs, learning_rate, early_stopping, validate_every, ) # Saving model yield ( gr.update(interactive=False), # run button gr.update( # status textbox value="Saving best model...", ), gr.update(visible=False), # model download button gr.update(visible=False), # output tab gr.update(value=None), # train IoU plot gr.update(value=None), # val IoU plot gr.update(value=None), # metrics table ) best_model_dir = os.path.join( os.path.dirname(base_model_zip), "best_model", ) best_model.save_pretrained(best_model_dir) best_model_zip_path = shutil.make_archive( base_name=best_model_dir, format="zip", root_dir=best_model_dir, ) metrics_df = pd.DataFrame(metrics) initial_epoch_metrics = metrics_df.iloc[0] final_epoch_metrics = metrics_df.iloc[-1] # metrics comparison table use epoch 0 as before and final as after metrics_comparison_df = pd.DataFrame( { "Metric": ["Accuracy", "IoU", "Loss", "Dice"], "Before": [ initial_epoch_metrics["val_acc"], initial_epoch_metrics["val_iou"], initial_epoch_metrics["val_loss"], dice[0], ], "After": [ final_epoch_metrics["val_acc"], final_epoch_metrics["val_iou"], final_epoch_metrics["val_loss"], dice[1], ], } ) yield ( gr.update(interactive=True), # run button gr.update(visible=False), # status textbox gr.update( # model download button value=best_model_zip_path, visible=True, ), gr.update(visible=True), # output tab gr.update(value=metrics_df), # train IoU plot gr.update(value=metrics_df), # val IoU plot gr.update( value=metrics_comparison_df, visible=True, ), ) def _toggle_run_btn(base_model, project): """Enable run button only when both required files are selected.""" ready = bool(base_model and project) return gr.update(interactive=ready) with gr.Blocks(title="SegFormer Training & Dataset Pipeline") as demo: gr.Markdown( "# SegFormer Training Pipeline\n" "Upload your base model and Supervisely project, " "tweak parsing & training hyperparameters, then click " "**Run Training**." ) with gr.Row(): base_model_zip = gr.File( label="Base PyTorch Model (.zip)", file_types=[".zip"], file_count="single", ) supervisely_project_zip = gr.File( label="Supervisely Project (.zip)", file_types=[".zip"], file_count="single", ) with gr.Tab("Training"): gr.Markdown("Adjust training hyperparameters.") with gr.Row(): data_percent = gr.Slider( minimum=1, maximum=100, step=1, value=100, label="Data Percent (%) used for training", ) batch_size = gr.Number( value=32, label="Batch Size (samples/step)", precision=0, minimum=1, ) num_epochs = gr.Number( value=60, label="Epochs (max passes)", precision=0, minimum=1, ) with gr.Row(): learning_rate = gr.Number( value=5e-5, label="Learning Rate", minimum=0.0, maximum=1.0, ) image_width = gr.Number( value=640, label="Image Width (px)", precision=0, minimum=1, ) image_height = gr.Number( value=640, label="Image Height (px)", precision=0, minimum=1, ) with gr.Row(): early_stopping = gr.Number( value=3, label="Early Stopping Patience (epochs w/o improvement)", precision=0, minimum=0, ) validate_every = gr.Number( value=1, label="Validate Every (epochs)", precision=0, minimum=0, ) with gr.Tab("Dataset Parsing"): gr.Markdown("Configure how the dataset is split and seeded.") with gr.Row(): train_ratio = gr.Slider( minimum=0.1, maximum=0.95, step=0.01, value=0.8, label="Train Split Ratio (rest used for validation)", ) seed = gr.Number( value=42, label="Random Seed (reproducibility)", precision=0, ) with gr.Accordion("Parameter Help", open=False): gr.Markdown( """ **Base PyTorch Model (.zip)**: Archive containing a folder with weights and configuration file.\n **Supervisely Project (.zip)**: Archive containing Exported Supervisely project containing images and annotation JSONs.\n **Train Split Ratio**: Fraction of dataset used for training; remainder becomes validation.\n **Random Seed**: Controls shuffling for reproducible splits & training.\n **Data Percent**: Subsample percentage of training split (use <100 for quick experiments).\n **Batch Size**: Samples processed before each optimizer step.\n **Epochs**: Maximum complete passes over the (subsampled) training set.\n **Learning Rate**: Initial optimizer step size.\n **Image Width / Height**: Target spatial size for preprocessing (resize/crop).\n **Early Stopping Patience**: Stop after this many validation checks without improvement.\n **Validate Every**: Run validation after this many epochs.\n """ ) run_btn = gr.Button( "Run Training", variant="primary", interactive=False, ) status = gr.Textbox( show_label=False, visible=False, ) with gr.Tab("Results", visible=False) as output_tab: model_download_btn = gr.DownloadButton( label="Download Trained Model (.zip)", value=None, visible=False, ) # table to show before and after accuracy and iou metrics_table = gr.DataFrame( label="Metrics Comparison", interactive=False, wrap=True, ) with gr.Row(): train_iou_plot = gr.LinePlot( label="Training IoU", x="epoch", y="train_iou", x_title="Epoch", y_title="IoU", height=400, ) val_iou_plot = gr.LinePlot( label="Validation IoU", x="epoch", y="val_iou", x_title="Epoch", y_title="IoU", height=400, ) # Enable run button only when both archives provided base_model_zip.change( _toggle_run_btn, inputs=[base_model_zip, supervisely_project_zip], outputs=run_btn, ) supervisely_project_zip.change( _toggle_run_btn, inputs=[base_model_zip, supervisely_project_zip], outputs=run_btn, ) # Click handler run_btn.click( run_pipeline, inputs=[ base_model_zip, supervisely_project_zip, train_ratio, seed, data_percent, batch_size, num_epochs, learning_rate, image_width, image_height, early_stopping, validate_every, ], outputs=[ run_btn, status, model_download_btn, output_tab, train_iou_plot, val_iou_plot, metrics_table, ], show_progress_on=status, scroll_to_output=True, ) if __name__ == "__main__": demo.launch()