File size: 33,088 Bytes
fc0f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EXQz7Vp8ehqb"
      },
      "source": [
        "# 🚀 Getting started\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/docs/optax-101.ipynb)\n",
        "\n",
        "Optax is a simple optimization library for [JAX](https://jax.readthedocs.io/). The main object is the {py:class}`GradientTransformation <optax.GradientTransformation>`, which can be chained with other transformations to obtain the final update operation and the optimizer state. Optax also contains some simple loss functions and utilities to help you write the full optimization steps. This notebook walks you through a few examples on how to use Optax."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vEIU3POrGiE5"
      },
      "source": [
        "## Example: Fitting a Linear Model\n",
        "\n",
        "Begin by importing the necessary packages:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jr7_e_ZJ_hky"
      },
      "outputs": [],
      "source": [
        "import jax.numpy as jnp\n",
        "import jax\n",
        "import optax\n",
        "import functools"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n7kMS9kyM8vM"
      },
      "source": [
        "In this example, we begin by setting up a simple linear model and a loss function. You can use any other library, such as [haiku](https://github.com/deepmind/dm-haiku) or [Flax](https://github.com/google/flax) to construct your networks. Here, we keep it simple and write it ourselves. The loss function (L2 Loss) comes from Optax's {doc}`losses <api/losses>` via {py:class}`l2_loss <optax.l2_loss>`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0-8XwoQF_AO2"
      },
      "outputs": [],
      "source": [
        "@functools.partial(jax.vmap, in_axes=(None, 0))\n",
        "def network(params, x):\n",
        "  return jnp.dot(params, x)\n",
        "\n",
        "def compute_loss(params, x, y):\n",
        "  y_pred = network(params, x)\n",
        "  loss = jnp.mean(optax.l2_loss(y_pred, y))\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EZviuSmuNFsC"
      },
      "source": [
        "Here we generate data under a known linear model (with `target_params=0.5`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H-_pwBx6_keL"
      },
      "outputs": [],
      "source": [
        "key = jax.random.PRNGKey(42)\n",
        "target_params = 0.5\n",
        "\n",
        "# Generate some data.\n",
        "xs = jax.random.normal(key, (16, 2))\n",
        "ys = jnp.sum(xs * target_params, axis=-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Td4Lp3qDNsL3"
      },
      "source": [
        "### Basic usage of Optax\n",
        "\n",
        "Optax contains implementations of {doc}`many popular optimizers <api/optimizers>` that can be used very simply. For example, the gradient transform for the Adam optimizer is available at {py:class}`optax.adam`. For now, let's start by calling the {py:class}`GradientTransformation <optax.GradientTransformation>` object for Adam the `optimizer`. We then initialize the optimizer state using the `init` function and `params` of the network."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rsLXLb5wBeY2"
      },
      "outputs": [],
      "source": [
        "start_learning_rate = 1e-1\n",
        "optimizer = optax.adam(start_learning_rate)\n",
        "\n",
        "# Initialize parameters of the model + optimizer.\n",
        "params = jnp.array([0.0, 0.0])\n",
        "opt_state = optimizer.init(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CpAvP1WSnsyM"
      },
      "source": [
        "Next we write the update loop. The {py:class}`GradientTransformation <optax.GradientTransformation>` object contains an `update` function that takes in the current optimizer state and gradients and returns the `updates` that need to be applied to the parameters: `updates, new_opt_state = optimizer.update(grads, opt_state)`.\n",
        "\n",
        "Optax comes with a few simple {doc}`update rules <api/apply_updates>` that apply the updates from the gradient transforms to the current parameters to return new ones: `new_params = optax.apply_updates(params, updates)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TNkhz_nrB2lx"
      },
      "outputs": [],
      "source": [
        "# A simple update loop.\n",
        "for _ in range(1000):\n",
        "  grads = jax.grad(compute_loss)(params, xs, ys)\n",
        "  updates, opt_state = optimizer.update(grads, opt_state)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "\n",
        "assert jnp.allclose(params, target_params), \\\n",
        "'Optimization should retrive the target params used to generate the data.'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XXEz3j7wPZUH"
      },
      "source": [
        "### Custom optimizers\n",
        "\n",
        "Optax makes it easy to create custom optimizers by {py:class}`chain <optax.chain>`ing gradient transforms. For example, this creates an optimizer based on Adam. Note that the scaling is `-learning_rate` which is an important detail since {py:class}`apply_updates <optax.apply_updates>` is additive."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KQNI2P3YEEgP"
      },
      "outputs": [],
      "source": [
        "# Exponential decay of the learning rate.\n",
        "scheduler = optax.exponential_decay(\n",
        "    init_value=start_learning_rate,\n",
        "    transition_steps=1000,\n",
        "    decay_rate=0.99)\n",
        "\n",
        "# Combining gradient transforms using `optax.chain`.\n",
        "gradient_transform = optax.chain(\n",
        "    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.\n",
        "    optax.scale_by_adam(),  # Use the updates from adam.\n",
        "    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.\n",
        "    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.\n",
        "    optax.scale(-1.0)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XGUrLKxAEO3j"
      },
      "outputs": [],
      "source": [
        "# Initialize parameters of the model + optimizer.\n",
        "params = jnp.array([0.0, 0.0])  # Recall target_params=0.5.\n",
        "opt_state = gradient_transform.init(params)\n",
        "\n",
        "# A simple update loop.\n",
        "for _ in range(1000):\n",
        "  grads = jax.grad(compute_loss)(params, xs, ys)\n",
        "  updates, opt_state = gradient_transform.update(grads, opt_state)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "\n",
        "assert jnp.allclose(params, target_params), \\\n",
        "'Optimization should retrive the target params used to generate the data.'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pIxKL7WsXFl8"
      },
      "source": [
        "### Advanced usage of Optax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nCtNiVTsZVt2"
      },
      "source": [
        "#### Modifying hyperparameters of optimizers in a schedule.\n",
        "\n",
        "In some scenarios, changing the hyperparameters (other than the learning rate) of an optimizer can be useful to ensure training reliability. We can do this easily by using {py:class}`inject_hyperparams <optax.inject_hyperparams>`. For example, this piece of code decays the `max_norm` of the {py:class}`clip_by_global_norm <optax.clip_by_global_norm>` gradient transform as training progresses:\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NR9Flsj7ZdpC"
      },
      "outputs": [],
      "source": [
        "decaying_global_norm_tx = optax.inject_hyperparams(optax.clip_by_global_norm)(\n",
        "    max_norm=optax.linear_schedule(1.0, 0.0, transition_steps=99))\n",
        "\n",
        "opt_state = decaying_global_norm_tx.init(None)\n",
        "assert opt_state.hyperparams['max_norm'] == 1.0, 'Max norm should start at 1.0'\n",
        "\n",
        "for _ in range(100):\n",
        "  _, opt_state = decaying_global_norm_tx.update(None, opt_state)\n",
        "\n",
        "assert opt_state.hyperparams['max_norm'] == 0.0, 'Max norm should end at 0.0'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tKcocLxEyYf2"
      },
      "source": [
        "## Example: Fitting a MLP\n",
        "\n",
        "Let's use Optax to fit a parametrized function. We will consider the problem of learning to identify when a value is odd or even.\n",
        "\n",
        "We will begin by creating a dataset that consists of batches of random 8 bit integers (represented using their binary representation), with each value labelled as \"odd\" or \"even\" using 1-hot encoding (i.e. `[1, 0]` means odd `[0, 1]` means even).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gg6zyMBqydty"
      },
      "outputs": [],
      "source": [
        "import optax\n",
        "import jax.numpy as jnp\n",
        "import jax\n",
        "import numpy as np\n",
        "\n",
        "BATCH_SIZE = 5\n",
        "NUM_TRAIN_STEPS = 1_000\n",
        "RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))\n",
        "\n",
        "TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)\n",
        "LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nV79rjQK8tvC"
      },
      "source": [
        "We may now define a parametrized function using JAX. This will allow us to efficiently compute gradients.\n",
        "\n",
        "There are a number of libraries that provide common building blocks for parametrized functions (such as flax and haiku). For this case though, we shall implement our function from scratch.\n",
        "\n",
        "Our function will be a 1-layer MLP (multi-layer perceptron) with a single hidden layer, and a single output layer. We initialize all parameters using a standard Gaussian {math}`\\mathcal{N}(0,1)` distribution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Syp9LJ338h9-"
      },
      "outputs": [],
      "source": [
        "initial_params = {\n",
        "    'hidden': jax.random.normal(shape=[8, 32], key=jax.random.PRNGKey(0)),\n",
        "    'output': jax.random.normal(shape=[32, 2], key=jax.random.PRNGKey(1)),\n",
        "}\n",
        "\n",
        "\n",
        "def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray:\n",
        "  x = jnp.dot(x, params['hidden'])\n",
        "  x = jax.nn.relu(x)\n",
        "  x = jnp.dot(x, params['output'])\n",
        "  return x\n",
        "\n",
        "\n",
        "def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:\n",
        "  y_hat = net(batch, params)\n",
        "\n",
        "  # optax also provides a number of common loss functions.\n",
        "  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)\n",
        "\n",
        "  return loss_value.mean()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2LVHrJyH9vDe"
      },
      "source": [
        "We will use {py:class}`optax.adam` to compute the parameter updates from their gradients on each optimizer step.\n",
        "\n",
        "Note that since Optax optimizers are implemented using pure functions, we will need to also keep track of the optimizer state. For the Adam optimizer, this state will contain the momentum values."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 791,
          "status": "ok",
          "timestamp": 1706783722580,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "JsbPBTF09FGY",
        "outputId": "8ed2b4d2-5d60-4f48-bba0-aeb8bf575cf2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step 0, loss: 5.830467224121094\n",
            "step 100, loss: 0.11273854970932007\n",
            "step 200, loss: 0.07734161615371704\n",
            "step 300, loss: 0.04497973248362541\n",
            "step 400, loss: 0.006239257287234068\n",
            "step 500, loss: 0.006817951332777739\n",
            "step 600, loss: 0.0020168141927570105\n",
            "step 700, loss: 0.010563415475189686\n",
            "step 800, loss: 0.0004556340572889894\n",
            "step 900, loss: 0.011806081980466843\n"
          ]
        }
      ],
      "source": [
        "def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:\n",
        "  opt_state = optimizer.init(params)\n",
        "\n",
        "  @jax.jit\n",
        "  def step(params, opt_state, batch, labels):\n",
        "    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)\n",
        "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
        "    params = optax.apply_updates(params, updates)\n",
        "    return params, opt_state, loss_value\n",
        "\n",
        "  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):\n",
        "    params, opt_state, loss_value = step(params, opt_state, batch, labels)\n",
        "    if i % 100 == 0:\n",
        "      print(f'step {i}, loss: {loss_value}')\n",
        "\n",
        "  return params\n",
        "\n",
        "# Finally, we can fit our parametrized function using the Adam optimizer\n",
        "# provided by optax.\n",
        "optimizer = optax.adam(learning_rate=1e-2)\n",
        "params = fit(initial_params, optimizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kTaBLYL8_Ppz"
      },
      "source": [
        "We see that our loss appears to have converged, which should indicate that we have successfully found better parameters for our network."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qT_Uaei5Dv_3"
      },
      "source": [
        "### Weight Decay, Schedules and Clipping\n",
        "\n",
        "Many research models make use of techniques such as learning rate scheduling, and gradient clipping. These may be achieved by chaining together gradient transformations such as {py:class}`optax.adam` and {py:class}`optax.clip`.\n",
        "\n",
        "In the following, we will use Adam with weight decay ({py:class}`optax.adamw`), a cosine learning rate schedule (with warmup) and also gradient clipping."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 257,
          "status": "ok",
          "timestamp": 1706783722955,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "SZegYQajDtLi",
        "outputId": "00d3b1f8-d44d-4ab8-c47a-dc557dd327c3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "step 0, loss: 5.830467224121094\n",
            "step 100, loss: 0.0\n",
            "step 200, loss: 1.2290412909483864e-18\n",
            "step 300, loss: 0.0\n",
            "step 400, loss: 0.0\n",
            "step 500, loss: 0.0\n",
            "step 600, loss: 0.0\n",
            "step 700, loss: 0.0\n",
            "step 800, loss: 0.0\n",
            "step 900, loss: 0.0\n"
          ]
        }
      ],
      "source": [
        "schedule = optax.warmup_cosine_decay_schedule(\n",
        "  init_value=0.0,\n",
        "  peak_value=1.0,\n",
        "  warmup_steps=50,\n",
        "  decay_steps=1_000,\n",
        "  end_value=0.0,\n",
        ")\n",
        "\n",
        "optimizer = optax.chain(\n",
        "  optax.clip(1.0),\n",
        "  optax.adamw(learning_rate=schedule),\n",
        ")\n",
        "\n",
        "params = fit(initial_params, optimizer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qf53Y6mT1Vwl"
      },
      "source": [
        "## Components\n",
        "\n",
        "We refer to the {doc}`docs <index>` for a detailed list of available Optax components. Here, we highlight the main categories of building blocks provided by Optax."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WZFpEKi82TGx"
      },
      "source": [
        "### Gradient Transformations ([transform.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/transform.py))\n",
        "\n",
        "One of the key building blocks of Optax is a {py:class}`GradientTransformation <optax.GradientTransformation>`. Each transformation is defined by two functions:\n",
        "\n",
        "`state = init(params)`\n",
        "\n",
        "`grads, state = update(grads, state, params=None)`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n6SsC9lNGiE-"
      },
      "source": [
        "The `init` function initializes a (possibly empty) set of statistics (aka state) and the `update` function transforms a candidate gradient given some statistics, and (optionally) the current value of the parameters.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_yCQbSCc2KhJ"
      },
      "outputs": [],
      "source": [
        "tx = optax.scale_by_rms()\n",
        "state = tx.init(params)  # init stats\n",
        "grads = jax.grad(loss)(params, TRAINING_DATA, LABELS)\n",
        "updates, state = tx.update(grads, state, params)  # transform & update stats."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TyxJmbBq2xT6"
      },
      "source": [
        "### Composing Gradient Transformations ([combine.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/combine.py))\n",
        "\n",
        "The fact that transformations take candidate gradients as input and return processed gradients as output (in contrast to returning the updated parameters) is critical to allow to combine arbitrary transformations into a custom optimiser / gradient processor, and also allows to combine transformations for different gradients that operate on a shared set of variables.\n",
        "\n",
        "For instance, {py:class}`chain <optax.chain>` combines them sequentially, and returns a new {py:class}`GradientTransformation <optax.GradientTransformation>` that applies several transformations in sequence.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TNPC9e7I28m8"
      },
      "outputs": [],
      "source": [
        "max_norm = 100.\n",
        "learning_rate = 1e-3\n",
        "\n",
        "my_optimiser = optax.chain(\n",
        "    optax.clip_by_global_norm(max_norm),\n",
        "    optax.scale_by_adam(eps=1e-4),\n",
        "    optax.scale(-learning_rate))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JmV92-PI2_pS"
      },
      "source": [
        "### Wrapping Gradient Transformations ([wrappers.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/wrappers.py))\n",
        "\n",
        "Optax also provides several wrappers that take a {py:class}`GradientTransformation <optax.GradientTransformation>` as input and return a new {py:class}`GradientTransformation <optax.GradientTransformation>` that modifies the behaviour of the inner transformation in a specific way.\n",
        "\n",
        "For instance, the {py:class}`flatten <optax.flatten>` wrapper flattens gradients into a single large vector before applying the inner {py:class}`GradientTransformation <optax.GradientTransformation>`. The transformed updates are then unflattened before being returned to the user. This can be used to reduce the overhead of performing many calculations on lots of small variables, at the cost of increasing memory usage.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b1TlMbAk3Jbo"
      },
      "outputs": [],
      "source": [
        "my_optimiser = optax.flatten(optax.adam(learning_rate))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IUCIMymV3M2n"
      },
      "source": [
        "Other examples of wrappers include accumulating gradients over multiple steps or applying the inner transformation only to specific parameters or at specific steps."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AGAmqST33PkO"
      },
      "source": [
        "### Schedules ([schedule.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/schedule.py))\n",
        "\n",
        "Many popular transformations use time-dependent components, e.g. to anneal some hyper-parameter (e.g. the learning rate). Optax provides for this purpose schedules that can be used to decay scalars as a function of a `step` count.\n",
        "\n",
        "For example, you may use a {py:class}`polynomial_schedule <optax.polynomial_schedule>` (with `power=1`) to decay a hyper-parameter linearly over a number of steps:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 131,
          "status": "ok",
          "timestamp": 1706783724487,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "Zbr61DLP3ecy",
        "outputId": "20beae28-32bf-4030-f877-fcd4a815a0c2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0\n",
            "0.8\n",
            "0.6\n",
            "0.39999998\n",
            "0.19999999\n",
            "0.0\n"
          ]
        }
      ],
      "source": [
        "schedule_fn = optax.polynomial_schedule(\n",
        "    init_value=1., end_value=0., power=1, transition_steps=5)\n",
        "\n",
        "for step_count in range(6):\n",
        "  print(schedule_fn(step_count))  # [1., 0.8, 0.6, 0.4, 0.2, 0.]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LGt0AzHF3fjR"
      },
      "source": [
        "Schedules can be combined with other transforms as follows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W9oCb0Kw3igG"
      },
      "outputs": [],
      "source": [
        "schedule_fn = optax.polynomial_schedule(\n",
        "    init_value=-learning_rate, end_value=0., power=1, transition_steps=5)\n",
        "optimiser = optax.chain(\n",
        "    optax.clip_by_global_norm(max_norm),\n",
        "    optax.scale_by_adam(eps=1e-4),\n",
        "    optax.scale_by_schedule(schedule_fn))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDSXlRAN_B2F"
      },
      "source": [
        "Schedules can also be used in place of the `learning_rate` argument of a\n",
        "{py:class}`GradientTransformation <optax.GradientTransformation>` as\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zyvlGLDw_BKk"
      },
      "outputs": [],
      "source": [
        "optimiser = optax.adam(learning_rate=schedule_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cKHZrM203kx4"
      },
      "source": [
        "### Popular optimisers ([alias.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/alias.py))\n",
        "\n",
        "In addition to the low-level building blocks, we also provide aliases for popular optimisers built using these components (e.g. RMSProp, Adam, AdamW, etc, ...). These are all still instances of a {py:class}`GradientTransformation <optax.GradientTransformation>`, and can therefore be further combined with any of the individual building blocks.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Czk49AQz3w1J"
      },
      "outputs": [],
      "source": [
        "def adamw(learning_rate, b1, b2, eps, weight_decay):\n",
        "  return optax.chain(\n",
        "      optax.scale_by_adam(b1=b1, b2=b2, eps=eps),\n",
        "      optax.scale_and_decay(-learning_rate, weight_decay=weight_decay))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j0tD_jWC3zar"
      },
      "source": [
        "### Applying updates ([update.py](https://github.com/google-deepmind/optax/blob/main/optax/_src/update.py))\n",
        "\n",
        "After transforming an update using a {py:class}`GradientTransformation <optax.GradientTransformation>` or any custom manipulation of the update, you will typically apply the update to a set of parameters. This can be done trivially using `tree_map`.\n",
        "\n",
        "For convenience, we expose an {py:class}`apply_updates <optax.apply_updates>` function to apply updates to parameters. The function just adds the updates and the parameters together, i.e. `tree_map(lambda p, u: p + u, params, updates)`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YG-TNzYm4CHt"
      },
      "outputs": [],
      "source": [
        "updates, state = tx.update(grads, state, params)  # transform & update stats.\n",
        "new_params = optax.apply_updates(params, updates)  # update the parameters."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eg85y6_s4C2c"
      },
      "source": [
        "Note that separating gradient transformations from the parameter update is critical to support composing a sequence of transformations (e.g. {py:class}`chain <optax.chain>`), as well as combining multiple updates to the same parameters (e.g. in multi-task settings where different tasks need different sets of gradient transformations)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dJzW0Flw4FP5"
      },
      "source": [
        "### Losses ([loss.py](https://github.com/google-deepmind/optax/tree/main/optax/losses))\n",
        "\n",
        "Optax provides a number of standard losses used in deep learning, such as {py:class}`l2_loss <optax.l2_loss>`, {py:class}`softmax_cross_entropy <optax.softmax_cross_entropy>`, {py:class}`cosine_distance <optax.cosine_distance>`, etc."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8JCWgHhJ4PMc"
      },
      "outputs": [],
      "source": [
        "predictions = net(TRAINING_DATA, params)\n",
        "loss = optax.huber_loss(predictions, LABELS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gAlaEpgQ4QyD"
      },
      "source": [
        "The losses accept batches as inputs, however, they perform no reduction across the batch dimension(s). This is trivial to do in JAX, for example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "45svU6Qr4ThD"
      },
      "outputs": [],
      "source": [
        "avg_loss = jnp.mean(optax.huber_loss(predictions, LABELS))\n",
        "sum_loss = jnp.sum(optax.huber_loss(predictions, LABELS))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MepQR-Cr4VaB"
      },
      "source": [
        "### Second Order ([second_order.py](https://github.com/google-deepmind/optax/tree/main/optax/second_order))\n",
        "\n",
        "Computing the Hessian or Fisher information matrices for neural networks is typically intractable due to the quadratic memory requirements. Solving for the diagonals of these matrices is often a better solution. The library offers functions for computing these diagonals with sub-quadratic memory requirements."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fcJiCWSQ4gPP"
      },
      "source": [
        "### Stochastic gradient estimators ([stochastic_gradient_estimators.py](https://github.com/google-deepmind/optax/blob/main/optax/monte_carlo/stochastic_gradient_estimators.py))\n",
        "\n",
        "Stochastic gradient estimators compute Monte Carlo estimates of gradients of the expectation of a function under a distribution with respect to the distribution's parameters.\n",
        "\n",
        "Unbiased estimators, such as the score function estimator (REINFORCE), pathwise estimator (reparameterization trick) or measure valued estimator, are implemented: {py:class}`score_function_jacobians <optax.monte_carlo.score_function_jacobians>`, {py:class}`pathwise_jacobians <optax.monte_carlo.pathwise_jacobians>` and {py:class}`measure_valued_jacobians <optax.monte_carlo.measure_valued_jacobians>`. Their applicability (both in terms of functions and distributions) is discussed in their respective documentation.\n",
        "\n",
        "Stochastic gradient estimators can be combined with common control variates for variance reduction via {py:class}`control_variates_jacobians <optax.monte_carlo.control_variates_jacobians>`. For provided control variates see {py:class}`control_delta_method <optax.monte_carlo.control_delta_method>` and {py:class}`moving_avg_baseline <optax.monte_carlo.moving_avg_baseline>`.\n",
        "\n",
        "The result of a gradient estimator or {py:class}`control_variates_jacobians <optax.monte_carlo.control_variates_jacobians>` contains the Jacobians of the function with respect to the samples from the input distribution. These can then be used to update distributional parameters or to assess gradient variance.\n",
        "\n",
        "Example of how to use the {py:class}`pathwise_jacobians <optax.monte_carlo.pathwise_jacobians>` estimator:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NYJOV6Vv4uDb"
      },
      "outputs": [],
      "source": [
        "mean, log_scale, rng, num_samples = 0., 1., jax.random.PRNGKey(0), 100\n",
        "dist_params = [mean, log_scale]\n",
        "function = lambda x: jnp.sum(x)\n",
        "jacobians = optax.monte_carlo.pathwise_jacobians(\n",
        "      function, dist_params,\n",
        "      optax.multi_normal, rng, num_samples)\n",
        "\n",
        "mean_grads = jnp.mean(jacobians[0], axis=0)\n",
        "log_scale_grads = jnp.mean(jacobians[1], axis=0)\n",
        "grads = [mean_grads, log_scale_grads]\n",
        "optim = optax.adam(1e-3)\n",
        "optim_state = optim.init(grads)\n",
        "optim_update, optim_state = optim.update(grads, optim_state)\n",
        "updated_dist_params = optax.apply_updates(dist_params, optim_update)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x5mXAl_H4xCH"
      },
      "source": [
        "where `optim` is an Optax optimizer."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "Optax 101",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.11.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}