File size: 11,401 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
.. _training_api:

Training API (experimental)
===========================

Kornia provides a Training API with the specific purpose to train and fine-tune the
supported deep learning algorithms within the library.

.. sidebar:: **Deep Alchemy**

    .. image:: https://github.com/kornia/data/raw/main/pixie_alchemist.png
       :width: 100%
       :align: center

  A seemingly magical process of transformation, creation, or combination of data to usable deep learning models.


.. important::
	In order to use our Training API you must: ``pip install kornia[x]``

Why a Training API ?
--------------------

Kornia includes deep learning models that eventually need to be updated through fine-tuning.
Our aim is to have an API flexible enough to be used across our vision models and enable us to
override methods or dynamically pass callbacks to ease the process of debugging and experimentations.

.. admonition:: **Disclaimer**
	:class: seealso

	We do not pretend to be a general purpose training library but instead we allow Kornia users to
	experiment with the training of our models.

Design Principles
-----------------

- `kornia` golden rule is to not have heavy dependencies.
- Our models are simple enough so that a light training API can fulfill our needs.
- Flexible and full control to the training/validation loops and customize the pipeline.
- Decouple the model definition from the training pipeline.
- Use plane PyTorch abstractions and recipes to write your own routines.
- Implement `accelerate <https://github.com/huggingface/accelerate/>`_ library to scale the problem.

Trainer Usage
-------------

The entry point to start traning with Kornia is through the :py:class:`~kornia.x.Trainer` class.

The main API is a self contained module that heavily relies on `accelerate <https://github.com/huggingface/accelerate/>`_
to easily scale the training over multi-GPUs/TPU/fp16 `(see more) <https://github.com/huggingface/accelerate#supported-integrations/>`_
by following standard PyTorch recipes. Our API expects to consume standard PyTorch components and you decide if `kornia` makes the magic
for you.

1. Define your model

.. code:: python

	model = nn.Sequential(
	  kornia.contrib.VisionTransformer(image_size=32, patch_size=16),
	  kornia.contrib.ClassificationHead(num_classes=10),
	)

2. Create the datasets and dataloaders for training and validation

.. code:: python

	# datasets
	train_dataset = torchvision.datasets.CIFAR10(
	  root=config.data_path, train=True, download=True, transform=T.ToTensor())

	valid_dataset = torchvision.datasets.CIFAR10(
	  root=config.data_path, train=False, download=True, transform=T.ToTensor())

	# dataloaders
	train_dataloader = torch.utils.data.DataLoader(
	  train_dataset, batch_size=config.batch_size, shuffle=True)

	valid_daloader = torch.utils.data.DataLoader(
	  valid_dataset, batch_size=config.batch_size, shuffle=True)

3. Create your loss function, optimizer and scheduler

.. code:: python

	# loss function
	criterion = nn.CrossEntropyLoss()

	# optimizer and scheduler
	optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
	scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
	  optimizer, config.num_epochs * len(train_dataloader)
	)

4. Create the Trainer and execute the training pipeline

.. code:: python

	trainer = kornia.train.Trainer(
	  model, train_dataloader, valid_daloader, criterion, optimizer, scheduler, config,
	)
	trainer.fit()  # execute your training !


Customize [callbacks]
---------------------

At this point you might think - *Is this API generic enough ?*

	Of course not ! What is next ? Let's have fun and **customize**.

The :py:class:`~kornia.x.Trainer` internals are clearly defined such in a way so that e.g you can
subclass and just override the :py:func:`~kornia.x.Trainer.evaluate` method and adjust
according to your needs. We provide predefined classes for generic problems such as
:py:class:`~kornia.x.ImageClassifierTrainer`, :py:class:`~kornia.x.SemanticSegmentationTrainer`.

.. note::
	More trainers will come as soon as we include more models.

You can easily customize by creating your own class, or even through ``callbacks`` as follows:

.. code:: python

    @torch.no_grad()
    def my_evaluate(self) -> dict:
      self.model.eval()
      for sample_id, sample in enumerate(self.valid_dataloader):
        source, target = sample  # this might change with new pytorch ataset structure

        # perform the preprocess and augmentations in batch
        img = self.preprocess(source)
        # Forward
        out = self.model(img)
        # Loss computation
        val_loss = self.criterion(out, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(out.detach(), target, topk=(1, 5))

    # create the trainer and pass the evaluate method as follows
    trainer = K.train.Trainer(..., callbacks={"evaluate", my_evaluate})

**Still not convinced ?**

	You can even override the whole :py:func:`~kornia.x.ImageClassifierTrainer.fit()`
	method and implement your custom for loops and the trainer will setup for you using the Accelerator all
	the data to the device and the rest of the story is just PyTorch :)

.. code:: python

    def my_fit(self, ):  # this is a custom pytorch training loop
      self.model.train()
      for epoch in range(self.num_epochs):
        for source, targets in self.train_dataloader:
          self.optimizer.zero_grad()

          output = self.model(source)
          loss = self.criterion(output, targets)

          self.backward(loss)
          self.optimizer.step()

          stats = self.evaluate()  # do whatever you want with validation

    # create the trainer and pass the evaluate method as follows
    trainer = K.train.Trainer(..., callbacks={"fit", my_fit})

.. note::
  The following hooks are available to override: ``preprocess``, ``augmentations``, ``evaluate``, ``fit``,
  ``on_checkpoint``, ``on_epoch_end``, ``on_before_model``


