Spaces:
Sleeping
Sleeping
Commit
·
e4aef33
1
Parent(s):
7a8be68
Implement initial version of SegFormer training pipeline with dataset parsing and model training functionalities. Added Dockerfile for environment setup, utility scripts for parsing and training, and Gradio interface for user interaction.
Browse files- Dockerfile +37 -0
- app.py +346 -0
- requirements-sly.txt +4 -0
- requirements.txt +6 -0
- scripts/supervisely_parser.py +221 -0
- utils/__init__.py +7 -0
- utils/parse.py +101 -0
- utils/train.py +602 -0
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.8.0-cuda12.9-cudnn9-runtime
|
| 2 |
+
|
| 3 |
+
ENV PYTHONUNBUFFERED=1
|
| 4 |
+
|
| 5 |
+
RUN useradd -m -u 1000 user
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install system dependencies for building Python packages
|
| 9 |
+
RUN apt-get update && apt-get install -y \
|
| 10 |
+
build-essential \
|
| 11 |
+
libffi-dev \
|
| 12 |
+
libssl-dev \
|
| 13 |
+
ffmpeg \
|
| 14 |
+
libsm6 \
|
| 15 |
+
libxext6 \
|
| 16 |
+
libmagic1 \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Create virtual environment for Sly (keep isolated)
|
| 21 |
+
RUN python -m venv /app/.venv-sly
|
| 22 |
+
RUN /app/.venv-sly/bin/pip install --upgrade pip
|
| 23 |
+
COPY --chown=user requirements-sly.txt requirements-sly.txt
|
| 24 |
+
RUN /app/.venv-sly/bin/pip install -r requirements-sly.txt
|
| 25 |
+
|
| 26 |
+
# Install Gradio and other dependencies
|
| 27 |
+
RUN pip install --upgrade pip
|
| 28 |
+
COPY --chown=user requirements.txt requirements.txt
|
| 29 |
+
RUN pip install -r requirements.txt
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Copy the rest of the app
|
| 33 |
+
COPY --chown=user . .
|
| 34 |
+
|
| 35 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 36 |
+
|
| 37 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from utils import run_supervisely_parser, train_model
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_pipeline(
|
| 10 |
+
base_model_zip,
|
| 11 |
+
supervisely_project_zip,
|
| 12 |
+
train_ratio,
|
| 13 |
+
seed,
|
| 14 |
+
data_percent,
|
| 15 |
+
batch_size,
|
| 16 |
+
num_epochs,
|
| 17 |
+
learning_rate,
|
| 18 |
+
image_width,
|
| 19 |
+
image_height,
|
| 20 |
+
early_stopping,
|
| 21 |
+
validate_every,
|
| 22 |
+
pr=gr.Progress(track_tqdm=True),
|
| 23 |
+
):
|
| 24 |
+
|
| 25 |
+
# Parsing
|
| 26 |
+
yield (
|
| 27 |
+
gr.update(interactive=False), # run button
|
| 28 |
+
gr.update( # status textbox
|
| 29 |
+
value="Parsing Supervisely project ...",
|
| 30 |
+
visible=True,
|
| 31 |
+
),
|
| 32 |
+
gr.update(visible=False), # model download button
|
| 33 |
+
gr.update(visible=False), # output tab
|
| 34 |
+
gr.update(value=None), # train IoU plot
|
| 35 |
+
gr.update(value=None), # val IoU plot
|
| 36 |
+
gr.update(value=None), # metrics table
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
dataset_dir = run_supervisely_parser(
|
| 40 |
+
project_path=supervisely_project_zip,
|
| 41 |
+
train_ratio=train_ratio,
|
| 42 |
+
seed=seed,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Training
|
| 46 |
+
yield (
|
| 47 |
+
gr.update(interactive=False), # run button
|
| 48 |
+
gr.update( # status textbox
|
| 49 |
+
value="Starting model training...",
|
| 50 |
+
),
|
| 51 |
+
gr.update(visible=False), # model download button
|
| 52 |
+
gr.update(visible=False), # output tab
|
| 53 |
+
gr.update(value=None), # train IoU plot
|
| 54 |
+
gr.update(value=None), # val IoU plot
|
| 55 |
+
gr.update(value=None), # metrics table
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
best_model, metrics, dice = train_model(
|
| 59 |
+
dataset_dir,
|
| 60 |
+
base_model_zip,
|
| 61 |
+
image_width,
|
| 62 |
+
image_height,
|
| 63 |
+
batch_size,
|
| 64 |
+
data_percent,
|
| 65 |
+
num_epochs,
|
| 66 |
+
learning_rate,
|
| 67 |
+
early_stopping,
|
| 68 |
+
validate_every,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# Saving model
|
| 72 |
+
yield (
|
| 73 |
+
gr.update(interactive=False), # run button
|
| 74 |
+
gr.update( # status textbox
|
| 75 |
+
value="Saving best model...",
|
| 76 |
+
),
|
| 77 |
+
gr.update(visible=False), # model download button
|
| 78 |
+
gr.update(visible=False), # output tab
|
| 79 |
+
gr.update(value=None), # train IoU plot
|
| 80 |
+
gr.update(value=None), # val IoU plot
|
| 81 |
+
gr.update(value=None), # metrics table
|
| 82 |
+
)
|
| 83 |
+
best_model_dir = os.path.join(
|
| 84 |
+
os.path.dirname(base_model_zip),
|
| 85 |
+
"best_model",
|
| 86 |
+
)
|
| 87 |
+
best_model.save_pretrained(best_model_dir)
|
| 88 |
+
|
| 89 |
+
best_model_zip_path = shutil.make_archive(
|
| 90 |
+
base_name=best_model_dir,
|
| 91 |
+
format="zip",
|
| 92 |
+
root_dir=best_model_dir,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
metrics_df = pd.DataFrame(metrics)
|
| 96 |
+
|
| 97 |
+
initial_epoch_metrics = metrics_df.iloc[0]
|
| 98 |
+
final_epoch_metrics = metrics_df.iloc[-1]
|
| 99 |
+
|
| 100 |
+
# metrics comparison table use epoch 0 as before and final as after
|
| 101 |
+
metrics_comparison_df = pd.DataFrame(
|
| 102 |
+
{
|
| 103 |
+
"Metric": ["Accuracy", "IoU", "Loss", "Dice"],
|
| 104 |
+
"Before": [
|
| 105 |
+
initial_epoch_metrics["val_acc"],
|
| 106 |
+
initial_epoch_metrics["val_iou"],
|
| 107 |
+
initial_epoch_metrics["val_loss"],
|
| 108 |
+
dice[0],
|
| 109 |
+
],
|
| 110 |
+
"After": [
|
| 111 |
+
final_epoch_metrics["val_acc"],
|
| 112 |
+
final_epoch_metrics["val_iou"],
|
| 113 |
+
final_epoch_metrics["val_loss"],
|
| 114 |
+
dice[1],
|
| 115 |
+
],
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
yield (
|
| 120 |
+
gr.update(interactive=True), # run button
|
| 121 |
+
gr.update(visible=False), # status textbox
|
| 122 |
+
gr.update( # model download button
|
| 123 |
+
value=best_model_zip_path,
|
| 124 |
+
visible=True,
|
| 125 |
+
),
|
| 126 |
+
gr.update(visible=True), # output tab
|
| 127 |
+
gr.update(value=metrics_df), # train IoU plot
|
| 128 |
+
gr.update(value=metrics_df), # val IoU plot
|
| 129 |
+
gr.update(
|
| 130 |
+
value=metrics_comparison_df,
|
| 131 |
+
visible=True,
|
| 132 |
+
),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _toggle_run_btn(base_model, project):
|
| 137 |
+
"""Enable run button only when both required files are selected."""
|
| 138 |
+
ready = bool(base_model and project)
|
| 139 |
+
return gr.update(interactive=ready)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
with gr.Blocks(title="SegFormer Training & Dataset Pipeline") as demo:
|
| 143 |
+
gr.Markdown(
|
| 144 |
+
"# SegFormer Training Pipeline\n"
|
| 145 |
+
"Upload your base model and Supervisely project, "
|
| 146 |
+
"tweak parsing & training hyperparameters, then click "
|
| 147 |
+
"**Run Training**."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
with gr.Row():
|
| 151 |
+
base_model_zip = gr.File(
|
| 152 |
+
label="Base PyTorch Model (.zip)",
|
| 153 |
+
file_types=[".zip"],
|
| 154 |
+
file_count="single",
|
| 155 |
+
)
|
| 156 |
+
supervisely_project_zip = gr.File(
|
| 157 |
+
label="Supervisely Project (.zip)",
|
| 158 |
+
file_types=[".zip"],
|
| 159 |
+
file_count="single",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
with gr.Tab("Training"):
|
| 163 |
+
gr.Markdown("Adjust training hyperparameters.")
|
| 164 |
+
with gr.Row():
|
| 165 |
+
data_percent = gr.Slider(
|
| 166 |
+
minimum=1,
|
| 167 |
+
maximum=100,
|
| 168 |
+
step=1,
|
| 169 |
+
value=100,
|
| 170 |
+
label="Data Percent (%) used for training",
|
| 171 |
+
)
|
| 172 |
+
batch_size = gr.Number(
|
| 173 |
+
value=32,
|
| 174 |
+
label="Batch Size (samples/step)",
|
| 175 |
+
precision=0,
|
| 176 |
+
minimum=1,
|
| 177 |
+
)
|
| 178 |
+
num_epochs = gr.Number(
|
| 179 |
+
value=60,
|
| 180 |
+
label="Epochs (max passes)",
|
| 181 |
+
precision=0,
|
| 182 |
+
minimum=1,
|
| 183 |
+
)
|
| 184 |
+
with gr.Row():
|
| 185 |
+
learning_rate = gr.Number(
|
| 186 |
+
value=5e-5,
|
| 187 |
+
label="Learning Rate",
|
| 188 |
+
minimum=0.0,
|
| 189 |
+
maximum=1.0,
|
| 190 |
+
)
|
| 191 |
+
image_width = gr.Number(
|
| 192 |
+
value=640,
|
| 193 |
+
label="Image Width (px)",
|
| 194 |
+
precision=0,
|
| 195 |
+
minimum=1,
|
| 196 |
+
)
|
| 197 |
+
image_height = gr.Number(
|
| 198 |
+
value=640,
|
| 199 |
+
label="Image Height (px)",
|
| 200 |
+
precision=0,
|
| 201 |
+
minimum=1,
|
| 202 |
+
)
|
| 203 |
+
with gr.Row():
|
| 204 |
+
early_stopping = gr.Number(
|
| 205 |
+
value=3,
|
| 206 |
+
label="Early Stopping Patience (epochs w/o improvement)",
|
| 207 |
+
precision=0,
|
| 208 |
+
minimum=0,
|
| 209 |
+
)
|
| 210 |
+
validate_every = gr.Number(
|
| 211 |
+
value=1,
|
| 212 |
+
label="Validate Every (epochs)",
|
| 213 |
+
precision=0,
|
| 214 |
+
minimum=0,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with gr.Tab("Dataset Parsing"):
|
| 218 |
+
gr.Markdown("Configure how the dataset is split and seeded.")
|
| 219 |
+
with gr.Row():
|
| 220 |
+
train_ratio = gr.Slider(
|
| 221 |
+
minimum=0.1,
|
| 222 |
+
maximum=0.95,
|
| 223 |
+
step=0.01,
|
| 224 |
+
value=0.8,
|
| 225 |
+
label="Train Split Ratio (rest used for validation)",
|
| 226 |
+
)
|
| 227 |
+
seed = gr.Number(
|
| 228 |
+
value=42,
|
| 229 |
+
label="Random Seed (reproducibility)",
|
| 230 |
+
precision=0,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
with gr.Accordion("Parameter Help", open=False):
|
| 234 |
+
gr.Markdown(
|
| 235 |
+
"""
|
| 236 |
+
**Base PyTorch Model (.zip)**: Archive containing a folder with
|
| 237 |
+
weights and configuration file.\n
|
| 238 |
+
**Supervisely Project (.zip)**: Archive containing Exported
|
| 239 |
+
Supervisely project
|
| 240 |
+
containing images and annotation JSONs.\n
|
| 241 |
+
**Train Split Ratio**: Fraction of dataset used for training;
|
| 242 |
+
remainder becomes validation.\n
|
| 243 |
+
**Random Seed**: Controls shuffling for reproducible splits &
|
| 244 |
+
training.\n
|
| 245 |
+
**Data Percent**: Subsample percentage of training split (use
|
| 246 |
+
<100 for quick experiments).\n
|
| 247 |
+
**Batch Size**: Samples processed before each optimizer step.\n
|
| 248 |
+
**Epochs**: Maximum complete passes over the (subsampled)
|
| 249 |
+
training set.\n
|
| 250 |
+
**Learning Rate**: Initial optimizer step size.\n
|
| 251 |
+
**Image Width / Height**: Target spatial size for preprocessing
|
| 252 |
+
(resize/crop).\n
|
| 253 |
+
**Early Stopping Patience**: Stop after this many validation
|
| 254 |
+
checks without improvement.\n
|
| 255 |
+
**Validate Every**: Run validation after this many epochs.\n
|
| 256 |
+
"""
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
run_btn = gr.Button(
|
| 260 |
+
"Run Training",
|
| 261 |
+
variant="primary",
|
| 262 |
+
interactive=False,
|
| 263 |
+
)
|
| 264 |
+
status = gr.Textbox(
|
| 265 |
+
show_label=False,
|
| 266 |
+
visible=False,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
with gr.Tab("Results", visible=False) as output_tab:
|
| 270 |
+
|
| 271 |
+
model_download_btn = gr.DownloadButton(
|
| 272 |
+
label="Download Trained Model (.zip)",
|
| 273 |
+
value=None,
|
| 274 |
+
visible=False,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# table to show before and after accuracy and iou
|
| 278 |
+
metrics_table = gr.DataFrame(
|
| 279 |
+
label="Metrics Comparison",
|
| 280 |
+
interactive=False,
|
| 281 |
+
wrap=True,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
train_iou_plot = gr.LinePlot(
|
| 286 |
+
label="Training IoU",
|
| 287 |
+
x="epoch",
|
| 288 |
+
y="train_iou",
|
| 289 |
+
x_title="Epoch",
|
| 290 |
+
y_title="IoU",
|
| 291 |
+
height=400,
|
| 292 |
+
)
|
| 293 |
+
val_iou_plot = gr.LinePlot(
|
| 294 |
+
label="Validation IoU",
|
| 295 |
+
x="epoch",
|
| 296 |
+
y="val_iou",
|
| 297 |
+
x_title="Epoch",
|
| 298 |
+
y_title="IoU",
|
| 299 |
+
height=400,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Enable run button only when both archives provided
|
| 303 |
+
base_model_zip.change(
|
| 304 |
+
_toggle_run_btn,
|
| 305 |
+
inputs=[base_model_zip, supervisely_project_zip],
|
| 306 |
+
outputs=run_btn,
|
| 307 |
+
)
|
| 308 |
+
supervisely_project_zip.change(
|
| 309 |
+
_toggle_run_btn,
|
| 310 |
+
inputs=[base_model_zip, supervisely_project_zip],
|
| 311 |
+
outputs=run_btn,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Click handler
|
| 315 |
+
run_btn.click(
|
| 316 |
+
run_pipeline,
|
| 317 |
+
inputs=[
|
| 318 |
+
base_model_zip,
|
| 319 |
+
supervisely_project_zip,
|
| 320 |
+
train_ratio,
|
| 321 |
+
seed,
|
| 322 |
+
data_percent,
|
| 323 |
+
batch_size,
|
| 324 |
+
num_epochs,
|
| 325 |
+
learning_rate,
|
| 326 |
+
image_width,
|
| 327 |
+
image_height,
|
| 328 |
+
early_stopping,
|
| 329 |
+
validate_every,
|
| 330 |
+
],
|
| 331 |
+
outputs=[
|
| 332 |
+
run_btn,
|
| 333 |
+
status,
|
| 334 |
+
model_download_btn,
|
| 335 |
+
output_tab,
|
| 336 |
+
train_iou_plot,
|
| 337 |
+
val_iou_plot,
|
| 338 |
+
metrics_table,
|
| 339 |
+
],
|
| 340 |
+
show_progress_on=status,
|
| 341 |
+
scroll_to_output=True,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
if __name__ == "__main__":
|
| 346 |
+
demo.launch()
|
requirements-sly.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
supervisely>=6.73.418
|
| 2 |
+
numpy>=1.26.4
|
| 3 |
+
pillow>=10.2.0
|
| 4 |
+
tqdm>=4.67.1
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.8.0
|
| 2 |
+
gradio>=5.42.0
|
| 3 |
+
pillow>=11.3.0
|
| 4 |
+
transformers>=4.55.2
|
| 5 |
+
tqdm>=4.67.1
|
| 6 |
+
evaluate>=0.4.5
|
scripts/supervisely_parser.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Supervisely Parser Script
|
| 3 |
+
|
| 4 |
+
This script parses Supervisely projects and converts them to a format
|
| 5 |
+
suitable for training segmentation models. It extracts class information,
|
| 6 |
+
creates train/validation splits, and converts annotations to indexed
|
| 7 |
+
color masks.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import json
|
| 12 |
+
import random
|
| 13 |
+
import shutil
|
| 14 |
+
import argparse
|
| 15 |
+
import numpy as np
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import supervisely as sly
|
| 21 |
+
from supervisely import Annotation
|
| 22 |
+
except ImportError as e:
|
| 23 |
+
print(f"Failed to import supervisely: {e}")
|
| 24 |
+
print(
|
| 25 |
+
"Please ensure that the 'supervisely' package is installed and "
|
| 26 |
+
"compatible with your environment."
|
| 27 |
+
)
|
| 28 |
+
raise
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def extract_class_info(project, output_dir):
|
| 32 |
+
"""Extract class information from the project metadata."""
|
| 33 |
+
id2label = {}
|
| 34 |
+
id2color = {}
|
| 35 |
+
|
| 36 |
+
for obj in project.meta.obj_classes:
|
| 37 |
+
id_str, _, label = obj.name.partition(". ")
|
| 38 |
+
if not label or not id_str.isdigit():
|
| 39 |
+
continue
|
| 40 |
+
index = int(id_str) - 1
|
| 41 |
+
|
| 42 |
+
id2label[index] = label
|
| 43 |
+
id2color[index] = obj.color
|
| 44 |
+
|
| 45 |
+
# Save class mappings
|
| 46 |
+
with open(f"{output_dir}/id2label.json", "w") as f:
|
| 47 |
+
json.dump(id2label, f, sort_keys=True, indent=2)
|
| 48 |
+
with open(f"{output_dir}/id2color.json", "w") as f:
|
| 49 |
+
json.dump(id2color, f, sort_keys=True, indent=2)
|
| 50 |
+
|
| 51 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 52 |
+
return id2label, id2color, label2id
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def create_output_directories(output_dir):
|
| 56 |
+
"""Create necessary output directories."""
|
| 57 |
+
os.makedirs(f"{output_dir}/images/training", exist_ok=True)
|
| 58 |
+
os.makedirs(f"{output_dir}/annotations/training", exist_ok=True)
|
| 59 |
+
os.makedirs(f"{output_dir}/images/validation", exist_ok=True)
|
| 60 |
+
os.makedirs(f"{output_dir}/annotations/validation", exist_ok=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def calculate_split_counts(datasets, train_ratio=0.8):
|
| 64 |
+
"""Calculate the number of items for training and validation."""
|
| 65 |
+
total_items = 0
|
| 66 |
+
for dataset in datasets:
|
| 67 |
+
total_items += len(dataset.get_items_names())
|
| 68 |
+
|
| 69 |
+
train_items = int(total_items * train_ratio)
|
| 70 |
+
val_items = total_items - train_items
|
| 71 |
+
|
| 72 |
+
print(
|
| 73 |
+
f"Total items: {total_items}\n"
|
| 74 |
+
f"Train items: {train_items}\n"
|
| 75 |
+
f"Validation items: {val_items}"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
return train_items, val_items
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def to_class_index_mask(
|
| 82 |
+
annotation: Annotation,
|
| 83 |
+
label2id: dict,
|
| 84 |
+
mask_path: str,
|
| 85 |
+
):
|
| 86 |
+
"""Convert annotation to class index mask and save as PNG."""
|
| 87 |
+
height, width = annotation.img_size
|
| 88 |
+
class_mask = np.zeros((height, width), dtype=np.uint8)
|
| 89 |
+
|
| 90 |
+
for label in annotation.labels:
|
| 91 |
+
class_name = label.obj_class.name.partition(". ")[2]
|
| 92 |
+
if class_name not in label2id:
|
| 93 |
+
tqdm.write(f"Skipping unrecognized label: {label}")
|
| 94 |
+
continue # skip unrecognized labels
|
| 95 |
+
|
| 96 |
+
class_index = label2id[class_name]
|
| 97 |
+
|
| 98 |
+
if label.geometry.geometry_name() == "bitmap":
|
| 99 |
+
origin = label.geometry.origin
|
| 100 |
+
top = origin.row
|
| 101 |
+
left = origin.col
|
| 102 |
+
bitmap = label.geometry.data # binary numpy array, shape (h, w)
|
| 103 |
+
|
| 104 |
+
h, w = bitmap.shape
|
| 105 |
+
if top + h > height or left + w > width:
|
| 106 |
+
tqdm.write(f"Skipping label '{class_name}': size mismatch.")
|
| 107 |
+
continue
|
| 108 |
+
|
| 109 |
+
class_mask[top : top + h, left : left + w][bitmap] = class_index
|
| 110 |
+
else:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
Image.fromarray(class_mask).save(mask_path)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def process_datasets(
|
| 117 |
+
project,
|
| 118 |
+
datasets,
|
| 119 |
+
output_dir,
|
| 120 |
+
label2id,
|
| 121 |
+
train_items,
|
| 122 |
+
):
|
| 123 |
+
"""Process all datasets and create train/validation splits."""
|
| 124 |
+
for dataset in tqdm(datasets, desc="Processing datasets"):
|
| 125 |
+
items = dataset.get_items_names()
|
| 126 |
+
random.shuffle(items)
|
| 127 |
+
|
| 128 |
+
for i, item in tqdm(
|
| 129 |
+
enumerate(items),
|
| 130 |
+
desc=f"Processing dataset: {dataset.name}",
|
| 131 |
+
total=len(items),
|
| 132 |
+
leave=False,
|
| 133 |
+
):
|
| 134 |
+
# Determine split
|
| 135 |
+
split = "training" if i < train_items else "validation"
|
| 136 |
+
|
| 137 |
+
# Copy images
|
| 138 |
+
item_paths = dataset.get_item_paths(item)
|
| 139 |
+
img_path = item_paths.img_path
|
| 140 |
+
img_filename = os.path.basename(img_path)
|
| 141 |
+
dest_path = f"{output_dir}/images/{split}/{img_filename}"
|
| 142 |
+
shutil.copy(img_path, dest_path)
|
| 143 |
+
|
| 144 |
+
# Convert and copy annotations
|
| 145 |
+
ann_path = item_paths.ann_path
|
| 146 |
+
ann = sly.Annotation.load_json_file(ann_path, project.meta)
|
| 147 |
+
mask_filename = f"{os.path.splitext(item)[0]}.png"
|
| 148 |
+
mask_path = f"{output_dir}/annotations/{split}/{mask_filename}"
|
| 149 |
+
to_class_index_mask(ann, label2id, mask_path)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def parse_arguments():
|
| 153 |
+
"""Parse command line arguments."""
|
| 154 |
+
parser = argparse.ArgumentParser(
|
| 155 |
+
description="Parse Supervisely project and convert to training format"
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--project_dir",
|
| 159 |
+
type=str,
|
| 160 |
+
required=True,
|
| 161 |
+
help="Path to the Supervisely project directory",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--output_base_dir",
|
| 165 |
+
type=str,
|
| 166 |
+
required=True,
|
| 167 |
+
help="Base output directory for parsed data",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--train_ratio",
|
| 171 |
+
type=float,
|
| 172 |
+
default=0.8,
|
| 173 |
+
help="Ratio of data to use for training (default: 0.8)",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--seed",
|
| 177 |
+
type=int,
|
| 178 |
+
default=42,
|
| 179 |
+
help="Random seed for reproducible splits (default: 42)",
|
| 180 |
+
)
|
| 181 |
+
return parser.parse_args()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def main():
|
| 185 |
+
"""Main function to parse Supervisely project."""
|
| 186 |
+
# Parse arguments
|
| 187 |
+
args = parse_arguments()
|
| 188 |
+
|
| 189 |
+
# Set random seed for reproducible splits
|
| 190 |
+
random.seed(args.seed)
|
| 191 |
+
|
| 192 |
+
# Load project
|
| 193 |
+
project = sly.Project(args.project_dir, sly.OpenMode.READ)
|
| 194 |
+
print(f"Project: {project.name}")
|
| 195 |
+
|
| 196 |
+
# Setup output directory
|
| 197 |
+
output_dir = os.path.join(args.output_base_dir, project.name)
|
| 198 |
+
create_output_directories(output_dir)
|
| 199 |
+
|
| 200 |
+
# Extract class information
|
| 201 |
+
id2label, id2color, label2id = extract_class_info(project, output_dir)
|
| 202 |
+
|
| 203 |
+
# Get datasets and calculate splits
|
| 204 |
+
datasets = project.datasets
|
| 205 |
+
print(f"Datasets: {len(datasets)}")
|
| 206 |
+
train_items, val_items = calculate_split_counts(datasets, args.train_ratio)
|
| 207 |
+
|
| 208 |
+
# Process datasets
|
| 209 |
+
process_datasets(
|
| 210 |
+
project,
|
| 211 |
+
datasets,
|
| 212 |
+
output_dir,
|
| 213 |
+
label2id,
|
| 214 |
+
train_items,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
print("Processing completed successfully!")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
main()
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .parse import run_supervisely_parser
|
| 2 |
+
from .train import train_model
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"run_supervisely_parser",
|
| 6 |
+
"train_model",
|
| 7 |
+
]
|
utils/parse.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import zipfile
|
| 4 |
+
import subprocess
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def run_supervisely_parser(
|
| 9 |
+
project_path: str,
|
| 10 |
+
train_ratio: float,
|
| 11 |
+
seed: int,
|
| 12 |
+
) -> str:
|
| 13 |
+
"""Extract a Supervisely project zip and run the parser script inside .venv-sly.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
project_path : (str)
|
| 18 |
+
Path to the uploaded Supervisely project .zip.
|
| 19 |
+
train_ratio : float
|
| 20 |
+
Portion of data to allocate to training (remainder is validation).
|
| 21 |
+
seed : int
|
| 22 |
+
Random seed forwarded to the parser for reproducible splits.
|
| 23 |
+
|
| 24 |
+
Returns
|
| 25 |
+
-------
|
| 26 |
+
str
|
| 27 |
+
Path to the parsed dataset directory produced by the parser script.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
project_zip = Path(project_path)
|
| 31 |
+
if not project_zip.exists():
|
| 32 |
+
raise FileNotFoundError(
|
| 33 |
+
f"Provided project zip not found: {project_zip}"
|
| 34 |
+
)
|
| 35 |
+
if project_zip.suffix.lower() != ".zip":
|
| 36 |
+
raise ValueError("Supervisely project must be a .zip archive")
|
| 37 |
+
|
| 38 |
+
project_dir = project_zip.parent
|
| 39 |
+
extract_dir = Path(tempfile.mkdtemp(dir=project_dir))
|
| 40 |
+
output_base_dir = Path(tempfile.mkdtemp(dir=project_dir))
|
| 41 |
+
|
| 42 |
+
with zipfile.ZipFile(project_zip, "r") as zf:
|
| 43 |
+
zf.extractall(extract_dir)
|
| 44 |
+
|
| 45 |
+
def find_project_root(root: Path) -> Path:
|
| 46 |
+
if (root / "meta.json").exists():
|
| 47 |
+
return root
|
| 48 |
+
for child in root.iterdir():
|
| 49 |
+
if child.is_dir() and (child / "meta.json").exists():
|
| 50 |
+
return child
|
| 51 |
+
raise FileNotFoundError(
|
| 52 |
+
f"Could not locate 'meta.json' inside extracted archive at {root}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
project_root = find_project_root(extract_dir)
|
| 56 |
+
|
| 57 |
+
repo_root = Path(__file__).resolve().parent.parent
|
| 58 |
+
parser_script = repo_root / "scripts" / "supervisely_parser.py"
|
| 59 |
+
venv_python = repo_root / ".venv-sly" / "bin" / "python"
|
| 60 |
+
|
| 61 |
+
if not parser_script.exists():
|
| 62 |
+
raise FileNotFoundError(
|
| 63 |
+
f"Parser script not found: {parser_script}",
|
| 64 |
+
)
|
| 65 |
+
if not venv_python.exists():
|
| 66 |
+
raise FileNotFoundError(
|
| 67 |
+
"Expected .venv-sly Python interpreter at: " f"{venv_python}",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
cmd = [
|
| 71 |
+
str(venv_python),
|
| 72 |
+
str(parser_script),
|
| 73 |
+
"--project_dir",
|
| 74 |
+
str(project_root),
|
| 75 |
+
"--output_base_dir",
|
| 76 |
+
str(output_base_dir),
|
| 77 |
+
"--train_ratio",
|
| 78 |
+
str(train_ratio),
|
| 79 |
+
"--seed",
|
| 80 |
+
str(seed),
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
result = subprocess.run(
|
| 84 |
+
cmd,
|
| 85 |
+
capture_output=True,
|
| 86 |
+
text=True,
|
| 87 |
+
env={**os.environ},
|
| 88 |
+
)
|
| 89 |
+
if result.returncode != 0:
|
| 90 |
+
raise RuntimeError(
|
| 91 |
+
"Supervisely parser failed.\n"
|
| 92 |
+
f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
|
| 93 |
+
)
|
| 94 |
+
produced_dirs = [p for p in output_base_dir.iterdir() if p.is_dir()]
|
| 95 |
+
if len(produced_dirs) != 1:
|
| 96 |
+
raise RuntimeError(
|
| 97 |
+
"Could not unambiguously determine parsed dataset directory in "
|
| 98 |
+
f"{output_base_dir}. Found: {produced_dirs}"
|
| 99 |
+
)
|
| 100 |
+
dataset_dir = produced_dirs[0]
|
| 101 |
+
return str(dataset_dir)
|
utils/train.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SegFormer Fine-tuning Script
|
| 3 |
+
|
| 4 |
+
This script fine-tunes a SegFormer model on a custom semantic segmentation
|
| 5 |
+
dataset. It provides configurable parameters for training hyperparameters
|
| 6 |
+
and dataset settings.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import zipfile
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from torch.utils.data import Dataset, DataLoader
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from transformers import (
|
| 20 |
+
SegformerImageProcessor,
|
| 21 |
+
SegformerForSemanticSegmentation,
|
| 22 |
+
)
|
| 23 |
+
import evaluate
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SemanticSegmentationDataset(Dataset):
|
| 28 |
+
"""Image (semantic) segmentation dataset."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
root_dir,
|
| 33 |
+
image_processor,
|
| 34 |
+
train=True,
|
| 35 |
+
data_percent=100,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
root_dir (string): Root directory of the dataset containing
|
| 40 |
+
the images + annotations.
|
| 41 |
+
image_processor (SegFormerImageProcessor): image processor to
|
| 42 |
+
prepare images + segmentation maps.
|
| 43 |
+
train (bool): Whether to load "training" or "validation"
|
| 44 |
+
images + annotations.
|
| 45 |
+
data_percent (int): Percentage of the dataset to use.
|
| 46 |
+
100 means all data, 50 means half of the data.
|
| 47 |
+
"""
|
| 48 |
+
self.root_dir = root_dir
|
| 49 |
+
self.image_processor = image_processor
|
| 50 |
+
self.train = train
|
| 51 |
+
|
| 52 |
+
sub_path = "training" if self.train else "validation"
|
| 53 |
+
self.img_dir = os.path.join(self.root_dir, "images", sub_path)
|
| 54 |
+
self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path)
|
| 55 |
+
|
| 56 |
+
# read images
|
| 57 |
+
image_file_names = []
|
| 58 |
+
for root, dirs, files in os.walk(self.img_dir):
|
| 59 |
+
image_file_names.extend(files)
|
| 60 |
+
self.images = sorted(image_file_names)
|
| 61 |
+
|
| 62 |
+
# read annotations
|
| 63 |
+
annotation_file_names = []
|
| 64 |
+
for root, dirs, files in os.walk(self.ann_dir):
|
| 65 |
+
annotation_file_names.extend(files)
|
| 66 |
+
self.annotations = sorted(annotation_file_names)
|
| 67 |
+
|
| 68 |
+
assert len(self.images) == len(
|
| 69 |
+
self.annotations
|
| 70 |
+
), "There must be as many images as there are segmentation maps"
|
| 71 |
+
|
| 72 |
+
# Apply data_percent to limit the dataset size
|
| 73 |
+
data_percent = data_percent / 100.0
|
| 74 |
+
if data_percent < 1.0:
|
| 75 |
+
images_num_samples = int(len(self.images) * data_percent)
|
| 76 |
+
annotations_num_samples = int(len(self.annotations) * data_percent)
|
| 77 |
+
self.images = self.images[:images_num_samples]
|
| 78 |
+
self.annotations = self.annotations[:annotations_num_samples]
|
| 79 |
+
|
| 80 |
+
def __len__(self):
|
| 81 |
+
return len(self.images)
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, idx):
|
| 84 |
+
image = Image.open(os.path.join(self.img_dir, self.images[idx]))
|
| 85 |
+
segmentation_map = Image.open(
|
| 86 |
+
os.path.join(
|
| 87 |
+
self.ann_dir,
|
| 88 |
+
self.annotations[idx],
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
encoded_inputs = self.image_processor(
|
| 92 |
+
image,
|
| 93 |
+
segmentation_map,
|
| 94 |
+
return_tensors="pt",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
for k, v in encoded_inputs.items():
|
| 98 |
+
encoded_inputs[k].squeeze_() # remove batch dimension
|
| 99 |
+
|
| 100 |
+
return encoded_inputs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MeanDice:
|
| 104 |
+
def __init__(self):
|
| 105 |
+
self.reset()
|
| 106 |
+
|
| 107 |
+
def reset(self):
|
| 108 |
+
"""Reset stored predictions and references."""
|
| 109 |
+
self.predictions = []
|
| 110 |
+
self.references = []
|
| 111 |
+
|
| 112 |
+
def add_batch(self, predictions, references):
|
| 113 |
+
"""
|
| 114 |
+
Add a batch of predictions and references.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
predictions (np.ndarray): Predicted class indices
|
| 118 |
+
references (np.ndarray): Ground truth class indices
|
| 119 |
+
"""
|
| 120 |
+
self.predictions.append(predictions)
|
| 121 |
+
self.references.append(references)
|
| 122 |
+
|
| 123 |
+
def compute(self, num_labels, ignore_index=None):
|
| 124 |
+
"""Compute mean Dice score across all stored batches."""
|
| 125 |
+
predictions = np.concatenate([p.flatten() for p in self.predictions])
|
| 126 |
+
references = np.concatenate([r.flatten() for r in self.references])
|
| 127 |
+
|
| 128 |
+
dice_scores = []
|
| 129 |
+
|
| 130 |
+
for class_id in range(num_labels):
|
| 131 |
+
pred_mask = predictions == class_id
|
| 132 |
+
ref_mask = references == class_id
|
| 133 |
+
|
| 134 |
+
# Exclude ignore_index
|
| 135 |
+
if ignore_index is not None:
|
| 136 |
+
valid_mask = references != ignore_index
|
| 137 |
+
pred_mask = pred_mask & valid_mask
|
| 138 |
+
ref_mask = ref_mask & valid_mask
|
| 139 |
+
|
| 140 |
+
intersection = np.sum(pred_mask & ref_mask)
|
| 141 |
+
union = np.sum(pred_mask) + np.sum(ref_mask)
|
| 142 |
+
|
| 143 |
+
if union == 0:
|
| 144 |
+
dice = 1.0 if intersection == 0 else 0.0
|
| 145 |
+
else:
|
| 146 |
+
dice = 2.0 * intersection / union
|
| 147 |
+
|
| 148 |
+
dice_scores.append(dice)
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"mean_dice": float(np.mean(dice_scores)),
|
| 152 |
+
"per_class_dice": dice_scores,
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_latest_model_dir(base_path: str = "./segformer_finetuned") -> Path:
|
| 157 |
+
"""
|
| 158 |
+
Returns the Path to the latest model directory based on
|
| 159 |
+
timestamp folder names.
|
| 160 |
+
|
| 161 |
+
Folder names must follow the format: YYYY-MM-DD_HH-MM-SS
|
| 162 |
+
"""
|
| 163 |
+
base = Path(base_path)
|
| 164 |
+
if not base.exists() or not base.is_dir():
|
| 165 |
+
raise FileNotFoundError(f"Directory not found: {base_path}")
|
| 166 |
+
|
| 167 |
+
model_dirs = []
|
| 168 |
+
for d in base.iterdir():
|
| 169 |
+
if d.is_dir():
|
| 170 |
+
try:
|
| 171 |
+
dt = datetime.strptime(d.name, "%Y-%m-%d_%H-%M-%S")
|
| 172 |
+
model_dirs.append((dt, d))
|
| 173 |
+
except ValueError:
|
| 174 |
+
continue # Skip non-matching directories
|
| 175 |
+
|
| 176 |
+
if not model_dirs:
|
| 177 |
+
raise FileNotFoundError(
|
| 178 |
+
"No model directories found with valid timestamp format."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Return the directory with the latest timestamp
|
| 182 |
+
return max(model_dirs, key=lambda x: x[0])[1]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def load_model_and_labels(data_dir, model_path):
|
| 186 |
+
"""Load the model and label mappings."""
|
| 187 |
+
# Load id2label mapping from JSON file
|
| 188 |
+
id2label = json.load(open(f"{data_dir}/id2label.json", mode="r"))
|
| 189 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 190 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 191 |
+
|
| 192 |
+
# Load id2color mapping from JSON file
|
| 193 |
+
id2color = json.load(open(f"{data_dir}/id2color.json", "r"))
|
| 194 |
+
|
| 195 |
+
print(f"Loaded {len(id2label)} classes:")
|
| 196 |
+
for i, label in id2label.items():
|
| 197 |
+
print(f" {i}: {label}")
|
| 198 |
+
|
| 199 |
+
# Load model
|
| 200 |
+
model = SegformerForSemanticSegmentation.from_pretrained(
|
| 201 |
+
model_path,
|
| 202 |
+
num_labels=len(id2label),
|
| 203 |
+
id2label=id2label,
|
| 204 |
+
label2id=label2id,
|
| 205 |
+
)
|
| 206 |
+
return model, id2label, id2color
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def create_datasets_and_dataloaders(
|
| 210 |
+
image_width,
|
| 211 |
+
image_height,
|
| 212 |
+
data_dir,
|
| 213 |
+
batch_size,
|
| 214 |
+
data_percent,
|
| 215 |
+
):
|
| 216 |
+
"""Create datasets and dataloaders."""
|
| 217 |
+
image_processor = SegformerImageProcessor(
|
| 218 |
+
size={"height": image_height, "width": image_width},
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
train_dataset = SemanticSegmentationDataset(
|
| 222 |
+
root_dir=data_dir,
|
| 223 |
+
image_processor=image_processor,
|
| 224 |
+
train=True,
|
| 225 |
+
data_percent=data_percent,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
valid_dataset = SemanticSegmentationDataset(
|
| 229 |
+
root_dir=data_dir,
|
| 230 |
+
image_processor=image_processor,
|
| 231 |
+
train=False,
|
| 232 |
+
data_percent=data_percent,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
print(f"Number of training examples: {len(train_dataset)}")
|
| 236 |
+
print(f"Number of validation examples: {len(valid_dataset)}")
|
| 237 |
+
|
| 238 |
+
train_dataloader = DataLoader(
|
| 239 |
+
train_dataset,
|
| 240 |
+
batch_size=batch_size,
|
| 241 |
+
shuffle=True,
|
| 242 |
+
)
|
| 243 |
+
valid_dataloader = DataLoader(
|
| 244 |
+
valid_dataset,
|
| 245 |
+
batch_size=batch_size,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return train_dataloader, valid_dataloader
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def class_indices_to_rgb(class_indices, id2color):
|
| 252 |
+
"""Convert class indices to RGB colored image."""
|
| 253 |
+
# class_indices shape: (H, W) with integer class IDs
|
| 254 |
+
height, width = class_indices.shape
|
| 255 |
+
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
|
| 256 |
+
|
| 257 |
+
for class_id, color in id2color.items():
|
| 258 |
+
rgb_image[class_indices == class_id] = color
|
| 259 |
+
|
| 260 |
+
return rgb_image
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def validate_model(
|
| 264 |
+
model: SegformerForSemanticSegmentation,
|
| 265 |
+
dataloader,
|
| 266 |
+
device,
|
| 267 |
+
id2label,
|
| 268 |
+
calc_dice=False,
|
| 269 |
+
epoch=None,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
Validate the model on a validation set and return loss, IoU, accuracy.
|
| 273 |
+
"""
|
| 274 |
+
model.eval()
|
| 275 |
+
metric = evaluate.load("mean_iou")
|
| 276 |
+
dice = MeanDice()
|
| 277 |
+
total_loss = 0.0
|
| 278 |
+
num_batches = 0
|
| 279 |
+
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
for batch in tqdm(
|
| 282 |
+
dataloader,
|
| 283 |
+
desc="Validating Epoch " + str(epoch if epoch is not None else ""),
|
| 284 |
+
leave=False,
|
| 285 |
+
unit="batches",
|
| 286 |
+
):
|
| 287 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 288 |
+
labels = batch["labels"].to(device)
|
| 289 |
+
|
| 290 |
+
outputs = model(pixel_values=pixel_values, labels=labels)
|
| 291 |
+
logits = outputs.logits
|
| 292 |
+
loss = outputs.loss
|
| 293 |
+
|
| 294 |
+
total_loss += loss.item()
|
| 295 |
+
num_batches += 1
|
| 296 |
+
|
| 297 |
+
upsampled_logits = nn.functional.interpolate(
|
| 298 |
+
logits,
|
| 299 |
+
size=labels.shape[-2:],
|
| 300 |
+
mode="bilinear",
|
| 301 |
+
align_corners=False,
|
| 302 |
+
)
|
| 303 |
+
predicted = upsampled_logits.argmax(dim=1)
|
| 304 |
+
|
| 305 |
+
# Store predictions and references for additional metrics
|
| 306 |
+
pred_np = predicted.detach().cpu().numpy()
|
| 307 |
+
ref_np = labels.detach().cpu().numpy()
|
| 308 |
+
|
| 309 |
+
metric.add_batch(
|
| 310 |
+
predictions=pred_np,
|
| 311 |
+
references=ref_np,
|
| 312 |
+
)
|
| 313 |
+
if calc_dice:
|
| 314 |
+
dice.add_batch(
|
| 315 |
+
predictions=pred_np,
|
| 316 |
+
references=ref_np,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Calculate IoU and accuracy
|
| 320 |
+
result = metric.compute(
|
| 321 |
+
num_labels=len(id2label),
|
| 322 |
+
ignore_index=10,
|
| 323 |
+
reduce_labels=False,
|
| 324 |
+
)
|
| 325 |
+
if calc_dice:
|
| 326 |
+
dice_result = dice.compute(
|
| 327 |
+
num_labels=len(id2label),
|
| 328 |
+
ignore_index=10,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
| 332 |
+
return (
|
| 333 |
+
avg_loss,
|
| 334 |
+
result["mean_iou"],
|
| 335 |
+
result["per_category_iou"],
|
| 336 |
+
result["mean_accuracy"],
|
| 337 |
+
result["per_category_accuracy"],
|
| 338 |
+
dice_result["mean_dice"] if calc_dice else None,
|
| 339 |
+
dice_result["per_class_dice"] if calc_dice else None,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def run_training(
|
| 344 |
+
model: SegformerForSemanticSegmentation,
|
| 345 |
+
device,
|
| 346 |
+
train_dataloader,
|
| 347 |
+
valid_dataloader,
|
| 348 |
+
id2label,
|
| 349 |
+
num_epochs,
|
| 350 |
+
learning_rate,
|
| 351 |
+
early_stopping,
|
| 352 |
+
validate_every,
|
| 353 |
+
):
|
| 354 |
+
"""Train the model.
|
| 355 |
+
|
| 356 |
+
Returns
|
| 357 |
+
-------
|
| 358 |
+
tuple(best_model, metrics)
|
| 359 |
+
best_model : nn.Module
|
| 360 |
+
metrics : dict with lists for keys: 'epoch', 'train_loss', 'train_iou',
|
| 361 |
+
'train_acc', 'val_loss', 'val_iou', 'val_acc'
|
| 362 |
+
"""
|
| 363 |
+
# Setup device
|
| 364 |
+
model.to(device)
|
| 365 |
+
|
| 366 |
+
# Setup optimizer
|
| 367 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 368 |
+
|
| 369 |
+
# Setup metrics
|
| 370 |
+
metrics = {
|
| 371 |
+
"epoch": [],
|
| 372 |
+
"train_loss": [],
|
| 373 |
+
"train_iou": [],
|
| 374 |
+
"train_acc": [],
|
| 375 |
+
"val_loss": [],
|
| 376 |
+
"val_iou": [],
|
| 377 |
+
"val_acc": [],
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
metric = evaluate.load("mean_iou")
|
| 381 |
+
|
| 382 |
+
model.train()
|
| 383 |
+
|
| 384 |
+
# Initial validation
|
| 385 |
+
(
|
| 386 |
+
loss,
|
| 387 |
+
iou,
|
| 388 |
+
per_class_iou,
|
| 389 |
+
acc,
|
| 390 |
+
per_class_acc,
|
| 391 |
+
dice,
|
| 392 |
+
dice_per_class,
|
| 393 |
+
) = validate_model(
|
| 394 |
+
model=model,
|
| 395 |
+
dataloader=valid_dataloader,
|
| 396 |
+
device=device,
|
| 397 |
+
id2label=id2label,
|
| 398 |
+
calc_dice=True,
|
| 399 |
+
epoch=0,
|
| 400 |
+
)
|
| 401 |
+
# Add to metrics at epoch 0
|
| 402 |
+
metrics["epoch"].append(int(0))
|
| 403 |
+
metrics["val_loss"].append(loss)
|
| 404 |
+
metrics["val_iou"].append(iou)
|
| 405 |
+
metrics["val_acc"].append(acc)
|
| 406 |
+
metrics["train_loss"].append(None)
|
| 407 |
+
metrics["train_iou"].append(None)
|
| 408 |
+
metrics["train_acc"].append(None)
|
| 409 |
+
|
| 410 |
+
initial_dice = dice
|
| 411 |
+
|
| 412 |
+
best_model = model
|
| 413 |
+
|
| 414 |
+
best_iou = iou
|
| 415 |
+
patience = early_stopping
|
| 416 |
+
epochs_without_improvement = 0
|
| 417 |
+
for epoch in tqdm(
|
| 418 |
+
range(num_epochs),
|
| 419 |
+
desc="Training Epochs",
|
| 420 |
+
unit="epochs",
|
| 421 |
+
):
|
| 422 |
+
epoch_loss = 0.0
|
| 423 |
+
num_batches = 0
|
| 424 |
+
model.train() # Ensure model is in training mode
|
| 425 |
+
|
| 426 |
+
progress_bar = tqdm(
|
| 427 |
+
train_dataloader,
|
| 428 |
+
desc=f"Training Epoch {epoch + 1}",
|
| 429 |
+
leave=True,
|
| 430 |
+
unit="batches",
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
for idx, batch in enumerate(progress_bar):
|
| 434 |
+
# Get the inputs
|
| 435 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 436 |
+
labels = batch["labels"].to(device)
|
| 437 |
+
|
| 438 |
+
# Zero the parameter gradients
|
| 439 |
+
optimizer.zero_grad()
|
| 440 |
+
|
| 441 |
+
# Forward + backward + optimize
|
| 442 |
+
outputs = model(pixel_values=pixel_values, labels=labels)
|
| 443 |
+
loss, logits = outputs.loss, outputs.logits
|
| 444 |
+
|
| 445 |
+
loss.backward()
|
| 446 |
+
optimizer.step()
|
| 447 |
+
|
| 448 |
+
epoch_loss += loss.item()
|
| 449 |
+
num_batches += 1
|
| 450 |
+
|
| 451 |
+
# Evaluate training batch
|
| 452 |
+
with torch.no_grad():
|
| 453 |
+
upsampled_logits = nn.functional.interpolate(
|
| 454 |
+
logits,
|
| 455 |
+
size=labels.shape[-2:],
|
| 456 |
+
mode="bilinear",
|
| 457 |
+
align_corners=False,
|
| 458 |
+
)
|
| 459 |
+
predicted = upsampled_logits.argmax(dim=1)
|
| 460 |
+
|
| 461 |
+
# Store for metric calculation
|
| 462 |
+
pred_np = predicted.detach().cpu().numpy()
|
| 463 |
+
ref_np = labels.detach().cpu().numpy()
|
| 464 |
+
|
| 465 |
+
# Note: metric expects predictions + labels as numpy arrays
|
| 466 |
+
metric.add_batch(
|
| 467 |
+
predictions=pred_np,
|
| 468 |
+
references=ref_np,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
train_metrics = metric.compute(
|
| 472 |
+
num_labels=len(id2label),
|
| 473 |
+
ignore_index=10,
|
| 474 |
+
reduce_labels=False,
|
| 475 |
+
)
|
| 476 |
+
train_loss = epoch_loss / num_batches if num_batches else 0.0
|
| 477 |
+
|
| 478 |
+
# Validation
|
| 479 |
+
if (epoch + 1) % validate_every == 0:
|
| 480 |
+
(
|
| 481 |
+
val_loss,
|
| 482 |
+
val_iou,
|
| 483 |
+
val_per_class_iou,
|
| 484 |
+
val_acc,
|
| 485 |
+
val_per_class_acc,
|
| 486 |
+
val_dice,
|
| 487 |
+
val_dice_per_class,
|
| 488 |
+
) = validate_model(
|
| 489 |
+
model=model,
|
| 490 |
+
dataloader=valid_dataloader,
|
| 491 |
+
device=device,
|
| 492 |
+
id2label=id2label,
|
| 493 |
+
epoch=epoch + 1,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Record metrics
|
| 497 |
+
metrics["epoch"].append(int(epoch + 1))
|
| 498 |
+
metrics["train_loss"].append(train_loss)
|
| 499 |
+
metrics["train_iou"].append(train_metrics["mean_iou"])
|
| 500 |
+
metrics["train_acc"].append(train_metrics["mean_accuracy"])
|
| 501 |
+
metrics["val_loss"].append(val_loss)
|
| 502 |
+
metrics["val_iou"].append(val_iou)
|
| 503 |
+
metrics["val_acc"].append(val_acc)
|
| 504 |
+
|
| 505 |
+
# Save the best model
|
| 506 |
+
if val_iou > best_iou:
|
| 507 |
+
best_model = model
|
| 508 |
+
best_iou = val_iou
|
| 509 |
+
epochs_without_improvement = 0
|
| 510 |
+
else:
|
| 511 |
+
epochs_without_improvement += 1
|
| 512 |
+
|
| 513 |
+
if epochs_without_improvement >= patience:
|
| 514 |
+
tqdm.write(
|
| 515 |
+
f"Early stopping after {patience} epochs with no improvement",
|
| 516 |
+
)
|
| 517 |
+
break
|
| 518 |
+
|
| 519 |
+
return best_model, metrics, initial_dice
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def extract_model_zip(model_zip_path):
|
| 523 |
+
"""Extract model zip file and return the model directory."""
|
| 524 |
+
|
| 525 |
+
if not os.path.exists(model_zip_path):
|
| 526 |
+
raise FileNotFoundError(f"Model zip file not found: {model_zip_path}")
|
| 527 |
+
|
| 528 |
+
with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
|
| 529 |
+
extract_dir = os.path.join(os.path.dirname(model_zip_path), "output")
|
| 530 |
+
zip_ref.extractall(extract_dir)
|
| 531 |
+
|
| 532 |
+
# Check nested folder
|
| 533 |
+
if len(os.listdir(extract_dir)) == 1:
|
| 534 |
+
return os.path.join(extract_dir, os.listdir(extract_dir)[0])
|
| 535 |
+
else:
|
| 536 |
+
return extract_dir
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def train_model(
|
| 540 |
+
data_dir,
|
| 541 |
+
base_model_zip,
|
| 542 |
+
image_width,
|
| 543 |
+
image_height,
|
| 544 |
+
batch_size,
|
| 545 |
+
data_percent,
|
| 546 |
+
num_epochs,
|
| 547 |
+
learning_rate,
|
| 548 |
+
early_stopping,
|
| 549 |
+
validate_every,
|
| 550 |
+
):
|
| 551 |
+
|
| 552 |
+
model_path = extract_model_zip(base_model_zip)
|
| 553 |
+
|
| 554 |
+
# Load model and labels
|
| 555 |
+
model, id2label, id2color = load_model_and_labels(data_dir, model_path)
|
| 556 |
+
|
| 557 |
+
# Create datasets and dataloaders
|
| 558 |
+
train_dataloader, valid_dataloader = create_datasets_and_dataloaders(
|
| 559 |
+
image_width,
|
| 560 |
+
image_height,
|
| 561 |
+
data_dir,
|
| 562 |
+
batch_size,
|
| 563 |
+
data_percent,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 567 |
+
print(f"Using device: {device}")
|
| 568 |
+
|
| 569 |
+
# Train the model
|
| 570 |
+
best_model, metrics, initial_dice = run_training(
|
| 571 |
+
model,
|
| 572 |
+
device,
|
| 573 |
+
train_dataloader,
|
| 574 |
+
valid_dataloader,
|
| 575 |
+
id2label,
|
| 576 |
+
num_epochs,
|
| 577 |
+
learning_rate,
|
| 578 |
+
early_stopping,
|
| 579 |
+
validate_every,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Final validation
|
| 583 |
+
(
|
| 584 |
+
loss,
|
| 585 |
+
iou,
|
| 586 |
+
per_class_iou,
|
| 587 |
+
acc,
|
| 588 |
+
per_class_acc,
|
| 589 |
+
dice,
|
| 590 |
+
dice_per_class,
|
| 591 |
+
) = validate_model(
|
| 592 |
+
model=best_model,
|
| 593 |
+
dataloader=valid_dataloader,
|
| 594 |
+
device=device,
|
| 595 |
+
id2label=id2label,
|
| 596 |
+
calc_dice=True,
|
| 597 |
+
epoch=0,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
final_dice = dice
|
| 601 |
+
|
| 602 |
+
return best_model, metrics, [initial_dice, final_dice]
|