HomeSenseTest / app.py
YusufMesbah's picture
Implement initial version of SegFormer training pipeline with dataset parsing and model training functionalities. Added Dockerfile for environment setup, utility scripts for parsing and training, and Gradio interface for user interaction.
e4aef33
import os
import shutil
import gradio as gr
import pandas as pd
from utils import run_supervisely_parser, train_model
def run_pipeline(
base_model_zip,
supervisely_project_zip,
train_ratio,
seed,
data_percent,
batch_size,
num_epochs,
learning_rate,
image_width,
image_height,
early_stopping,
validate_every,
pr=gr.Progress(track_tqdm=True),
):
# Parsing
yield (
gr.update(interactive=False), # run button
gr.update( # status textbox
value="Parsing Supervisely project ...",
visible=True,
),
gr.update(visible=False), # model download button
gr.update(visible=False), # output tab
gr.update(value=None), # train IoU plot
gr.update(value=None), # val IoU plot
gr.update(value=None), # metrics table
)
dataset_dir = run_supervisely_parser(
project_path=supervisely_project_zip,
train_ratio=train_ratio,
seed=seed,
)
# Training
yield (
gr.update(interactive=False), # run button
gr.update( # status textbox
value="Starting model training...",
),
gr.update(visible=False), # model download button
gr.update(visible=False), # output tab
gr.update(value=None), # train IoU plot
gr.update(value=None), # val IoU plot
gr.update(value=None), # metrics table
)
best_model, metrics, dice = train_model(
dataset_dir,
base_model_zip,
image_width,
image_height,
batch_size,
data_percent,
num_epochs,
learning_rate,
early_stopping,
validate_every,
)
# Saving model
yield (
gr.update(interactive=False), # run button
gr.update( # status textbox
value="Saving best model...",
),
gr.update(visible=False), # model download button
gr.update(visible=False), # output tab
gr.update(value=None), # train IoU plot
gr.update(value=None), # val IoU plot
gr.update(value=None), # metrics table
)
best_model_dir = os.path.join(
os.path.dirname(base_model_zip),
"best_model",
)
best_model.save_pretrained(best_model_dir)
best_model_zip_path = shutil.make_archive(
base_name=best_model_dir,
format="zip",
root_dir=best_model_dir,
)
metrics_df = pd.DataFrame(metrics)
initial_epoch_metrics = metrics_df.iloc[0]
final_epoch_metrics = metrics_df.iloc[-1]
# metrics comparison table use epoch 0 as before and final as after
metrics_comparison_df = pd.DataFrame(
{
"Metric": ["Accuracy", "IoU", "Loss", "Dice"],
"Before": [
initial_epoch_metrics["val_acc"],
initial_epoch_metrics["val_iou"],
initial_epoch_metrics["val_loss"],
dice[0],
],
"After": [
final_epoch_metrics["val_acc"],
final_epoch_metrics["val_iou"],
final_epoch_metrics["val_loss"],
dice[1],
],
}
)
yield (
gr.update(interactive=True), # run button
gr.update(visible=False), # status textbox
gr.update( # model download button
value=best_model_zip_path,
visible=True,
),
gr.update(visible=True), # output tab
gr.update(value=metrics_df), # train IoU plot
gr.update(value=metrics_df), # val IoU plot
gr.update(
value=metrics_comparison_df,
visible=True,
),
)
def _toggle_run_btn(base_model, project):
"""Enable run button only when both required files are selected."""
ready = bool(base_model and project)
return gr.update(interactive=ready)
with gr.Blocks(title="SegFormer Training & Dataset Pipeline") as demo:
gr.Markdown(
"# SegFormer Training Pipeline\n"
"Upload your base model and Supervisely project, "
"tweak parsing & training hyperparameters, then click "
"**Run Training**."
)
with gr.Row():
base_model_zip = gr.File(
label="Base PyTorch Model (.zip)",
file_types=[".zip"],
file_count="single",
)
supervisely_project_zip = gr.File(
label="Supervisely Project (.zip)",
file_types=[".zip"],
file_count="single",
)
with gr.Tab("Training"):
gr.Markdown("Adjust training hyperparameters.")
with gr.Row():
data_percent = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=100,
label="Data Percent (%) used for training",
)
batch_size = gr.Number(
value=32,
label="Batch Size (samples/step)",
precision=0,
minimum=1,
)
num_epochs = gr.Number(
value=60,
label="Epochs (max passes)",
precision=0,
minimum=1,
)
with gr.Row():
learning_rate = gr.Number(
value=5e-5,
label="Learning Rate",
minimum=0.0,
maximum=1.0,
)
image_width = gr.Number(
value=640,
label="Image Width (px)",
precision=0,
minimum=1,
)
image_height = gr.Number(
value=640,
label="Image Height (px)",
precision=0,
minimum=1,
)
with gr.Row():
early_stopping = gr.Number(
value=3,
label="Early Stopping Patience (epochs w/o improvement)",
precision=0,
minimum=0,
)
validate_every = gr.Number(
value=1,
label="Validate Every (epochs)",
precision=0,
minimum=0,
)
with gr.Tab("Dataset Parsing"):
gr.Markdown("Configure how the dataset is split and seeded.")
with gr.Row():
train_ratio = gr.Slider(
minimum=0.1,
maximum=0.95,
step=0.01,
value=0.8,
label="Train Split Ratio (rest used for validation)",
)
seed = gr.Number(
value=42,
label="Random Seed (reproducibility)",
precision=0,
)
with gr.Accordion("Parameter Help", open=False):
gr.Markdown(
"""
**Base PyTorch Model (.zip)**: Archive containing a folder with
weights and configuration file.\n
**Supervisely Project (.zip)**: Archive containing Exported
Supervisely project
containing images and annotation JSONs.\n
**Train Split Ratio**: Fraction of dataset used for training;
remainder becomes validation.\n
**Random Seed**: Controls shuffling for reproducible splits &
training.\n
**Data Percent**: Subsample percentage of training split (use
<100 for quick experiments).\n
**Batch Size**: Samples processed before each optimizer step.\n
**Epochs**: Maximum complete passes over the (subsampled)
training set.\n
**Learning Rate**: Initial optimizer step size.\n
**Image Width / Height**: Target spatial size for preprocessing
(resize/crop).\n
**Early Stopping Patience**: Stop after this many validation
checks without improvement.\n
**Validate Every**: Run validation after this many epochs.\n
"""
)
run_btn = gr.Button(
"Run Training",
variant="primary",
interactive=False,
)
status = gr.Textbox(
show_label=False,
visible=False,
)
with gr.Tab("Results", visible=False) as output_tab:
model_download_btn = gr.DownloadButton(
label="Download Trained Model (.zip)",
value=None,
visible=False,
)
# table to show before and after accuracy and iou
metrics_table = gr.DataFrame(
label="Metrics Comparison",
interactive=False,
wrap=True,
)
with gr.Row():
train_iou_plot = gr.LinePlot(
label="Training IoU",
x="epoch",
y="train_iou",
x_title="Epoch",
y_title="IoU",
height=400,
)
val_iou_plot = gr.LinePlot(
label="Validation IoU",
x="epoch",
y="val_iou",
x_title="Epoch",
y_title="IoU",
height=400,
)
# Enable run button only when both archives provided
base_model_zip.change(
_toggle_run_btn,
inputs=[base_model_zip, supervisely_project_zip],
outputs=run_btn,
)
supervisely_project_zip.change(
_toggle_run_btn,
inputs=[base_model_zip, supervisely_project_zip],
outputs=run_btn,
)
# Click handler
run_btn.click(
run_pipeline,
inputs=[
base_model_zip,
supervisely_project_zip,
train_ratio,
seed,
data_percent,
batch_size,
num_epochs,
learning_rate,
image_width,
image_height,
early_stopping,
validate_every,
],
outputs=[
run_btn,
status,
model_download_btn,
output_tab,
train_iou_plot,
val_iou_plot,
metrics_table,
],
show_progress_on=status,
scroll_to_output=True,
)
if __name__ == "__main__":
demo.launch()