Preprocess and augmentations
----------------------------

Taking a pre-trained model from an external source and assume that fine-tuning with your
data by just changing few things in your model is usually a bad assumption in practice.

Fine-tuning a model need a lot tricks which usually means designing a good augmentation
or preprocess strategy before you execute the training pipeline. For this reason, we enable
through callbacks to pass pointers to the ``proprocess`` and ``augmentation`` functions to make easy
the debugging and experimentation experience.

.. code:: python

	def preprocess(x):
	  return x.float() / 255.

	augmentations = nn.Sequential(
	  K.augmentation.RandomHorizontalFlip(p=0.75),
	  K.augmentation.RandomVerticalFlip(p=0.75),
	  K.augmentation.RandomAffine(degrees=10.),
	  K.augmentation.PatchSequential(
		K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.8),
		grid_size=(2, 2),  # cifar-10 is 32x32 and vit is patch 16
		patchwise_apply=False,
	  ),
	)

	# create the trainer and pass the augmentation or preprocess
	trainer = K.train.ImageClassifierTrainer(...,
	  callbacks={"preprocess", preprocess, "augmentations": augmentations})

Callbacks utilities
-------------------

We also provide utilities to save checkpoints of the model or early stop the training. You can use
as follows passing as ``callbacks`` the classes :py:class:`~kornia.x.ModelCheckpoint` and
:py:class:`~kornia.x.EarlyStopping`.

.. code:: python

	model_checkpoint = ModelCheckpoint(
	  filepath="./outputs", monitor="top5",
	)

	early_stop = EarlyStopping(monitor="top5")

	trainer = K.train.ImageClassifierTrainer(...,
	  callbacks={"on_checkpoint", model_checkpoint, "on_epoch_end": early_stop})

Hyperparameter sweeps
---------------------

Use `hydra <https://hydra.cc>`_ to implement an easy search strategy for your hyper-parameters as follows:

.. note::

  Checkout the toy example in `here <https://github.com/kornia/kornia/tree/master/examples/train/image_classifier>`__

.. code:: python

  python ./train/image_classifier/main.py num_epochs=50 batch_size=32

.. code:: python

  python ./train/image_classifier/main.py --multirun lr=1e-3,1e-4

Distributed Training
--------------------

Kornia :py:class:`~kornia.x.Trainer` heavily relies on `accelerate <https://github.com/huggingface/accelerate/>`_ to
decouple the process of running your training scripts in a distributed environment.

.. note::

	We haven't tested yet all the possibilities for distributed training.
	Expect some adventures or `join us <https://join.slack.com/t/kornia/shared_invite/zt-csobk21g-CnydWe5fmvkcktIeRFGCEQ>`_ and help to iterate :)

The below recipes are taken from the `accelerate` library in `here <https://github.com/huggingface/accelerate/tree/main/examples#simple-vision-example>`__:

- single CPU:

  * from a server without GPU

    .. code:: bash

      python ./train/image_classifier/main.py

  * from any server by passing `cpu=True` to the `Accelerator`.

    .. code:: bash

      python ./train/image_classifier/main.py --data_path path_to_data --cpu

  * from any server with Accelerate launcher

    .. code:: bash

      accelerate launch --cpu ./train/image_classifier/main.py --data_path path_to_data

- single GPU:

  .. code:: bash

    python ./train/image_classifier/main.py  # from a server with a GPU

- with fp16 (mixed-precision)

  * from any server by passing `fp16=True` to the `Accelerator`.

    .. code:: bash

      python ./train/image_classifier/main.py --data_path path_to_data --fp16

  * from any server with Accelerate launcher

    .. code:: bash

      accelerate launch --fp16 ./train/image_classifier/main.py --data_path path_to_data

- multi GPUs (using PyTorch distributed mode)

  * With Accelerate config and launcher

    .. code:: bash

      accelerate config  # This will create a config file on your server
      accelerate launch ./train/image_classifier/main.py --data_path path_to_data  # This will run the script on your server

  * With traditional PyTorch launcher

    .. code:: bash

      python -m torch.distributed.launch --nproc_per_node 2 --use_env ./train/image_classifier/main.py --data_path path_to_data

- multi GPUs, multi node (several machines, using PyTorch distributed mode)

  * With Accelerate config and launcher, on each machine:

    .. code:: bash

      accelerate config  # This will create a config file on each server
      accelerate launch ./train/image_classifier/main.py --data_path path_to_data  # This will run the script on each server

  * With PyTorch launcher only

    .. code:: bash

      python -m torch.distributed.launch --nproc_per_node 2 \
        --use_env \
        --node_rank 0 \
        --master_addr master_node_ip_address \
        ./train/image_classifier/main.py --data_path path_to_data  # On the first server

      python -m torch.distributed.launch --nproc_per_node 2 \
        --use_env \
        --node_rank 1 \
        --master_addr master_node_ip_address \
        ./train/image_classifier/main.py --data_path path_to_data  # On the second server

- (multi) TPUs

  * With Accelerate config and launcher

    .. code:: bash

      accelerate config  # This will create a config file on your TPU server
      accelerate launch ./train/image_classifier/main.py --data_path path_to_data  # This will run the script on each server

  * In PyTorch:
    Add an `xmp.spawn` line in your script as you usually do.