YusufMesbah commited on
Commit
e4aef33
·
1 Parent(s): 7a8be68

Implement initial version of SegFormer training pipeline with dataset parsing and model training functionalities. Added Dockerfile for environment setup, utility scripts for parsing and training, and Gradio interface for user interaction.

Browse files
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.8.0-cuda12.9-cudnn9-runtime
2
+
3
+ ENV PYTHONUNBUFFERED=1
4
+
5
+ RUN useradd -m -u 1000 user
6
+ WORKDIR /app
7
+
8
+ # Install system dependencies for building Python packages
9
+ RUN apt-get update && apt-get install -y \
10
+ build-essential \
11
+ libffi-dev \
12
+ libssl-dev \
13
+ ffmpeg \
14
+ libsm6 \
15
+ libxext6 \
16
+ libmagic1 \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+
20
+ # Create virtual environment for Sly (keep isolated)
21
+ RUN python -m venv /app/.venv-sly
22
+ RUN /app/.venv-sly/bin/pip install --upgrade pip
23
+ COPY --chown=user requirements-sly.txt requirements-sly.txt
24
+ RUN /app/.venv-sly/bin/pip install -r requirements-sly.txt
25
+
26
+ # Install Gradio and other dependencies
27
+ RUN pip install --upgrade pip
28
+ COPY --chown=user requirements.txt requirements.txt
29
+ RUN pip install -r requirements.txt
30
+
31
+
32
+ # Copy the rest of the app
33
+ COPY --chown=user . .
34
+
35
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
36
+
37
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import gradio as gr
4
+ import pandas as pd
5
+
6
+ from utils import run_supervisely_parser, train_model
7
+
8
+
9
+ def run_pipeline(
10
+ base_model_zip,
11
+ supervisely_project_zip,
12
+ train_ratio,
13
+ seed,
14
+ data_percent,
15
+ batch_size,
16
+ num_epochs,
17
+ learning_rate,
18
+ image_width,
19
+ image_height,
20
+ early_stopping,
21
+ validate_every,
22
+ pr=gr.Progress(track_tqdm=True),
23
+ ):
24
+
25
+ # Parsing
26
+ yield (
27
+ gr.update(interactive=False), # run button
28
+ gr.update( # status textbox
29
+ value="Parsing Supervisely project ...",
30
+ visible=True,
31
+ ),
32
+ gr.update(visible=False), # model download button
33
+ gr.update(visible=False), # output tab
34
+ gr.update(value=None), # train IoU plot
35
+ gr.update(value=None), # val IoU plot
36
+ gr.update(value=None), # metrics table
37
+ )
38
+
39
+ dataset_dir = run_supervisely_parser(
40
+ project_path=supervisely_project_zip,
41
+ train_ratio=train_ratio,
42
+ seed=seed,
43
+ )
44
+
45
+ # Training
46
+ yield (
47
+ gr.update(interactive=False), # run button
48
+ gr.update( # status textbox
49
+ value="Starting model training...",
50
+ ),
51
+ gr.update(visible=False), # model download button
52
+ gr.update(visible=False), # output tab
53
+ gr.update(value=None), # train IoU plot
54
+ gr.update(value=None), # val IoU plot
55
+ gr.update(value=None), # metrics table
56
+ )
57
+
58
+ best_model, metrics, dice = train_model(
59
+ dataset_dir,
60
+ base_model_zip,
61
+ image_width,
62
+ image_height,
63
+ batch_size,
64
+ data_percent,
65
+ num_epochs,
66
+ learning_rate,
67
+ early_stopping,
68
+ validate_every,
69
+ )
70
+
71
+ # Saving model
72
+ yield (
73
+ gr.update(interactive=False), # run button
74
+ gr.update( # status textbox
75
+ value="Saving best model...",
76
+ ),
77
+ gr.update(visible=False), # model download button
78
+ gr.update(visible=False), # output tab
79
+ gr.update(value=None), # train IoU plot
80
+ gr.update(value=None), # val IoU plot
81
+ gr.update(value=None), # metrics table
82
+ )
83
+ best_model_dir = os.path.join(
84
+ os.path.dirname(base_model_zip),
85
+ "best_model",
86
+ )
87
+ best_model.save_pretrained(best_model_dir)
88
+
89
+ best_model_zip_path = shutil.make_archive(
90
+ base_name=best_model_dir,
91
+ format="zip",
92
+ root_dir=best_model_dir,
93
+ )
94
+
95
+ metrics_df = pd.DataFrame(metrics)
96
+
97
+ initial_epoch_metrics = metrics_df.iloc[0]
98
+ final_epoch_metrics = metrics_df.iloc[-1]
99
+
100
+ # metrics comparison table use epoch 0 as before and final as after
101
+ metrics_comparison_df = pd.DataFrame(
102
+ {
103
+ "Metric": ["Accuracy", "IoU", "Loss", "Dice"],
104
+ "Before": [
105
+ initial_epoch_metrics["val_acc"],
106
+ initial_epoch_metrics["val_iou"],
107
+ initial_epoch_metrics["val_loss"],
108
+ dice[0],
109
+ ],
110
+ "After": [
111
+ final_epoch_metrics["val_acc"],
112
+ final_epoch_metrics["val_iou"],
113
+ final_epoch_metrics["val_loss"],
114
+ dice[1],
115
+ ],
116
+ }
117
+ )
118
+
119
+ yield (
120
+ gr.update(interactive=True), # run button
121
+ gr.update(visible=False), # status textbox
122
+ gr.update( # model download button
123
+ value=best_model_zip_path,
124
+ visible=True,
125
+ ),
126
+ gr.update(visible=True), # output tab
127
+ gr.update(value=metrics_df), # train IoU plot
128
+ gr.update(value=metrics_df), # val IoU plot
129
+ gr.update(
130
+ value=metrics_comparison_df,
131
+ visible=True,
132
+ ),
133
+ )
134
+
135
+
136
+ def _toggle_run_btn(base_model, project):
137
+ """Enable run button only when both required files are selected."""
138
+ ready = bool(base_model and project)
139
+ return gr.update(interactive=ready)
140
+
141
+
142
+ with gr.Blocks(title="SegFormer Training & Dataset Pipeline") as demo:
143
+ gr.Markdown(
144
+ "# SegFormer Training Pipeline\n"
145
+ "Upload your base model and Supervisely project, "
146
+ "tweak parsing & training hyperparameters, then click "
147
+ "**Run Training**."
148
+ )
149
+
150
+ with gr.Row():
151
+ base_model_zip = gr.File(
152
+ label="Base PyTorch Model (.zip)",
153
+ file_types=[".zip"],
154
+ file_count="single",
155
+ )
156
+ supervisely_project_zip = gr.File(
157
+ label="Supervisely Project (.zip)",
158
+ file_types=[".zip"],
159
+ file_count="single",
160
+ )
161
+
162
+ with gr.Tab("Training"):
163
+ gr.Markdown("Adjust training hyperparameters.")
164
+ with gr.Row():
165
+ data_percent = gr.Slider(
166
+ minimum=1,
167
+ maximum=100,
168
+ step=1,
169
+ value=100,
170
+ label="Data Percent (%) used for training",
171
+ )
172
+ batch_size = gr.Number(
173
+ value=32,
174
+ label="Batch Size (samples/step)",
175
+ precision=0,
176
+ minimum=1,
177
+ )
178
+ num_epochs = gr.Number(
179
+ value=60,
180
+ label="Epochs (max passes)",
181
+ precision=0,
182
+ minimum=1,
183
+ )
184
+ with gr.Row():
185
+ learning_rate = gr.Number(
186
+ value=5e-5,
187
+ label="Learning Rate",
188
+ minimum=0.0,
189
+ maximum=1.0,
190
+ )
191
+ image_width = gr.Number(
192
+ value=640,
193
+ label="Image Width (px)",
194
+ precision=0,
195
+ minimum=1,
196
+ )
197
+ image_height = gr.Number(
198
+ value=640,
199
+ label="Image Height (px)",
200
+ precision=0,
201
+ minimum=1,
202
+ )
203
+ with gr.Row():
204
+ early_stopping = gr.Number(
205
+ value=3,
206
+ label="Early Stopping Patience (epochs w/o improvement)",
207
+ precision=0,
208
+ minimum=0,
209
+ )
210
+ validate_every = gr.Number(
211
+ value=1,
212
+ label="Validate Every (epochs)",
213
+ precision=0,
214
+ minimum=0,
215
+ )
216
+
217
+ with gr.Tab("Dataset Parsing"):
218
+ gr.Markdown("Configure how the dataset is split and seeded.")
219
+ with gr.Row():
220
+ train_ratio = gr.Slider(
221
+ minimum=0.1,
222
+ maximum=0.95,
223
+ step=0.01,
224
+ value=0.8,
225
+ label="Train Split Ratio (rest used for validation)",
226
+ )
227
+ seed = gr.Number(
228
+ value=42,
229
+ label="Random Seed (reproducibility)",
230
+ precision=0,
231
+ )
232
+
233
+ with gr.Accordion("Parameter Help", open=False):
234
+ gr.Markdown(
235
+ """
236
+ **Base PyTorch Model (.zip)**: Archive containing a folder with
237
+ weights and configuration file.\n
238
+ **Supervisely Project (.zip)**: Archive containing Exported
239
+ Supervisely project
240
+ containing images and annotation JSONs.\n
241
+ **Train Split Ratio**: Fraction of dataset used for training;
242
+ remainder becomes validation.\n
243
+ **Random Seed**: Controls shuffling for reproducible splits &
244
+ training.\n
245
+ **Data Percent**: Subsample percentage of training split (use
246
+ <100 for quick experiments).\n
247
+ **Batch Size**: Samples processed before each optimizer step.\n
248
+ **Epochs**: Maximum complete passes over the (subsampled)
249
+ training set.\n
250
+ **Learning Rate**: Initial optimizer step size.\n
251
+ **Image Width / Height**: Target spatial size for preprocessing
252
+ (resize/crop).\n
253
+ **Early Stopping Patience**: Stop after this many validation
254
+ checks without improvement.\n
255
+ **Validate Every**: Run validation after this many epochs.\n
256
+ """
257
+ )
258
+
259
+ run_btn = gr.Button(
260
+ "Run Training",
261
+ variant="primary",
262
+ interactive=False,
263
+ )
264
+ status = gr.Textbox(
265
+ show_label=False,
266
+ visible=False,
267
+ )
268
+
269
+ with gr.Tab("Results", visible=False) as output_tab:
270
+
271
+ model_download_btn = gr.DownloadButton(
272
+ label="Download Trained Model (.zip)",
273
+ value=None,
274
+ visible=False,
275
+ )
276
+
277
+ # table to show before and after accuracy and iou
278
+ metrics_table = gr.DataFrame(
279
+ label="Metrics Comparison",
280
+ interactive=False,
281
+ wrap=True,
282
+ )
283
+
284
+ with gr.Row():
285
+ train_iou_plot = gr.LinePlot(
286
+ label="Training IoU",
287
+ x="epoch",
288
+ y="train_iou",
289
+ x_title="Epoch",
290
+ y_title="IoU",
291
+ height=400,
292
+ )
293
+ val_iou_plot = gr.LinePlot(
294
+ label="Validation IoU",
295
+ x="epoch",
296
+ y="val_iou",
297
+ x_title="Epoch",
298
+ y_title="IoU",
299
+ height=400,
300
+ )
301
+
302
+ # Enable run button only when both archives provided
303
+ base_model_zip.change(
304
+ _toggle_run_btn,
305
+ inputs=[base_model_zip, supervisely_project_zip],
306
+ outputs=run_btn,
307
+ )
308
+ supervisely_project_zip.change(
309
+ _toggle_run_btn,
310
+ inputs=[base_model_zip, supervisely_project_zip],
311
+ outputs=run_btn,
312
+ )
313
+
314
+ # Click handler
315
+ run_btn.click(
316
+ run_pipeline,
317
+ inputs=[
318
+ base_model_zip,
319
+ supervisely_project_zip,
320
+ train_ratio,
321
+ seed,
322
+ data_percent,
323
+ batch_size,
324
+ num_epochs,
325
+ learning_rate,
326
+ image_width,
327
+ image_height,
328
+ early_stopping,
329
+ validate_every,
330
+ ],
331
+ outputs=[
332
+ run_btn,
333
+ status,
334
+ model_download_btn,
335
+ output_tab,
336
+ train_iou_plot,
337
+ val_iou_plot,
338
+ metrics_table,
339
+ ],
340
+ show_progress_on=status,
341
+ scroll_to_output=True,
342
+ )
343
+
344
+
345
+ if __name__ == "__main__":
346
+ demo.launch()
requirements-sly.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ supervisely>=6.73.418
2
+ numpy>=1.26.4
3
+ pillow>=10.2.0
4
+ tqdm>=4.67.1
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ gradio>=5.42.0
3
+ pillow>=11.3.0
4
+ transformers>=4.55.2
5
+ tqdm>=4.67.1
6
+ evaluate>=0.4.5
scripts/supervisely_parser.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supervisely Parser Script
3
+
4
+ This script parses Supervisely projects and converts them to a format
5
+ suitable for training segmentation models. It extracts class information,
6
+ creates train/validation splits, and converts annotations to indexed
7
+ color masks.
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import random
13
+ import shutil
14
+ import argparse
15
+ import numpy as np
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+
19
+ try:
20
+ import supervisely as sly
21
+ from supervisely import Annotation
22
+ except ImportError as e:
23
+ print(f"Failed to import supervisely: {e}")
24
+ print(
25
+ "Please ensure that the 'supervisely' package is installed and "
26
+ "compatible with your environment."
27
+ )
28
+ raise
29
+
30
+
31
+ def extract_class_info(project, output_dir):
32
+ """Extract class information from the project metadata."""
33
+ id2label = {}
34
+ id2color = {}
35
+
36
+ for obj in project.meta.obj_classes:
37
+ id_str, _, label = obj.name.partition(". ")
38
+ if not label or not id_str.isdigit():
39
+ continue
40
+ index = int(id_str) - 1
41
+
42
+ id2label[index] = label
43
+ id2color[index] = obj.color
44
+
45
+ # Save class mappings
46
+ with open(f"{output_dir}/id2label.json", "w") as f:
47
+ json.dump(id2label, f, sort_keys=True, indent=2)
48
+ with open(f"{output_dir}/id2color.json", "w") as f:
49
+ json.dump(id2color, f, sort_keys=True, indent=2)
50
+
51
+ label2id = {v: k for k, v in id2label.items()}
52
+ return id2label, id2color, label2id
53
+
54
+
55
+ def create_output_directories(output_dir):
56
+ """Create necessary output directories."""
57
+ os.makedirs(f"{output_dir}/images/training", exist_ok=True)
58
+ os.makedirs(f"{output_dir}/annotations/training", exist_ok=True)
59
+ os.makedirs(f"{output_dir}/images/validation", exist_ok=True)
60
+ os.makedirs(f"{output_dir}/annotations/validation", exist_ok=True)
61
+
62
+
63
+ def calculate_split_counts(datasets, train_ratio=0.8):
64
+ """Calculate the number of items for training and validation."""
65
+ total_items = 0
66
+ for dataset in datasets:
67
+ total_items += len(dataset.get_items_names())
68
+
69
+ train_items = int(total_items * train_ratio)
70
+ val_items = total_items - train_items
71
+
72
+ print(
73
+ f"Total items: {total_items}\n"
74
+ f"Train items: {train_items}\n"
75
+ f"Validation items: {val_items}"
76
+ )
77
+
78
+ return train_items, val_items
79
+
80
+
81
+ def to_class_index_mask(
82
+ annotation: Annotation,
83
+ label2id: dict,
84
+ mask_path: str,
85
+ ):
86
+ """Convert annotation to class index mask and save as PNG."""
87
+ height, width = annotation.img_size
88
+ class_mask = np.zeros((height, width), dtype=np.uint8)
89
+
90
+ for label in annotation.labels:
91
+ class_name = label.obj_class.name.partition(". ")[2]
92
+ if class_name not in label2id:
93
+ tqdm.write(f"Skipping unrecognized label: {label}")
94
+ continue # skip unrecognized labels
95
+
96
+ class_index = label2id[class_name]
97
+
98
+ if label.geometry.geometry_name() == "bitmap":
99
+ origin = label.geometry.origin
100
+ top = origin.row
101
+ left = origin.col
102
+ bitmap = label.geometry.data # binary numpy array, shape (h, w)
103
+
104
+ h, w = bitmap.shape
105
+ if top + h > height or left + w > width:
106
+ tqdm.write(f"Skipping label '{class_name}': size mismatch.")
107
+ continue
108
+
109
+ class_mask[top : top + h, left : left + w][bitmap] = class_index
110
+ else:
111
+ continue
112
+
113
+ Image.fromarray(class_mask).save(mask_path)
114
+
115
+
116
+ def process_datasets(
117
+ project,
118
+ datasets,
119
+ output_dir,
120
+ label2id,
121
+ train_items,
122
+ ):
123
+ """Process all datasets and create train/validation splits."""
124
+ for dataset in tqdm(datasets, desc="Processing datasets"):
125
+ items = dataset.get_items_names()
126
+ random.shuffle(items)
127
+
128
+ for i, item in tqdm(
129
+ enumerate(items),
130
+ desc=f"Processing dataset: {dataset.name}",
131
+ total=len(items),
132
+ leave=False,
133
+ ):
134
+ # Determine split
135
+ split = "training" if i < train_items else "validation"
136
+
137
+ # Copy images
138
+ item_paths = dataset.get_item_paths(item)
139
+ img_path = item_paths.img_path
140
+ img_filename = os.path.basename(img_path)
141
+ dest_path = f"{output_dir}/images/{split}/{img_filename}"
142
+ shutil.copy(img_path, dest_path)
143
+
144
+ # Convert and copy annotations
145
+ ann_path = item_paths.ann_path
146
+ ann = sly.Annotation.load_json_file(ann_path, project.meta)
147
+ mask_filename = f"{os.path.splitext(item)[0]}.png"
148
+ mask_path = f"{output_dir}/annotations/{split}/{mask_filename}"
149
+ to_class_index_mask(ann, label2id, mask_path)
150
+
151
+
152
+ def parse_arguments():
153
+ """Parse command line arguments."""
154
+ parser = argparse.ArgumentParser(
155
+ description="Parse Supervisely project and convert to training format"
156
+ )
157
+ parser.add_argument(
158
+ "--project_dir",
159
+ type=str,
160
+ required=True,
161
+ help="Path to the Supervisely project directory",
162
+ )
163
+ parser.add_argument(
164
+ "--output_base_dir",
165
+ type=str,
166
+ required=True,
167
+ help="Base output directory for parsed data",
168
+ )
169
+ parser.add_argument(
170
+ "--train_ratio",
171
+ type=float,
172
+ default=0.8,
173
+ help="Ratio of data to use for training (default: 0.8)",
174
+ )
175
+ parser.add_argument(
176
+ "--seed",
177
+ type=int,
178
+ default=42,
179
+ help="Random seed for reproducible splits (default: 42)",
180
+ )
181
+ return parser.parse_args()
182
+
183
+
184
+ def main():
185
+ """Main function to parse Supervisely project."""
186
+ # Parse arguments
187
+ args = parse_arguments()
188
+
189
+ # Set random seed for reproducible splits
190
+ random.seed(args.seed)
191
+
192
+ # Load project
193
+ project = sly.Project(args.project_dir, sly.OpenMode.READ)
194
+ print(f"Project: {project.name}")
195
+
196
+ # Setup output directory
197
+ output_dir = os.path.join(args.output_base_dir, project.name)
198
+ create_output_directories(output_dir)
199
+
200
+ # Extract class information
201
+ id2label, id2color, label2id = extract_class_info(project, output_dir)
202
+
203
+ # Get datasets and calculate splits
204
+ datasets = project.datasets
205
+ print(f"Datasets: {len(datasets)}")
206
+ train_items, val_items = calculate_split_counts(datasets, args.train_ratio)
207
+
208
+ # Process datasets
209
+ process_datasets(
210
+ project,
211
+ datasets,
212
+ output_dir,
213
+ label2id,
214
+ train_items,
215
+ )
216
+
217
+ print("Processing completed successfully!")
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main()
utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .parse import run_supervisely_parser
2
+ from .train import train_model
3
+
4
+ __all__ = [
5
+ "run_supervisely_parser",
6
+ "train_model",
7
+ ]
utils/parse.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import zipfile
4
+ import subprocess
5
+ from pathlib import Path
6
+
7
+
8
+ def run_supervisely_parser(
9
+ project_path: str,
10
+ train_ratio: float,
11
+ seed: int,
12
+ ) -> str:
13
+ """Extract a Supervisely project zip and run the parser script inside .venv-sly.
14
+
15
+ Parameters
16
+ ----------
17
+ project_path : (str)
18
+ Path to the uploaded Supervisely project .zip.
19
+ train_ratio : float
20
+ Portion of data to allocate to training (remainder is validation).
21
+ seed : int
22
+ Random seed forwarded to the parser for reproducible splits.
23
+
24
+ Returns
25
+ -------
26
+ str
27
+ Path to the parsed dataset directory produced by the parser script.
28
+ """
29
+
30
+ project_zip = Path(project_path)
31
+ if not project_zip.exists():
32
+ raise FileNotFoundError(
33
+ f"Provided project zip not found: {project_zip}"
34
+ )
35
+ if project_zip.suffix.lower() != ".zip":
36
+ raise ValueError("Supervisely project must be a .zip archive")
37
+
38
+ project_dir = project_zip.parent
39
+ extract_dir = Path(tempfile.mkdtemp(dir=project_dir))
40
+ output_base_dir = Path(tempfile.mkdtemp(dir=project_dir))
41
+
42
+ with zipfile.ZipFile(project_zip, "r") as zf:
43
+ zf.extractall(extract_dir)
44
+
45
+ def find_project_root(root: Path) -> Path:
46
+ if (root / "meta.json").exists():
47
+ return root
48
+ for child in root.iterdir():
49
+ if child.is_dir() and (child / "meta.json").exists():
50
+ return child
51
+ raise FileNotFoundError(
52
+ f"Could not locate 'meta.json' inside extracted archive at {root}"
53
+ )
54
+
55
+ project_root = find_project_root(extract_dir)
56
+
57
+ repo_root = Path(__file__).resolve().parent.parent
58
+ parser_script = repo_root / "scripts" / "supervisely_parser.py"
59
+ venv_python = repo_root / ".venv-sly" / "bin" / "python"
60
+
61
+ if not parser_script.exists():
62
+ raise FileNotFoundError(
63
+ f"Parser script not found: {parser_script}",
64
+ )
65
+ if not venv_python.exists():
66
+ raise FileNotFoundError(
67
+ "Expected .venv-sly Python interpreter at: " f"{venv_python}",
68
+ )
69
+
70
+ cmd = [
71
+ str(venv_python),
72
+ str(parser_script),
73
+ "--project_dir",
74
+ str(project_root),
75
+ "--output_base_dir",
76
+ str(output_base_dir),
77
+ "--train_ratio",
78
+ str(train_ratio),
79
+ "--seed",
80
+ str(seed),
81
+ ]
82
+
83
+ result = subprocess.run(
84
+ cmd,
85
+ capture_output=True,
86
+ text=True,
87
+ env={**os.environ},
88
+ )
89
+ if result.returncode != 0:
90
+ raise RuntimeError(
91
+ "Supervisely parser failed.\n"
92
+ f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
93
+ )
94
+ produced_dirs = [p for p in output_base_dir.iterdir() if p.is_dir()]
95
+ if len(produced_dirs) != 1:
96
+ raise RuntimeError(
97
+ "Could not unambiguously determine parsed dataset directory in "
98
+ f"{output_base_dir}. Found: {produced_dirs}"
99
+ )
100
+ dataset_dir = produced_dirs[0]
101
+ return str(dataset_dir)
utils/train.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SegFormer Fine-tuning Script
3
+
4
+ This script fine-tunes a SegFormer model on a custom semantic segmentation
5
+ dataset. It provides configurable parameters for training hyperparameters
6
+ and dataset settings.
7
+ """
8
+
9
+ import json
10
+ import os
11
+ import zipfile
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from PIL import Image
17
+ from pathlib import Path
18
+ from datetime import datetime
19
+ from transformers import (
20
+ SegformerImageProcessor,
21
+ SegformerForSemanticSegmentation,
22
+ )
23
+ import evaluate
24
+ from tqdm import tqdm
25
+
26
+
27
+ class SemanticSegmentationDataset(Dataset):
28
+ """Image (semantic) segmentation dataset."""
29
+
30
+ def __init__(
31
+ self,
32
+ root_dir,
33
+ image_processor,
34
+ train=True,
35
+ data_percent=100,
36
+ ):
37
+ """
38
+ Args:
39
+ root_dir (string): Root directory of the dataset containing
40
+ the images + annotations.
41
+ image_processor (SegFormerImageProcessor): image processor to
42
+ prepare images + segmentation maps.
43
+ train (bool): Whether to load "training" or "validation"
44
+ images + annotations.
45
+ data_percent (int): Percentage of the dataset to use.
46
+ 100 means all data, 50 means half of the data.
47
+ """
48
+ self.root_dir = root_dir
49
+ self.image_processor = image_processor
50
+ self.train = train
51
+
52
+ sub_path = "training" if self.train else "validation"
53
+ self.img_dir = os.path.join(self.root_dir, "images", sub_path)
54
+ self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path)
55
+
56
+ # read images
57
+ image_file_names = []
58
+ for root, dirs, files in os.walk(self.img_dir):
59
+ image_file_names.extend(files)
60
+ self.images = sorted(image_file_names)
61
+
62
+ # read annotations
63
+ annotation_file_names = []
64
+ for root, dirs, files in os.walk(self.ann_dir):
65
+ annotation_file_names.extend(files)
66
+ self.annotations = sorted(annotation_file_names)
67
+
68
+ assert len(self.images) == len(
69
+ self.annotations
70
+ ), "There must be as many images as there are segmentation maps"
71
+
72
+ # Apply data_percent to limit the dataset size
73
+ data_percent = data_percent / 100.0
74
+ if data_percent < 1.0:
75
+ images_num_samples = int(len(self.images) * data_percent)
76
+ annotations_num_samples = int(len(self.annotations) * data_percent)
77
+ self.images = self.images[:images_num_samples]
78
+ self.annotations = self.annotations[:annotations_num_samples]
79
+
80
+ def __len__(self):
81
+ return len(self.images)
82
+
83
+ def __getitem__(self, idx):
84
+ image = Image.open(os.path.join(self.img_dir, self.images[idx]))
85
+ segmentation_map = Image.open(
86
+ os.path.join(
87
+ self.ann_dir,
88
+ self.annotations[idx],
89
+ ),
90
+ )
91
+ encoded_inputs = self.image_processor(
92
+ image,
93
+ segmentation_map,
94
+ return_tensors="pt",
95
+ )
96
+
97
+ for k, v in encoded_inputs.items():
98
+ encoded_inputs[k].squeeze_() # remove batch dimension
99
+
100
+ return encoded_inputs
101
+
102
+
103
+ class MeanDice:
104
+ def __init__(self):
105
+ self.reset()
106
+
107
+ def reset(self):
108
+ """Reset stored predictions and references."""
109
+ self.predictions = []
110
+ self.references = []
111
+
112
+ def add_batch(self, predictions, references):
113
+ """
114
+ Add a batch of predictions and references.
115
+
116
+ Args:
117
+ predictions (np.ndarray): Predicted class indices
118
+ references (np.ndarray): Ground truth class indices
119
+ """
120
+ self.predictions.append(predictions)
121
+ self.references.append(references)
122
+
123
+ def compute(self, num_labels, ignore_index=None):
124
+ """Compute mean Dice score across all stored batches."""
125
+ predictions = np.concatenate([p.flatten() for p in self.predictions])
126
+ references = np.concatenate([r.flatten() for r in self.references])
127
+
128
+ dice_scores = []
129
+
130
+ for class_id in range(num_labels):
131
+ pred_mask = predictions == class_id
132
+ ref_mask = references == class_id
133
+
134
+ # Exclude ignore_index
135
+ if ignore_index is not None:
136
+ valid_mask = references != ignore_index
137
+ pred_mask = pred_mask & valid_mask
138
+ ref_mask = ref_mask & valid_mask
139
+
140
+ intersection = np.sum(pred_mask & ref_mask)
141
+ union = np.sum(pred_mask) + np.sum(ref_mask)
142
+
143
+ if union == 0:
144
+ dice = 1.0 if intersection == 0 else 0.0
145
+ else:
146
+ dice = 2.0 * intersection / union
147
+
148
+ dice_scores.append(dice)
149
+
150
+ return {
151
+ "mean_dice": float(np.mean(dice_scores)),
152
+ "per_class_dice": dice_scores,
153
+ }
154
+
155
+
156
+ def get_latest_model_dir(base_path: str = "./segformer_finetuned") -> Path:
157
+ """
158
+ Returns the Path to the latest model directory based on
159
+ timestamp folder names.
160
+
161
+ Folder names must follow the format: YYYY-MM-DD_HH-MM-SS
162
+ """
163
+ base = Path(base_path)
164
+ if not base.exists() or not base.is_dir():
165
+ raise FileNotFoundError(f"Directory not found: {base_path}")
166
+
167
+ model_dirs = []
168
+ for d in base.iterdir():
169
+ if d.is_dir():
170
+ try:
171
+ dt = datetime.strptime(d.name, "%Y-%m-%d_%H-%M-%S")
172
+ model_dirs.append((dt, d))
173
+ except ValueError:
174
+ continue # Skip non-matching directories
175
+
176
+ if not model_dirs:
177
+ raise FileNotFoundError(
178
+ "No model directories found with valid timestamp format."
179
+ )
180
+
181
+ # Return the directory with the latest timestamp
182
+ return max(model_dirs, key=lambda x: x[0])[1]
183
+
184
+
185
+ def load_model_and_labels(data_dir, model_path):
186
+ """Load the model and label mappings."""
187
+ # Load id2label mapping from JSON file
188
+ id2label = json.load(open(f"{data_dir}/id2label.json", mode="r"))
189
+ id2label = {int(k): v for k, v in id2label.items()}
190
+ label2id = {v: k for k, v in id2label.items()}
191
+
192
+ # Load id2color mapping from JSON file
193
+ id2color = json.load(open(f"{data_dir}/id2color.json", "r"))
194
+
195
+ print(f"Loaded {len(id2label)} classes:")
196
+ for i, label in id2label.items():
197
+ print(f" {i}: {label}")
198
+
199
+ # Load model
200
+ model = SegformerForSemanticSegmentation.from_pretrained(
201
+ model_path,
202
+ num_labels=len(id2label),
203
+ id2label=id2label,
204
+ label2id=label2id,
205
+ )
206
+ return model, id2label, id2color
207
+
208
+
209
+ def create_datasets_and_dataloaders(
210
+ image_width,
211
+ image_height,
212
+ data_dir,
213
+ batch_size,
214
+ data_percent,
215
+ ):
216
+ """Create datasets and dataloaders."""
217
+ image_processor = SegformerImageProcessor(
218
+ size={"height": image_height, "width": image_width},
219
+ )
220
+
221
+ train_dataset = SemanticSegmentationDataset(
222
+ root_dir=data_dir,
223
+ image_processor=image_processor,
224
+ train=True,
225
+ data_percent=data_percent,
226
+ )
227
+
228
+ valid_dataset = SemanticSegmentationDataset(
229
+ root_dir=data_dir,
230
+ image_processor=image_processor,
231
+ train=False,
232
+ data_percent=data_percent,
233
+ )
234
+
235
+ print(f"Number of training examples: {len(train_dataset)}")
236
+ print(f"Number of validation examples: {len(valid_dataset)}")
237
+
238
+ train_dataloader = DataLoader(
239
+ train_dataset,
240
+ batch_size=batch_size,
241
+ shuffle=True,
242
+ )
243
+ valid_dataloader = DataLoader(
244
+ valid_dataset,
245
+ batch_size=batch_size,
246
+ )
247
+
248
+ return train_dataloader, valid_dataloader
249
+
250
+
251
+ def class_indices_to_rgb(class_indices, id2color):
252
+ """Convert class indices to RGB colored image."""
253
+ # class_indices shape: (H, W) with integer class IDs
254
+ height, width = class_indices.shape
255
+ rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
256
+
257
+ for class_id, color in id2color.items():
258
+ rgb_image[class_indices == class_id] = color
259
+
260
+ return rgb_image
261
+
262
+
263
+ def validate_model(
264
+ model: SegformerForSemanticSegmentation,
265
+ dataloader,
266
+ device,
267
+ id2label,
268
+ calc_dice=False,
269
+ epoch=None,
270
+ ):
271
+ """
272
+ Validate the model on a validation set and return loss, IoU, accuracy.
273
+ """
274
+ model.eval()
275
+ metric = evaluate.load("mean_iou")
276
+ dice = MeanDice()
277
+ total_loss = 0.0
278
+ num_batches = 0
279
+
280
+ with torch.no_grad():
281
+ for batch in tqdm(
282
+ dataloader,
283
+ desc="Validating Epoch " + str(epoch if epoch is not None else ""),
284
+ leave=False,
285
+ unit="batches",
286
+ ):
287
+ pixel_values = batch["pixel_values"].to(device)
288
+ labels = batch["labels"].to(device)
289
+
290
+ outputs = model(pixel_values=pixel_values, labels=labels)
291
+ logits = outputs.logits
292
+ loss = outputs.loss
293
+
294
+ total_loss += loss.item()
295
+ num_batches += 1
296
+
297
+ upsampled_logits = nn.functional.interpolate(
298
+ logits,
299
+ size=labels.shape[-2:],
300
+ mode="bilinear",
301
+ align_corners=False,
302
+ )
303
+ predicted = upsampled_logits.argmax(dim=1)
304
+
305
+ # Store predictions and references for additional metrics
306
+ pred_np = predicted.detach().cpu().numpy()
307
+ ref_np = labels.detach().cpu().numpy()
308
+
309
+ metric.add_batch(
310
+ predictions=pred_np,
311
+ references=ref_np,
312
+ )
313
+ if calc_dice:
314
+ dice.add_batch(
315
+ predictions=pred_np,
316
+ references=ref_np,
317
+ )
318
+
319
+ # Calculate IoU and accuracy
320
+ result = metric.compute(
321
+ num_labels=len(id2label),
322
+ ignore_index=10,
323
+ reduce_labels=False,
324
+ )
325
+ if calc_dice:
326
+ dice_result = dice.compute(
327
+ num_labels=len(id2label),
328
+ ignore_index=10,
329
+ )
330
+
331
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
332
+ return (
333
+ avg_loss,
334
+ result["mean_iou"],
335
+ result["per_category_iou"],
336
+ result["mean_accuracy"],
337
+ result["per_category_accuracy"],
338
+ dice_result["mean_dice"] if calc_dice else None,
339
+ dice_result["per_class_dice"] if calc_dice else None,
340
+ )
341
+
342
+
343
+ def run_training(
344
+ model: SegformerForSemanticSegmentation,
345
+ device,
346
+ train_dataloader,
347
+ valid_dataloader,
348
+ id2label,
349
+ num_epochs,
350
+ learning_rate,
351
+ early_stopping,
352
+ validate_every,
353
+ ):
354
+ """Train the model.
355
+
356
+ Returns
357
+ -------
358
+ tuple(best_model, metrics)
359
+ best_model : nn.Module
360
+ metrics : dict with lists for keys: 'epoch', 'train_loss', 'train_iou',
361
+ 'train_acc', 'val_loss', 'val_iou', 'val_acc'
362
+ """
363
+ # Setup device
364
+ model.to(device)
365
+
366
+ # Setup optimizer
367
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
368
+
369
+ # Setup metrics
370
+ metrics = {
371
+ "epoch": [],
372
+ "train_loss": [],
373
+ "train_iou": [],
374
+ "train_acc": [],
375
+ "val_loss": [],
376
+ "val_iou": [],
377
+ "val_acc": [],
378
+ }
379
+
380
+ metric = evaluate.load("mean_iou")
381
+
382
+ model.train()
383
+
384
+ # Initial validation
385
+ (
386
+ loss,
387
+ iou,
388
+ per_class_iou,
389
+ acc,
390
+ per_class_acc,
391
+ dice,
392
+ dice_per_class,
393
+ ) = validate_model(
394
+ model=model,
395
+ dataloader=valid_dataloader,
396
+ device=device,
397
+ id2label=id2label,
398
+ calc_dice=True,
399
+ epoch=0,
400
+ )
401
+ # Add to metrics at epoch 0
402
+ metrics["epoch"].append(int(0))
403
+ metrics["val_loss"].append(loss)
404
+ metrics["val_iou"].append(iou)
405
+ metrics["val_acc"].append(acc)
406
+ metrics["train_loss"].append(None)
407
+ metrics["train_iou"].append(None)
408
+ metrics["train_acc"].append(None)
409
+
410
+ initial_dice = dice
411
+
412
+ best_model = model
413
+
414
+ best_iou = iou
415
+ patience = early_stopping
416
+ epochs_without_improvement = 0
417
+ for epoch in tqdm(
418
+ range(num_epochs),
419
+ desc="Training Epochs",
420
+ unit="epochs",
421
+ ):
422
+ epoch_loss = 0.0
423
+ num_batches = 0
424
+ model.train() # Ensure model is in training mode
425
+
426
+ progress_bar = tqdm(
427
+ train_dataloader,
428
+ desc=f"Training Epoch {epoch + 1}",
429
+ leave=True,
430
+ unit="batches",
431
+ )
432
+
433
+ for idx, batch in enumerate(progress_bar):
434
+ # Get the inputs
435
+ pixel_values = batch["pixel_values"].to(device)
436
+ labels = batch["labels"].to(device)
437
+
438
+ # Zero the parameter gradients
439
+ optimizer.zero_grad()
440
+
441
+ # Forward + backward + optimize
442
+ outputs = model(pixel_values=pixel_values, labels=labels)
443
+ loss, logits = outputs.loss, outputs.logits
444
+
445
+ loss.backward()
446
+ optimizer.step()
447
+
448
+ epoch_loss += loss.item()
449
+ num_batches += 1
450
+
451
+ # Evaluate training batch
452
+ with torch.no_grad():
453
+ upsampled_logits = nn.functional.interpolate(
454
+ logits,
455
+ size=labels.shape[-2:],
456
+ mode="bilinear",
457
+ align_corners=False,
458
+ )
459
+ predicted = upsampled_logits.argmax(dim=1)
460
+
461
+ # Store for metric calculation
462
+ pred_np = predicted.detach().cpu().numpy()
463
+ ref_np = labels.detach().cpu().numpy()
464
+
465
+ # Note: metric expects predictions + labels as numpy arrays
466
+ metric.add_batch(
467
+ predictions=pred_np,
468
+ references=ref_np,
469
+ )
470
+
471
+ train_metrics = metric.compute(
472
+ num_labels=len(id2label),
473
+ ignore_index=10,
474
+ reduce_labels=False,
475
+ )
476
+ train_loss = epoch_loss / num_batches if num_batches else 0.0
477
+
478
+ # Validation
479
+ if (epoch + 1) % validate_every == 0:
480
+ (
481
+ val_loss,
482
+ val_iou,
483
+ val_per_class_iou,
484
+ val_acc,
485
+ val_per_class_acc,
486
+ val_dice,
487
+ val_dice_per_class,
488
+ ) = validate_model(
489
+ model=model,
490
+ dataloader=valid_dataloader,
491
+ device=device,
492
+ id2label=id2label,
493
+ epoch=epoch + 1,
494
+ )
495
+
496
+ # Record metrics
497
+ metrics["epoch"].append(int(epoch + 1))
498
+ metrics["train_loss"].append(train_loss)
499
+ metrics["train_iou"].append(train_metrics["mean_iou"])
500
+ metrics["train_acc"].append(train_metrics["mean_accuracy"])
501
+ metrics["val_loss"].append(val_loss)
502
+ metrics["val_iou"].append(val_iou)
503
+ metrics["val_acc"].append(val_acc)
504
+
505
+ # Save the best model
506
+ if val_iou > best_iou:
507
+ best_model = model
508
+ best_iou = val_iou
509
+ epochs_without_improvement = 0
510
+ else:
511
+ epochs_without_improvement += 1
512
+
513
+ if epochs_without_improvement >= patience:
514
+ tqdm.write(
515
+ f"Early stopping after {patience} epochs with no improvement",
516
+ )
517
+ break
518
+
519
+ return best_model, metrics, initial_dice
520
+
521
+
522
+ def extract_model_zip(model_zip_path):
523
+ """Extract model zip file and return the model directory."""
524
+
525
+ if not os.path.exists(model_zip_path):
526
+ raise FileNotFoundError(f"Model zip file not found: {model_zip_path}")
527
+
528
+ with zipfile.ZipFile(model_zip_path, "r") as zip_ref:
529
+ extract_dir = os.path.join(os.path.dirname(model_zip_path), "output")
530
+ zip_ref.extractall(extract_dir)
531
+
532
+ # Check nested folder
533
+ if len(os.listdir(extract_dir)) == 1:
534
+ return os.path.join(extract_dir, os.listdir(extract_dir)[0])
535
+ else:
536
+ return extract_dir
537
+
538
+
539
+ def train_model(
540
+ data_dir,
541
+ base_model_zip,
542
+ image_width,
543
+ image_height,
544
+ batch_size,
545
+ data_percent,
546
+ num_epochs,
547
+ learning_rate,
548
+ early_stopping,
549
+ validate_every,
550
+ ):
551
+
552
+ model_path = extract_model_zip(base_model_zip)
553
+
554
+ # Load model and labels
555
+ model, id2label, id2color = load_model_and_labels(data_dir, model_path)
556
+
557
+ # Create datasets and dataloaders
558
+ train_dataloader, valid_dataloader = create_datasets_and_dataloaders(
559
+ image_width,
560
+ image_height,
561
+ data_dir,
562
+ batch_size,
563
+ data_percent,
564
+ )
565
+
566
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
567
+ print(f"Using device: {device}")
568
+
569
+ # Train the model
570
+ best_model, metrics, initial_dice = run_training(
571
+ model,
572
+ device,
573
+ train_dataloader,
574
+ valid_dataloader,
575
+ id2label,
576
+ num_epochs,
577
+ learning_rate,
578
+ early_stopping,
579
+ validate_every,
580
+ )
581
+
582
+ # Final validation
583
+ (
584
+ loss,
585
+ iou,
586
+ per_class_iou,
587
+ acc,
588
+ per_class_acc,
589
+ dice,
590
+ dice_per_class,
591
+ ) = validate_model(
592
+ model=best_model,
593
+ dataloader=valid_dataloader,
594
+ device=device,
595
+ id2label=id2label,
596
+ calc_dice=True,
597
+ epoch=0,
598
+ )
599
+
600
+ final_dice = dice
601
+
602
+ return best_model, metrics, [initial_dice, final_dice]