File size: 59,067 Bytes
2c55b92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MpkYHwCqk7W-"
   },
   "source": [
    "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n",
    "\n",
    "# <h1><center>Tutorial  <a href=\"https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a></center></h1>\n",
    "\n",
    "This notebook provides an introductory tutorial for [**MuJoCo XLA (MJX)**](https://github.com/google-deepmind/mujoco/blob/main/mjx), a JAX-based implementation of MuJoCo useful for RL training workloads.\n",
    "\n",
    "**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime > Change runtime type\".\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xBSdkbmGN2K-"
   },
   "source": [
    "### Copyright notice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_UbO9uhtBSX5"
   },
   "source": [
    "> <p><small><small>Copyright 2023 DeepMind Technologies Limited.</small></p>\n",
    "> <p><small><small>Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href=\"http://www.apache.org/licenses/LICENSE-2.0\">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>\n",
    "> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YvyGCsgSCxHQ"
   },
   "source": [
    "# Install MuJoCo, MJX, and Brax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Xqo7pyX-n72M"
   },
   "outputs": [],
   "source": [
    "!pip install mujoco\n",
    "!pip install mujoco_mjx\n",
    "!pip install brax\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "IbZxYDxzoz5R"
   },
   "outputs": [],
   "source": [
    "#@title Check if MuJoCo installation was successful\n",
    "\n",
    "from google.colab import files\n",
    "\n",
    "import distutils.util\n",
    "import os\n",
    "import subprocess\n",
    "if subprocess.run('nvidia-smi').returncode:\n",
    "  raise RuntimeError(\n",
    "      'Cannot communicate with GPU. '\n",
    "      'Make sure you are using a GPU Colab runtime. '\n",
    "      'Go to the Runtime menu and select Choose runtime type.')\n",
    "\n",
    "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
    "# This is usually installed as part of an Nvidia driver package, but the Colab\n",
    "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
    "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
    "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
    "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
    "  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
    "    f.write(\"\"\"{\n",
    "    \"file_format_version\" : \"1.0.0\",\n",
    "    \"ICD\" : {\n",
    "        \"library_path\" : \"libEGL_nvidia.so.0\"\n",
    "    }\n",
    "}\n",
    "\"\"\")\n",
    "\n",
    "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
    "print('Setting environment variable to use GPU rendering:')\n",
    "%env MUJOCO_GL=egl\n",
    "\n",
    "try:\n",
    "  print('Checking that the installation succeeded:')\n",
    "  import mujoco\n",
    "  mujoco.MjModel.from_xml_string('<mujoco/>')\n",
    "except Exception as e:\n",
    "  raise e from RuntimeError(\n",
    "      'Something went wrong during installation. Check the shell output above '\n",
    "      'for more information.\\n'\n",
    "      'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
    "      'by going to the Runtime menu and selecting \"Choose runtime type\".')\n",
    "\n",
    "print('Installation successful.')\n",
    "\n",
    "# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n",
    "xla_flags = os.environ.get('XLA_FLAGS', '')\n",
    "xla_flags += ' --xla_gpu_triton_gemm_any=True'\n",
    "os.environ['XLA_FLAGS'] = xla_flags\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "T5f4w3Kq2X14"
   },
   "outputs": [],
   "source": [
    "#@title Import packages for plotting and creating graphics\n",
    "import time\n",
    "import itertools\n",
    "import numpy as np\n",
    "from typing import Callable, NamedTuple, Optional, Union, List\n",
    "\n",
    "# Graphics and plotting.\n",
    "print('Installing mediapy:')\n",
    "!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
    "!pip install -q mediapy\n",
    "import mediapy as media\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# More legible printing from numpy.\n",
    "np.set_printoptions(precision=3, suppress=True, linewidth=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "id": "ObF1UXrkb0Nd"
   },
   "outputs": [],
   "source": [
    "#@title Import MuJoCo, MJX, and Brax\n",
    "from datetime import datetime\n",
    "from etils import epath\n",
    "import functools\n",
    "from IPython.display import HTML\n",
    "from typing import Any, Dict, Sequence, Tuple, Union\n",
    "import os\n",
    "from ml_collections import config_dict\n",
    "\n",
    "\n",
    "import jax\n",
    "from jax import numpy as jp\n",
    "import numpy as np\n",
    "from flax.training import orbax_utils\n",
    "from flax import struct\n",
    "from matplotlib import pyplot as plt\n",
    "import mediapy as media\n",
    "from orbax import checkpoint as ocp\n",
    "\n",
    "import mujoco\n",
    "from mujoco import mjx\n",
    "\n",
    "from brax import base\n",
    "from brax import envs\n",
    "from brax import math\n",
    "from brax.base import Base, Motion, Transform\n",
    "from brax.base import State as PipelineState\n",
    "from brax.envs.base import Env, PipelineEnv, State\n",
    "from brax.mjx.base import State as MjxState\n",
    "from brax.training.agents.ppo import train as ppo\n",
    "from brax.training.agents.ppo import networks as ppo_networks\n",
    "from brax.io import html, mjcf, model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nj4-Xmx4DFaq"
   },
   "source": [
    "# Introduction to MJX\n",
    "\n",
    "MJX is an implementation of MuJoCo written in [JAX](https://jax.readthedocs.io/en/latest/index.html), enabling large batch training on GPU/TPU. In this notebook, we will demonstrate how to train RL policies with MJX.\n",
    "\n",
    "Before we get into hefty RL workloads, let's get started with a simpler example! The entrypoint into MJX is through MuJoCo, so first we load a MuJoCo model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "bNus3mbbDz6a"
   },
   "outputs": [],
   "source": [
    "xml = \"\"\"\n",
    "<mujoco>\n",
    "  <worldbody>\n",
    "    <light name=\"top\" pos=\"0 0 1\"/>\n",
    "    <body name=\"box_and_sphere\" euler=\"0 0 -30\">\n",
    "      <joint name=\"swing\" type=\"hinge\" axis=\"1 -1 0\" pos=\"-.2 -.2 -.2\"/>\n",
    "      <geom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/>\n",
    "      <geom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/>\n",
    "    </body>\n",
    "  </worldbody>\n",
    "</mujoco>\n",
    "\"\"\"\n",
    "\n",
    "# Make model, data, and renderer\n",
    "mj_model = mujoco.MjModel.from_xml_string(xml)\n",
    "mj_data = mujoco.MjData(mj_model)\n",
    "renderer = mujoco.Renderer(mj_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Po5oykJbFQbj"
   },
   "source": [
    "Next we take the MuJoCo model and data, and place them on the GPU device using MJX."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "TSpoOWqeEC3P"
   },
   "outputs": [],
   "source": [
    "mjx_model = mjx.put_model(mj_model)\n",
    "mjx_data = mjx.put_data(mj_model, mj_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6rxMMSs4OJJf"
   },
   "source": [
    "Below, we print the `qpos` from MuJoCo and MJX. Notice that the `qpos` for the mjData is a numpy array living on the CPU, while the `qpos` for `mjx.Data` is a JAX Array living on the GPU device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "ZOD582pfOLP-"
   },
   "outputs": [],
   "source": [
    "print(mj_data.qpos, type(mj_data.qpos))\n",
    "print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZShF9-o_JLm3"
   },
   "source": [
    "Let's run the simulation in MuJoCo and render the trajectory. This example is taken from the [MuJoCo tutorial](https://colab.sandbox.google.com/github/google-deepmind/mujoco/blob/main/python/tutorial.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "HDlPlX05I3m-"
   },
   "outputs": [],
   "source": [
    "# enable joint visualization option:\n",
    "scene_option = mujoco.MjvOption()\n",
    "scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True\n",
    "\n",
    "duration = 3.8  # (seconds)\n",
    "framerate = 60  # (Hz)\n",
    "\n",
    "frames = []\n",
    "mujoco.mj_resetData(mj_model, mj_data)\n",
    "while mj_data.time < duration:\n",
    "  mujoco.mj_step(mj_model, mj_data)\n",
    "  if len(frames) < mj_data.time * framerate:\n",
    "    renderer.update_scene(mj_data, scene_option=scene_option)\n",
    "    pixels = renderer.render()\n",
    "    frames.append(pixels)\n",
    "\n",
    "# Simulate and display video.\n",
    "media.show_video(frames, fps=framerate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m70b_RxBJOyd"
   },
   "source": [
    "Now let's run the same exact simulation on the GPU device using MJX!\n",
    "\n",
    "In the example below, we use `mjx.step` instead of `mujoco.mj_step`, and we also [`jax.jit`](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) the `mjx.step` so that it runs efficiently on the GPU. For each frame, we convert the `mjx.Data` back to `mjData` so that we can use the MuJoCo renderer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Pr29xq0-JRQv"
   },
   "outputs": [],
   "source": [
    "\n",
    "jit_step = jax.jit(mjx.step)\n",
    "\n",
    "frames = []\n",
    "mujoco.mj_resetData(mj_model, mj_data)\n",
    "mjx_data = mjx.put_data(mj_model, mj_data)\n",
    "while mjx_data.time < duration:\n",
    "  mjx_data = jit_step(mjx_model, mjx_data)\n",
    "  if len(frames) < mjx_data.time * framerate:\n",
    "    mj_data = mjx.get_data(mj_model, mjx_data)\n",
    "    renderer.update_scene(mj_data, scene_option=scene_option)\n",
    "    pixels = renderer.render()\n",
    "    frames.append(pixels)\n",
    "\n",
    "media.show_video(frames, fps=framerate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wXsQ4qO2KO3Q"
   },
   "source": [
    "Running single threaded physics simulation on the GPU is not very [efficient](https://mujoco.readthedocs.io/en/stable/mjx.html#mjx-the-sharp-bits). The advantage with MJX is that we can run environments in parallel on a hardware accelerated device. Let's try it out!\n",
    "\n",
    "In the example below, we create 4096 copies of the `mjx.Data` and we run the `mjx.step` over the batched data. Since MJX is implemented in JAX, we take advantage of [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) to run the `mjx.step` in parallel over all `mjx.Data`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "rrdrcKRVK6w9"
   },
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)\n",
    "rng = jax.random.split(rng, 4096)\n",
    "batch = jax.vmap(lambda rng: mjx_data.replace(qpos=jax.random.uniform(rng, (1,))))(rng)\n",
    "\n",
    "jit_step = jax.jit(jax.vmap(mjx.step, in_axes=(None, 0)))\n",
    "batch = jit_step(mjx_model, batch)\n",
    "\n",
    "print(batch.qpos)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "x4lL220cOj0q"
   },
   "source": [
    "We can copy the batched `mjx.Data` back to MuJoCo like we did before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Jtz7j1PDOnw5"
   },
   "outputs": [],
   "source": [
    "batched_mj_data = mjx.get_data(mj_model, batch)\n",
    "print([d.qpos for d in batched_mj_data])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RAv6WUVUm78k"
   },
   "source": [
    "# Training a Policy with MJX\n",
    "\n",
    "Running large batch physics simulation is useful for training RL policies. Here we demonstrate training RL policies with MJX using the RL library from [Brax](https://github.com/google/brax).\n",
    "\n",
    "Below, we implement the classic Humanoid environment using MJX and Brax. We inherit from the `MjxEnv` implementation in Brax so that we can step the physics with MJX while training with Brax RL implementations.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "mtGMYNLE3QJN"
   },
   "outputs": [],
   "source": [
    "#@title Humanoid Env\n",
    "\n",
    "HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'\n",
    "\n",
    "class Humanoid(PipelineEnv):\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      forward_reward_weight=1.25,\n",
    "      ctrl_cost_weight=0.1,\n",
    "      healthy_reward=5.0,\n",
    "      terminate_when_unhealthy=True,\n",
    "      healthy_z_range=(1.0, 2.0),\n",
    "      reset_noise_scale=1e-2,\n",
    "      exclude_current_positions_from_observation=True,\n",
    "      **kwargs,\n",
    "  ):\n",
    "#\n",
    "    mj_model = mujoco.MjModel.from_xml_path(\n",
    "        (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())\n",
    "    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG\n",
    "    mj_model.opt.iterations = 6\n",
    "    mj_model.opt.ls_iterations = 6\n",
    "\n",
    "    sys = mjcf.load_model(mj_model)\n",
    "\n",
    "    physics_steps_per_control_step = 5\n",
    "    kwargs['n_frames'] = kwargs.get(\n",
    "        'n_frames', physics_steps_per_control_step)\n",
    "    kwargs['backend'] = 'mjx'\n",
    "\n",
    "    super().__init__(sys, **kwargs)\n",
    "\n",
    "    self._forward_reward_weight = forward_reward_weight\n",
    "    self._ctrl_cost_weight = ctrl_cost_weight\n",
    "    self._healthy_reward = healthy_reward\n",
    "    self._terminate_when_unhealthy = terminate_when_unhealthy\n",
    "    self._healthy_z_range = healthy_z_range\n",
    "    self._reset_noise_scale = reset_noise_scale\n",
    "    self._exclude_current_positions_from_observation = (\n",
    "        exclude_current_positions_from_observation\n",
    "    )\n",
    "\n",
    "  def reset(self, rng: jp.ndarray) -> State:\n",
    "    \"\"\"Resets the environment to an initial state.\"\"\"\n",
    "    rng, rng1, rng2 = jax.random.split(rng, 3)\n",
    "\n",
    "    low, hi = -self._reset_noise_scale, self._reset_noise_scale\n",
    "    qpos = self.sys.qpos0 + jax.random.uniform(\n",
    "        rng1, (self.sys.nq,), minval=low, maxval=hi\n",
    "    )\n",
    "    qvel = jax.random.uniform(\n",
    "        rng2, (self.sys.nv,), minval=low, maxval=hi\n",
    "    )\n",
    "\n",
    "    data = self.pipeline_init(qpos, qvel)\n",
    "\n",
    "    obs = self._get_obs(data, jp.zeros(self.sys.nu))\n",
    "    reward, done, zero = jp.zeros(3)\n",
    "    metrics = {\n",
    "        'forward_reward': zero,\n",
    "        'reward_linvel': zero,\n",
    "        'reward_quadctrl': zero,\n",
    "        'reward_alive': zero,\n",
    "        'x_position': zero,\n",
    "        'y_position': zero,\n",
    "        'distance_from_origin': zero,\n",
    "        'x_velocity': zero,\n",
    "        'y_velocity': zero,\n",
    "    }\n",
    "    return State(data, obs, reward, done, metrics)\n",
    "\n",
    "  def step(self, state: State, action: jp.ndarray) -> State:\n",
    "    \"\"\"Runs one timestep of the environment's dynamics.\"\"\"\n",
    "    data0 = state.pipeline_state\n",
    "    data = self.pipeline_step(data0, action)\n",
    "\n",
    "    com_before = data0.subtree_com[1]\n",
    "    com_after = data.subtree_com[1]\n",
    "    velocity = (com_after - com_before) / self.dt\n",
    "    forward_reward = self._forward_reward_weight * velocity[0]\n",
    "\n",
    "    min_z, max_z = self._healthy_z_range\n",
    "    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)\n",
    "    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)\n",
    "    if self._terminate_when_unhealthy:\n",
    "      healthy_reward = self._healthy_reward\n",
    "    else:\n",
    "      healthy_reward = self._healthy_reward * is_healthy\n",
    "\n",
    "    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))\n",
    "\n",
    "    obs = self._get_obs(data, action)\n",
    "    reward = forward_reward + healthy_reward - ctrl_cost\n",
    "    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0\n",
    "    state.metrics.update(\n",
    "        forward_reward=forward_reward,\n",
    "        reward_linvel=forward_reward,\n",
    "        reward_quadctrl=-ctrl_cost,\n",
    "        reward_alive=healthy_reward,\n",
    "        x_position=com_after[0],\n",
    "        y_position=com_after[1],\n",
    "        distance_from_origin=jp.linalg.norm(com_after),\n",
    "        x_velocity=velocity[0],\n",
    "        y_velocity=velocity[1],\n",
    "    )\n",
    "\n",
    "    return state.replace(\n",
    "        pipeline_state=data, obs=obs, reward=reward, done=done\n",
    "    )\n",
    "\n",
    "  def _get_obs(\n",
    "      self, data: mjx.Data, action: jp.ndarray\n",
    "  ) -> jp.ndarray:\n",
    "    \"\"\"Observes humanoid body position, velocities, and angles.\"\"\"\n",
    "    position = data.qpos\n",
    "    if self._exclude_current_positions_from_observation:\n",
    "      position = position[2:]\n",
    "\n",
    "    # external_contact_forces are excluded\n",
    "    return jp.concatenate([\n",
    "        position,\n",
    "        data.qvel,\n",
    "        data.cinert[1:].ravel(),\n",
    "        data.cvel[1:].ravel(),\n",
    "        data.qfrc_actuator,\n",
    "    ])\n",
    "\n",
    "\n",
    "envs.register_environment('humanoid', Humanoid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P1K6IznI2y83"
   },
   "source": [
    "## Visualize a Rollout\n",
    "\n",
    "Let's instantiate the environment and visualize a short rollout.\n",
    "\n",
    "NOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane. We turn off other contacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "EhKLFK54C1CH"
   },
   "outputs": [],
   "source": [
    "# instantiate the environment\n",
    "env_name = 'humanoid'\n",
    "env = envs.get_environment(env_name)\n",
    "\n",
    "# define the jit reset/step functions\n",
    "jit_reset = jax.jit(env.reset)\n",
    "jit_step = jax.jit(env.step)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Ph8u-v2Q2xLS"
   },
   "outputs": [],
   "source": [
    "# initialize the state\n",
    "state = jit_reset(jax.random.PRNGKey(0))\n",
    "rollout = [state.pipeline_state]\n",
    "\n",
    "# grab a trajectory\n",
    "for i in range(10):\n",
    "  ctrl = -0.1 * jp.ones(env.sys.nu)\n",
    "  state = jit_step(state, ctrl)\n",
    "  rollout.append(state.pipeline_state)\n",
    "\n",
    "media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BQDG6NQ1CbZD"
   },
   "source": [
    "## Train Humanoid Policy\n",
    "\n",
    "Let's now train a policy with PPO to make the Humanoid run forwards. Training takes about 6 minutes on a Tesla A100 GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "xLiddQYPApBw"
   },
   "outputs": [],
   "source": [
    "train_fn = functools.partial(\n",
    "    ppo.train, num_timesteps=20_000_000, num_evals=5, reward_scaling=0.1,\n",
    "    episode_length=1000, normalize_observations=True, action_repeat=1,\n",
    "    unroll_length=10, num_minibatches=24, num_updates_per_batch=8,\n",
    "    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=3072,\n",
    "    batch_size=512, seed=0)\n",
    "\n",
    "\n",
    "x_data = []\n",
    "y_data = []\n",
    "ydataerr = []\n",
    "times = [datetime.now()]\n",
    "\n",
    "max_y, min_y = 13000, 0\n",
    "def progress(num_steps, metrics):\n",
    "  times.append(datetime.now())\n",
    "  x_data.append(num_steps)\n",
    "  y_data.append(metrics['eval/episode_reward'])\n",
    "  ydataerr.append(metrics['eval/episode_reward_std'])\n",
    "\n",
    "  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])\n",
    "  plt.ylim([min_y, max_y])\n",
    "\n",
    "  plt.xlabel('# environment steps')\n",
    "  plt.ylabel('reward per episode')\n",
    "  plt.title(f'y={y_data[-1]:.3f}')\n",
    "\n",
    "  plt.errorbar(\n",
    "      x_data, y_data, yerr=ydataerr)\n",
    "  plt.show()\n",
    "\n",
    "make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)\n",
    "\n",
    "print(f'time to jit: {times[1] - times[0]}')\n",
    "print(f'time to train: {times[-1] - times[1]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YYIch0HEApBx"
   },
   "source": [
    "<!-- ## Save and Load Policy -->\n",
    "\n",
    "We can save and load the policy using the brax model API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "Z8gI6qH6ApBx"
   },
   "outputs": [],
   "source": [
    "#@title Save Model\n",
    "model_path = '/tmp/mjx_brax_policy'\n",
    "model.save_params(model_path, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "h4reaWgxApBx"
   },
   "outputs": [],
   "source": [
    "#@title Load Model and Define Inference Function\n",
    "params = model.load_params(model_path)\n",
    "\n",
    "inference_fn = make_inference_fn(params)\n",
    "jit_inference_fn = jax.jit(inference_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0G357XIfApBy"
   },
   "source": [
    "## Visualize Policy\n",
    "\n",
    "Finally we can visualize the policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "osYasMw4ApBy"
   },
   "outputs": [],
   "source": [
    "eval_env = envs.get_environment(env_name)\n",
    "\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "d-UhypudApBy"
   },
   "outputs": [],
   "source": [
    "# initialize the state\n",
    "rng = jax.random.PRNGKey(0)\n",
    "state = jit_reset(rng)\n",
    "rollout = [state.pipeline_state]\n",
    "\n",
    "# grab a trajectory\n",
    "n_steps = 500\n",
    "render_every = 2\n",
    "\n",
    "for i in range(n_steps):\n",
    "  act_rng, rng = jax.random.split(rng)\n",
    "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "  state = jit_step(state, ctrl)\n",
    "  rollout.append(state.pipeline_state)\n",
    "\n",
    "  if state.done:\n",
    "    break\n",
    "\n",
    "media.show_video(env.render(rollout[::render_every], camera='side'), fps=1.0 / env.dt / render_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zR-heox6LARK"
   },
   "source": [
    "# MJX Policy in MuJoCo\n",
    "\n",
    "We can also perform the physics step using the original MuJoCo python bindings to show that the policy trained in MJX works in MuJoCo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "w6ixFi4dApBy"
   },
   "outputs": [],
   "source": [
    "mj_model = eval_env.sys.mj_model\n",
    "mj_data = mujoco.MjData(mj_model)\n",
    "\n",
    "renderer = mujoco.Renderer(mj_model)\n",
    "ctrl = jp.zeros(mj_model.nu)\n",
    "\n",
    "images = []\n",
    "for i in range(n_steps):\n",
    "  act_rng, rng = jax.random.split(rng)\n",
    "\n",
    "  obs = eval_env._get_obs(mjx.put_data(mj_model, mj_data), ctrl)\n",
    "  ctrl, _ = jit_inference_fn(obs, act_rng)\n",
    "\n",
    "  mj_data.ctrl = ctrl\n",
    "  for _ in range(eval_env._n_frames):\n",
    "    mujoco.mj_step(mj_model, mj_data)  # Physics step using MuJoCo mj_step.\n",
    "\n",
    "  if i % render_every == 0:\n",
    "    renderer.update_scene(mj_data, camera='side')\n",
    "    images.append(renderer.render())\n",
    "\n",
    "media.show_video(images, fps=1.0 / eval_env.dt / render_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "65mIPj6DQNNa"
   },
   "source": [
    "# Training a Policy with Domain Randomization\n",
    "\n",
    "We might also want to include randomization over certain `mjModel` parameters while training a policy. In MJX, we can easily create a batch of environments with randomized values populated in `mjx.Model`. Below, we show a function that randomizes friction and actuator gain/bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "h8mhzKjHQuoL"
   },
   "outputs": [],
   "source": [
    "def domain_randomize(sys, rng):\n",
    "  \"\"\"Randomizes the mjx.Model.\"\"\"\n",
    "  @jax.vmap\n",
    "  def rand(rng):\n",
    "    _, key = jax.random.split(rng, 2)\n",
    "    # friction\n",
    "    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)\n",
    "    friction = sys.geom_friction.at[:, 0].set(friction)\n",
    "    # actuator\n",
    "    _, key = jax.random.split(key, 2)\n",
    "    gain_range = (-5, 5)\n",
    "    param = jax.random.uniform(\n",
    "        key, (1,), minval=gain_range[0], maxval=gain_range[1]\n",
    "    ) + sys.actuator_gainprm[:, 0]\n",
    "    gain = sys.actuator_gainprm.at[:, 0].set(param)\n",
    "    bias = sys.actuator_biasprm.at[:, 1].set(-param)\n",
    "    return friction, gain, bias\n",
    "\n",
    "  friction, gain, bias = rand(rng)\n",
    "\n",
    "  in_axes = jax.tree_util.tree_map(lambda x: None, sys)\n",
    "  in_axes = in_axes.tree_replace({\n",
    "      'geom_friction': 0,\n",
    "      'actuator_gainprm': 0,\n",
    "      'actuator_biasprm': 0,\n",
    "  })\n",
    "\n",
    "  sys = sys.tree_replace({\n",
    "      'geom_friction': friction,\n",
    "      'actuator_gainprm': gain,\n",
    "      'actuator_biasprm': bias,\n",
    "  })\n",
    "\n",
    "  return sys, in_axes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gnsZo-GWSYYj"
   },
   "source": [
    "If we wanted 10 environments with randomized friction and actuator params, we can call `domain_randomize`, which returns a batched `mjx.Model` along with a dictionary specifying the axes that are batched."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "1K45Kp2ASV9s"
   },
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)\n",
    "rng = jax.random.split(rng, 10)\n",
    "batched_sys, _ = domain_randomize(env.sys, rng)\n",
    "\n",
    "print('Single env friction shape: ', env.sys.geom_friction.shape)\n",
    "print('Batched env friction shape: ', batched_sys.geom_friction.shape)\n",
    "\n",
    "print('Friction on geom 0: ', env.sys.geom_friction[0, 0])\n",
    "print('Random frictions on geom 0: ', batched_sys.geom_friction[:, 0, 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "efnxNOnpQFuC"
   },
   "source": [
    "## Quadruped Env\n",
    "\n",
    "Let's define a quadruped environment that takes advantage of the domain randomization function. Here we use the [Barkour vb Quadruped](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_vb) from [MuJoCo Menagerie](https://github.com/google-deepmind/mujoco_menagerie). We implement an environment that trains a joystick policy with Brax.\n",
    "\n",
    "NOTE: for a full suite of robotic environments, many of which were transferred onto robots, check out [MuJoCo Playground](https://github.com/google-deepmind/mujoco_playground)!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "VfyK73gtRXid"
   },
   "outputs": [],
   "source": [
    "!git clone https://github.com/google-deepmind/mujoco_menagerie\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "y79PoJOCIl-O"
   },
   "outputs": [],
   "source": [
    "#@title Barkour vb Quadruped Env\n",
    "\n",
    "BARKOUR_ROOT_PATH = epath.Path('mujoco_menagerie/google_barkour_vb')\n",
    "\n",
    "\n",
    "def get_config():\n",
    "  \"\"\"Returns reward config for barkour quadruped environment.\"\"\"\n",
    "\n",
    "  def get_default_rewards_config():\n",
    "    default_config = config_dict.ConfigDict(\n",
    "        dict(\n",
    "            # The coefficients for all reward terms used for training. All\n",
    "            # physical quantities are in SI units, if no otherwise specified,\n",
    "            # i.e. joint positions are in rad, positions are measured in meters,\n",
    "            # torques in Nm, and time in seconds, and forces in Newtons.\n",
    "            scales=config_dict.ConfigDict(\n",
    "                dict(\n",
    "                    # Tracking rewards are computed using exp(-delta^2/sigma)\n",
    "                    # sigma can be a hyperparameters to tune.\n",
    "                    # Track the base x-y velocity (no z-velocity tracking.)\n",
    "                    tracking_lin_vel=1.5,\n",
    "                    # Track the angular velocity along z-axis, i.e. yaw rate.\n",
    "                    tracking_ang_vel=0.8,\n",
    "                    # Below are regularization terms, we roughly divide the\n",
    "                    # terms to base state regularizations, joint\n",
    "                    # regularizations, and other behavior regularizations.\n",
    "                    # Penalize the base velocity in z direction, L2 penalty.\n",
    "                    lin_vel_z=-2.0,\n",
    "                    # Penalize the base roll and pitch rate. L2 penalty.\n",
    "                    ang_vel_xy=-0.05,\n",
    "                    # Penalize non-zero roll and pitch angles. L2 penalty.\n",
    "                    orientation=-5.0,\n",
    "                    # L2 regularization of joint torques, |tau|^2.\n",
    "                    torques=-0.0002,\n",
    "                    # Penalize the change in the action and encourage smooth\n",
    "                    # actions. L2 regularization |action - last_action|^2\n",
    "                    action_rate=-0.01,\n",
    "                    # Encourage long swing steps.  However, it does not\n",
    "                    # encourage high clearances.\n",
    "                    feet_air_time=0.2,\n",
    "                    # Encourage no motion at zero command, L2 regularization\n",
    "                    # |q - q_default|^2.\n",
    "                    stand_still=-0.5,\n",
    "                    # Early termination penalty.\n",
    "                    termination=-1.0,\n",
    "                    # Penalizing foot slipping on the ground.\n",
    "                    foot_slip=-0.1,\n",
    "                )\n",
    "            ),\n",
    "            # Tracking reward = exp(-error^2/sigma).\n",
    "            tracking_sigma=0.25,\n",
    "        )\n",
    "    )\n",
    "    return default_config\n",
    "\n",
    "  default_config = config_dict.ConfigDict(\n",
    "      dict(\n",
    "          rewards=get_default_rewards_config(),\n",
    "      )\n",
    "  )\n",
    "\n",
    "  return default_config\n",
    "\n",
    "\n",
    "class BarkourEnv(PipelineEnv):\n",
    "  \"\"\"Environment for training the barkour quadruped joystick policy in MJX.\"\"\"\n",
    "\n",
    "  def __init__(\n",
    "      self,\n",
    "      obs_noise: float = 0.05,\n",
    "      action_scale: float = 0.3,\n",
    "      kick_vel: float = 0.05,\n",
    "      scene_file: str = 'scene_mjx.xml',\n",
    "      **kwargs,\n",
    "  ):\n",
    "    path = BARKOUR_ROOT_PATH / scene_file\n",
    "    sys = mjcf.load(path.as_posix())\n",
    "    self._dt = 0.02  # this environment is 50 fps\n",
    "    sys = sys.tree_replace({'opt.timestep': 0.004})\n",
    "\n",
    "    # override menagerie params for smoother policy\n",
    "    sys = sys.replace(\n",
    "        dof_damping=sys.dof_damping.at[6:].set(0.5239),\n",
    "        actuator_gainprm=sys.actuator_gainprm.at[:, 0].set(35.0),\n",
    "        actuator_biasprm=sys.actuator_biasprm.at[:, 1].set(-35.0),\n",
    "    )\n",
    "\n",
    "    n_frames = kwargs.pop('n_frames', int(self._dt / sys.opt.timestep))\n",
    "    super().__init__(sys, backend='mjx', n_frames=n_frames)\n",
    "\n",
    "    self.reward_config = get_config()\n",
    "    # set custom from kwargs\n",
    "    for k, v in kwargs.items():\n",
    "      if k.endswith('_scale'):\n",
    "        self.reward_config.rewards.scales[k[:-6]] = v\n",
    "\n",
    "    self._torso_idx = mujoco.mj_name2id(\n",
    "        sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, 'torso'\n",
    "    )\n",
    "    self._action_scale = action_scale\n",
    "    self._obs_noise = obs_noise\n",
    "    self._kick_vel = kick_vel\n",
    "    self._init_q = jp.array(sys.mj_model.keyframe('home').qpos)\n",
    "    self._default_pose = sys.mj_model.keyframe('home').qpos[7:]\n",
    "    self.lowers = jp.array([-0.7, -1.0, 0.05] * 4)\n",
    "    self.uppers = jp.array([0.52, 2.1, 2.1] * 4)\n",
    "    feet_site = [\n",
    "        'foot_front_left',\n",
    "        'foot_hind_left',\n",
    "        'foot_front_right',\n",
    "        'foot_hind_right',\n",
    "    ]\n",
    "    feet_site_id = [\n",
    "        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_SITE.value, f)\n",
    "        for f in feet_site\n",
    "    ]\n",
    "    assert not any(id_ == -1 for id_ in feet_site_id), 'Site not found.'\n",
    "    self._feet_site_id = np.array(feet_site_id)\n",
    "    lower_leg_body = [\n",
    "        'lower_leg_front_left',\n",
    "        'lower_leg_hind_left',\n",
    "        'lower_leg_front_right',\n",
    "        'lower_leg_hind_right',\n",
    "    ]\n",
    "    lower_leg_body_id = [\n",
    "        mujoco.mj_name2id(sys.mj_model, mujoco.mjtObj.mjOBJ_BODY.value, l)\n",
    "        for l in lower_leg_body\n",
    "    ]\n",
    "    assert not any(id_ == -1 for id_ in lower_leg_body_id), 'Body not found.'\n",
    "    self._lower_leg_body_id = np.array(lower_leg_body_id)\n",
    "    self._foot_radius = 0.0175\n",
    "    self._nv = sys.nv\n",
    "\n",
    "  def sample_command(self, rng: jax.Array) -> jax.Array:\n",
    "    lin_vel_x = [-0.6, 1.5]  # min max [m/s]\n",
    "    lin_vel_y = [-0.8, 0.8]  # min max [m/s]\n",
    "    ang_vel_yaw = [-0.7, 0.7]  # min max [rad/s]\n",
    "\n",
    "    _, key1, key2, key3 = jax.random.split(rng, 4)\n",
    "    lin_vel_x = jax.random.uniform(\n",
    "        key1, (1,), minval=lin_vel_x[0], maxval=lin_vel_x[1]\n",
    "    )\n",
    "    lin_vel_y = jax.random.uniform(\n",
    "        key2, (1,), minval=lin_vel_y[0], maxval=lin_vel_y[1]\n",
    "    )\n",
    "    ang_vel_yaw = jax.random.uniform(\n",
    "        key3, (1,), minval=ang_vel_yaw[0], maxval=ang_vel_yaw[1]\n",
    "    )\n",
    "    new_cmd = jp.array([lin_vel_x[0], lin_vel_y[0], ang_vel_yaw[0]])\n",
    "    return new_cmd\n",
    "\n",
    "  def reset(self, rng: jax.Array) -> State:  # pytype: disable=signature-mismatch\n",
    "    rng, key = jax.random.split(rng)\n",
    "\n",
    "    pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))\n",
    "\n",
    "    state_info = {\n",
    "        'rng': rng,\n",
    "        'last_act': jp.zeros(12),\n",
    "        'last_vel': jp.zeros(12),\n",
    "        'command': self.sample_command(key),\n",
    "        'last_contact': jp.zeros(4, dtype=bool),\n",
    "        'feet_air_time': jp.zeros(4),\n",
    "        'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()},\n",
    "        'kick': jp.array([0.0, 0.0]),\n",
    "        'step': 0,\n",
    "    }\n",
    "\n",
    "    obs_history = jp.zeros(15 * 31)  # store 15 steps of history\n",
    "    obs = self._get_obs(pipeline_state, state_info, obs_history)\n",
    "    reward, done = jp.zeros(2)\n",
    "    metrics = {'total_dist': 0.0}\n",
    "    for k in state_info['rewards']:\n",
    "      metrics[k] = state_info['rewards'][k]\n",
    "    state = State(pipeline_state, obs, reward, done, metrics, state_info)  # pytype: disable=wrong-arg-types\n",
    "    return state\n",
    "\n",
    "  def step(self, state: State, action: jax.Array) -> State:  # pytype: disable=signature-mismatch\n",
    "    rng, cmd_rng, kick_noise_2 = jax.random.split(state.info['rng'], 3)\n",
    "\n",
    "    # kick\n",
    "    push_interval = 10\n",
    "    kick_theta = jax.random.uniform(kick_noise_2, maxval=2 * jp.pi)\n",
    "    kick = jp.array([jp.cos(kick_theta), jp.sin(kick_theta)])\n",
    "    kick *= jp.mod(state.info['step'], push_interval) == 0\n",
    "    qvel = state.pipeline_state.qvel  # pytype: disable=attribute-error\n",
    "    qvel = qvel.at[:2].set(kick * self._kick_vel + qvel[:2])\n",
    "    state = state.tree_replace({'pipeline_state.qvel': qvel})\n",
    "\n",
    "    # physics step\n",
    "    motor_targets = self._default_pose + action * self._action_scale\n",
    "    motor_targets = jp.clip(motor_targets, self.lowers, self.uppers)\n",
    "    pipeline_state = self.pipeline_step(state.pipeline_state, motor_targets)\n",
    "    x, xd = pipeline_state.x, pipeline_state.xd\n",
    "\n",
    "    # observation data\n",
    "    obs = self._get_obs(pipeline_state, state.info, state.obs)\n",
    "    joint_angles = pipeline_state.q[7:]\n",
    "    joint_vel = pipeline_state.qd[6:]\n",
    "\n",
    "    # foot contact data based on z-position\n",
    "    foot_pos = pipeline_state.site_xpos[self._feet_site_id]  # pytype: disable=attribute-error\n",
    "    foot_contact_z = foot_pos[:, 2] - self._foot_radius\n",
    "    contact = foot_contact_z < 1e-3  # a mm or less off the floor\n",
    "    contact_filt_mm = contact | state.info['last_contact']\n",
    "    contact_filt_cm = (foot_contact_z < 3e-2) | state.info['last_contact']\n",
    "    first_contact = (state.info['feet_air_time'] > 0) * contact_filt_mm\n",
    "    state.info['feet_air_time'] += self.dt\n",
    "\n",
    "    # done if joint limits are reached or robot is falling\n",
    "    up = jp.array([0.0, 0.0, 1.0])\n",
    "    done = jp.dot(math.rotate(up, x.rot[self._torso_idx - 1]), up) < 0\n",
    "    done |= jp.any(joint_angles < self.lowers)\n",
    "    done |= jp.any(joint_angles > self.uppers)\n",
    "    done |= pipeline_state.x.pos[self._torso_idx - 1, 2] < 0.18\n",
    "\n",
    "    # reward\n",
    "    rewards = {\n",
    "        'tracking_lin_vel': (\n",
    "            self._reward_tracking_lin_vel(state.info['command'], x, xd)\n",
    "        ),\n",
    "        'tracking_ang_vel': (\n",
    "            self._reward_tracking_ang_vel(state.info['command'], x, xd)\n",
    "        ),\n",
    "        'lin_vel_z': self._reward_lin_vel_z(xd),\n",
    "        'ang_vel_xy': self._reward_ang_vel_xy(xd),\n",
    "        'orientation': self._reward_orientation(x),\n",
    "        'torques': self._reward_torques(pipeline_state.qfrc_actuator),  # pytype: disable=attribute-error\n",
    "        'action_rate': self._reward_action_rate(action, state.info['last_act']),\n",
    "        'stand_still': self._reward_stand_still(\n",
    "            state.info['command'], joint_angles,\n",
    "        ),\n",
    "        'feet_air_time': self._reward_feet_air_time(\n",
    "            state.info['feet_air_time'],\n",
    "            first_contact,\n",
    "            state.info['command'],\n",
    "        ),\n",
    "        'foot_slip': self._reward_foot_slip(pipeline_state, contact_filt_cm),\n",
    "        'termination': self._reward_termination(done, state.info['step']),\n",
    "    }\n",
    "    rewards = {\n",
    "        k: v * self.reward_config.rewards.scales[k] for k, v in rewards.items()\n",
    "    }\n",
    "    reward = jp.clip(sum(rewards.values()) * self.dt, 0.0, 10000.0)\n",
    "\n",
    "    # state management\n",
    "    state.info['kick'] = kick\n",
    "    state.info['last_act'] = action\n",
    "    state.info['last_vel'] = joint_vel\n",
    "    state.info['feet_air_time'] *= ~contact_filt_mm\n",
    "    state.info['last_contact'] = contact\n",
    "    state.info['rewards'] = rewards\n",
    "    state.info['step'] += 1\n",
    "    state.info['rng'] = rng\n",
    "\n",
    "    # sample new command if more than 500 timesteps achieved\n",
    "    state.info['command'] = jp.where(\n",
    "        state.info['step'] > 500,\n",
    "        self.sample_command(cmd_rng),\n",
    "        state.info['command'],\n",
    "    )\n",
    "    # reset the step counter when done\n",
    "    state.info['step'] = jp.where(\n",
    "        done | (state.info['step'] > 500), 0, state.info['step']\n",
    "    )\n",
    "\n",
    "    # log total displacement as a proxy metric\n",
    "    state.metrics['total_dist'] = math.normalize(x.pos[self._torso_idx - 1])[1]\n",
    "    state.metrics.update(state.info['rewards'])\n",
    "\n",
    "    done = jp.float32(done)\n",
    "    state = state.replace(\n",
    "        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done\n",
    "    )\n",
    "    return state\n",
    "\n",
    "  def _get_obs(\n",
    "      self,\n",
    "      pipeline_state: base.State,\n",
    "      state_info: dict[str, Any],\n",
    "      obs_history: jax.Array,\n",
    "  ) -> jax.Array:\n",
    "    inv_torso_rot = math.quat_inv(pipeline_state.x.rot[0])\n",
    "    local_rpyrate = math.rotate(pipeline_state.xd.ang[0], inv_torso_rot)\n",
    "\n",
    "    obs = jp.concatenate([\n",
    "        jp.array([local_rpyrate[2]]) * 0.25,                 # yaw rate\n",
    "        math.rotate(jp.array([0, 0, -1]), inv_torso_rot),    # projected gravity\n",
    "        state_info['command'] * jp.array([2.0, 2.0, 0.25]),  # command\n",
    "        pipeline_state.q[7:] - self._default_pose,           # motor angles\n",
    "        state_info['last_act'],                              # last action\n",
    "    ])\n",
    "\n",
    "    # clip, noise\n",
    "    obs = jp.clip(obs, -100.0, 100.0) + self._obs_noise * jax.random.uniform(\n",
    "        state_info['rng'], obs.shape, minval=-1, maxval=1\n",
    "    )\n",
    "    # stack observations through time\n",
    "    obs = jp.roll(obs_history, obs.size).at[:obs.size].set(obs)\n",
    "\n",
    "    return obs\n",
    "\n",
    "  # ------------ reward functions----------------\n",
    "  def _reward_lin_vel_z(self, xd: Motion) -> jax.Array:\n",
    "    # Penalize z axis base linear velocity\n",
    "    return jp.square(xd.vel[0, 2])\n",
    "\n",
    "  def _reward_ang_vel_xy(self, xd: Motion) -> jax.Array:\n",
    "    # Penalize xy axes base angular velocity\n",
    "    return jp.sum(jp.square(xd.ang[0, :2]))\n",
    "\n",
    "  def _reward_orientation(self, x: Transform) -> jax.Array:\n",
    "    # Penalize non flat base orientation\n",
    "    up = jp.array([0.0, 0.0, 1.0])\n",
    "    rot_up = math.rotate(up, x.rot[0])\n",
    "    return jp.sum(jp.square(rot_up[:2]))\n",
    "\n",
    "  def _reward_torques(self, torques: jax.Array) -> jax.Array:\n",
    "    # Penalize torques\n",
    "    return jp.sqrt(jp.sum(jp.square(torques))) + jp.sum(jp.abs(torques))\n",
    "\n",
    "  def _reward_action_rate(\n",
    "      self, act: jax.Array, last_act: jax.Array\n",
    "  ) -> jax.Array:\n",
    "    # Penalize changes in actions\n",
    "    return jp.sum(jp.square(act - last_act))\n",
    "\n",
    "  def _reward_tracking_lin_vel(\n",
    "      self, commands: jax.Array, x: Transform, xd: Motion\n",
    "  ) -> jax.Array:\n",
    "    # Tracking of linear velocity commands (xy axes)\n",
    "    local_vel = math.rotate(xd.vel[0], math.quat_inv(x.rot[0]))\n",
    "    lin_vel_error = jp.sum(jp.square(commands[:2] - local_vel[:2]))\n",
    "    lin_vel_reward = jp.exp(\n",
    "        -lin_vel_error / self.reward_config.rewards.tracking_sigma\n",
    "    )\n",
    "    return lin_vel_reward\n",
    "\n",
    "  def _reward_tracking_ang_vel(\n",
    "      self, commands: jax.Array, x: Transform, xd: Motion\n",
    "  ) -> jax.Array:\n",
    "    # Tracking of angular velocity commands (yaw)\n",
    "    base_ang_vel = math.rotate(xd.ang[0], math.quat_inv(x.rot[0]))\n",
    "    ang_vel_error = jp.square(commands[2] - base_ang_vel[2])\n",
    "    return jp.exp(-ang_vel_error / self.reward_config.rewards.tracking_sigma)\n",
    "\n",
    "  def _reward_feet_air_time(\n",
    "      self, air_time: jax.Array, first_contact: jax.Array, commands: jax.Array\n",
    "  ) -> jax.Array:\n",
    "    # Reward air time.\n",
    "    rew_air_time = jp.sum((air_time - 0.1) * first_contact)\n",
    "    rew_air_time *= (\n",
    "        math.normalize(commands[:2])[1] > 0.05\n",
    "    )  # no reward for zero command\n",
    "    return rew_air_time\n",
    "\n",
    "  def _reward_stand_still(\n",
    "      self,\n",
    "      commands: jax.Array,\n",
    "      joint_angles: jax.Array,\n",
    "  ) -> jax.Array:\n",
    "    # Penalize motion at zero commands\n",
    "    return jp.sum(jp.abs(joint_angles - self._default_pose)) * (\n",
    "        math.normalize(commands[:2])[1] < 0.1\n",
    "    )\n",
    "\n",
    "  def _reward_foot_slip(\n",
    "      self, pipeline_state: base.State, contact_filt: jax.Array\n",
    "  ) -> jax.Array:\n",
    "    # get velocities at feet which are offset from lower legs\n",
    "    # pytype: disable=attribute-error\n",
    "    pos = pipeline_state.site_xpos[self._feet_site_id]  # feet position\n",
    "    feet_offset = pos - pipeline_state.xpos[self._lower_leg_body_id]\n",
    "    # pytype: enable=attribute-error\n",
    "    offset = base.Transform.create(pos=feet_offset)\n",
    "    foot_indices = self._lower_leg_body_id - 1  # we got rid of the world body\n",
    "    foot_vel = offset.vmap().do(pipeline_state.xd.take(foot_indices)).vel\n",
    "\n",
    "    # Penalize large feet velocity for feet that are in contact with the ground.\n",
    "    return jp.sum(jp.square(foot_vel[:, :2]) * contact_filt.reshape((-1, 1)))\n",
    "\n",
    "  def _reward_termination(self, done: jax.Array, step: jax.Array) -> jax.Array:\n",
    "    return done & (step < 500)\n",
    "\n",
    "  def render(\n",
    "      self, trajectory: List[base.State], camera: str | None = None,\n",
    "      width: int = 240, height: int = 320,\n",
    "  ) -> Sequence[np.ndarray]:\n",
    "    camera = camera or 'track'\n",
    "    return super().render(trajectory, camera=camera, width=width, height=height)\n",
    "\n",
    "envs.register_environment('barkour', BarkourEnv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "pi_yrcz-Qp3W"
   },
   "outputs": [],
   "source": [
    "env_name = 'barkour'\n",
    "env = envs.get_environment(env_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nxaNFP9mA23H"
   },
   "source": [
    "## Train Policy\n",
    "\n",
    "To train a policy with domain randomization, we pass in the domain randomization function into the brax train function; brax will call the domain randomization function when rolling out episodes. Training the quadruped takes 6 minutes on a Tesla A100 GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "cHJCbESGA7Rk"
   },
   "outputs": [],
   "source": [
    "ckpt_path = epath.Path('/tmp/quadrupred_joystick/ckpts')\n",
    "ckpt_path.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "def policy_params_fn(current_step, make_policy, params):\n",
    "  # save checkpoints\n",
    "  orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
    "  save_args = orbax_utils.save_args_from_target(params)\n",
    "  path = ckpt_path / f'{current_step}'\n",
    "  orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
    "\n",
    "\n",
    "make_networks_factory = functools.partial(\n",
    "    ppo_networks.make_ppo_networks,\n",
    "        policy_hidden_layer_sizes=(128, 128, 128, 128))\n",
    "train_fn = functools.partial(\n",
    "      ppo.train, num_timesteps=100_000_000, num_evals=10,\n",
    "      reward_scaling=1, episode_length=1000, normalize_observations=True,\n",
    "      action_repeat=1, unroll_length=20, num_minibatches=32,\n",
    "      num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,\n",
    "      entropy_cost=1e-2, num_envs=8192, batch_size=256,\n",
    "      network_factory=make_networks_factory,\n",
    "      randomization_fn=domain_randomize,\n",
    "      policy_params_fn=policy_params_fn,\n",
    "      seed=0)\n",
    "\n",
    "x_data = []\n",
    "y_data = []\n",
    "ydataerr = []\n",
    "times = [datetime.now()]\n",
    "max_y, min_y = 40, 0\n",
    "\n",
    "# Reset environments since internals may be overwritten by tracers from the\n",
    "# domain randomization function.\n",
    "env = envs.get_environment(env_name)\n",
    "eval_env = envs.get_environment(env_name)\n",
    "make_inference_fn, params, _= train_fn(environment=env,\n",
    "                                       progress_fn=progress,\n",
    "                                       eval_env=eval_env)\n",
    "\n",
    "print(f'time to jit: {times[1] - times[0]}')\n",
    "print(f'time to train: {times[-1] - times[1]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "8Yge-CGP5JoO"
   },
   "outputs": [],
   "source": [
    "# Save and reload params.\n",
    "model_path = '/tmp/mjx_brax_quadruped_policy'\n",
    "model.save_params(model_path, params)\n",
    "params = model.load_params(model_path)\n",
    "\n",
    "inference_fn = make_inference_fn(params)\n",
    "jit_inference_fn = jax.jit(inference_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L01IrN4oCIkC"
   },
   "source": [
    "## Visualize Policy\n",
    "\n",
    "For the Barkour Quadruped, the joystick commands can be set through `x_vel`, `y_vel`, and `ang_vel`. `x_vel` and `y_vel` define the linear forward and sideways velocities with respect to the quadruped torso. `ang_vel` defines the angular velocity of the torso in the z direction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "VTbpEtXnEecd"
   },
   "outputs": [],
   "source": [
    "eval_env = envs.get_environment(env_name)\n",
    "\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "HRRN-8L-BivZ"
   },
   "outputs": [],
   "source": [
    "# @markdown Commands **only used for Barkour Env**:\n",
    "x_vel = 1.0  #@param {type: \"number\"}\n",
    "y_vel = 0.0  #@param {type: \"number\"}\n",
    "ang_vel = -0.5  #@param {type: \"number\"}\n",
    "\n",
    "the_command = jp.array([x_vel, y_vel, ang_vel])\n",
    "\n",
    "# initialize the state\n",
    "rng = jax.random.PRNGKey(0)\n",
    "state = jit_reset(rng)\n",
    "state.info['command'] = the_command\n",
    "rollout = [state.pipeline_state]\n",
    "\n",
    "# grab a trajectory\n",
    "n_steps = 500\n",
    "render_every = 2\n",
    "\n",
    "for i in range(n_steps):\n",
    "  act_rng, rng = jax.random.split(rng)\n",
    "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "  state = jit_step(state, ctrl)\n",
    "  rollout.append(state.pipeline_state)\n",
    "\n",
    "media.show_video(\n",
    "    eval_env.render(rollout[::render_every], camera='track'),\n",
    "    fps=1.0 / eval_env.dt / render_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aD6H6WD0915X"
   },
   "source": [
    "We can also render the rollout using the Brax renderer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "V7jqv08X95u4"
   },
   "outputs": [],
   "source": [
    "HTML(html.render(eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}), rollout))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gNagGnBODotY"
   },
   "source": [
    "## Train Policy with Height Field\n",
    "\n",
    "We may also want the quadruped to learn to walk on rought terrain. Let's take the latest checkpoint from the joystick policy above, and finetune it on a height field terrain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "wlT3xLouKxqT"
   },
   "outputs": [],
   "source": [
    "# use the height field scene\n",
    "scene_file = 'scene_hfield_mjx.xml'\n",
    "\n",
    "env = envs.get_environment(env_name, scene_file=scene_file)\n",
    "jit_reset = jax.jit(env.reset)\n",
    "state = jit_reset(jax.random.PRNGKey(0))\n",
    "plt.imshow(env.render([state.pipeline_state], camera='track')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "jOTx-OyPDqPW"
   },
   "outputs": [],
   "source": [
    "# grab the latest checkpoint from the flat terrain joystick policy\n",
    "latest_ckpts = list(ckpt_path.glob('*'))\n",
    "latest_ckpts.sort(key=lambda x: int(x.as_posix().split('/')[-1]))\n",
    "latest_ckpt = latest_ckpts[-1]\n",
    "\n",
    "train_fn = functools.partial(\n",
    "      ppo.train, num_timesteps=40_000_000, num_evals=5,\n",
    "      reward_scaling=1, episode_length=1000, normalize_observations=True,\n",
    "      action_repeat=1, unroll_length=20, num_minibatches=32,\n",
    "      num_updates_per_batch=4, discounting=0.97, learning_rate=3.0e-4,\n",
    "      entropy_cost=1e-2, num_envs=8192, batch_size=256,\n",
    "      network_factory=make_networks_factory,\n",
    "      randomization_fn=domain_randomize, seed=0,\n",
    "      restore_checkpoint_path=latest_ckpt)\n",
    "\n",
    "x_data = []\n",
    "y_data = []\n",
    "ydataerr = []\n",
    "times = [datetime.now()]\n",
    "max_y, min_y = 40, 0\n",
    "\n",
    "# Reset environments since internals may be overwritten by tracers from the\n",
    "# domain randomization function.\n",
    "env = envs.get_environment(env_name, scene_file=scene_file)\n",
    "eval_env = envs.get_environment(env_name, scene_file=scene_file)\n",
    "make_inference_fn, params, _= train_fn(environment=env,\n",
    "                                       progress_fn=progress,\n",
    "                                       eval_env=eval_env)\n",
    "\n",
    "print(f'time to jit: {times[1] - times[0]}')\n",
    "print(f'time to train: {times[-1] - times[1]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0wHRJyjLFww6"
   },
   "source": [
    "## Visualize Policy with Height Field"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "1X57XkaVFu-v"
   },
   "outputs": [],
   "source": [
    "eval_env = envs.get_environment(env_name, scene_file=scene_file)\n",
    "\n",
    "jit_reset = jax.jit(eval_env.reset)\n",
    "jit_step = jax.jit(eval_env.step)\n",
    "inference_fn = make_inference_fn(params)\n",
    "jit_inference_fn = jax.jit(inference_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "id": "nAxexZcVFu-v"
   },
   "outputs": [],
   "source": [
    "# @markdown Commands **only used for Barkour Env**:\n",
    "x_vel = 1.0  #@param {type: \"number\"}\n",
    "y_vel = 0.0  #@param {type: \"number\"}\n",
    "ang_vel = -0.5  #@param {type: \"number\"}\n",
    "\n",
    "the_command = jp.array([x_vel, y_vel, ang_vel])\n",
    "\n",
    "# initialize the state\n",
    "rng = jax.random.PRNGKey(0)\n",
    "state = jit_reset(rng)\n",
    "state.info['command'] = the_command\n",
    "rollout = [state.pipeline_state]\n",
    "\n",
    "# grab a trajectory\n",
    "n_steps = 500\n",
    "render_every = 2\n",
    "\n",
    "for i in range(n_steps):\n",
    "  act_rng, rng = jax.random.split(rng)\n",
    "  ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
    "  state = jit_step(state, ctrl)\n",
    "  rollout.append(state.pipeline_state)\n",
    "\n",
    "media.show_video(\n",
    "    eval_env.render(rollout[::render_every], camera='track'),\n",
    "    fps=1.0 / eval_env.dt / render_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-Q5gsOOaBYd1"
   },
   "source": [
    "# MuJoCo Playground: Robotics locomotion and manipulation environments + Sim-to-Real!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gluTlHURuC6i"
   },
   "source": [
    "By now, we have shown how MJX can be used to train policies for classic control and robotic locomotion. For a full suite of robotic locomotion and manipulation environments we encourage you to check out [MuJoCo Playground](https://github.com/google-deepmind/mujoco_playground). Many of the robotic environments have been transferred onto robots, as described on the [website](https://playground.mujoco.org/) and [technical report](https://playground.mujoco.org/assets/playground_technical_report.pdf).\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuClass": "premium",
   "gpuType": "V100",
   "machine_shape": "hm",
   "private_outputs": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}