File size: 33,173 Bytes
25e57c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 | {
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Training with Orbit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/orbit\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "456h0idS2Xcq"
},
"source": [
"This example will work through fine-tuning a BERT model using the [Orbit](https://www.tensorflow.org/api_docs/python/orbit) training library.\n",
"\n",
"Orbit is a flexible, lightweight library designed to make it easy to write [custom training loops](https://www.tensorflow.org/tutorials/distribute/custom_training) in TensorFlow. Orbit handles common model training tasks such as saving checkpoints, running model evaluations, and setting up summary writing, while giving users full control over implementing the inner training loop. It integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU).\n",
"\n",
"Most examples on [tensorflow.org](https://www.tensorflow.org/) use custom training loops or [model.fit()](https://www.tensorflow.org/api_docs/python/tf/keras/Model) from Keras. Orbit is a good alternative to `model.fit` if your model is complex and your training loop requires more flexibility, control, or customization. Also, using Orbit can simplify the code when there are many different model architectures that all use the same custom training loop.\n",
"\n",
"This tutorial focuses on setting up and using Orbit, rather than details about BERT, model construction, and data processing. For more in-depth tutorials on these topics, refer to the following tutorials:\n",
"\n",
"* [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) - which goes into detail on these sub-topics.\n",
"* [Fine tune BERT for GLUE on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) - which generalizes the code to run any BERT configuration on any [GLUE](https://www.tensorflow.org/datasets/catalog/glue) sub-task, and runs on TPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TJ4m3khW3p_W"
},
"source": [
"## Install the TensorFlow Models package\n",
"\n",
"Install and import the necessary packages, then configure all the objects necessary for training a model.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FZlj0U8Aq9Gt"
},
"outputs": [],
"source": [
"!pip install -q opencv-python\n",
"!pip install tensorflow>=2.9.0 tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MEJkRrmapr16"
},
"source": [
"The `tf-models-official` package contains both the `orbit` and `tensorflow_models` modules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dUVPW84Zucuq"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"import orbit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18Icocf3lwYD"
},
"source": [
"## Setup for training\n",
"\n",
"This tutorial does not focus on configuring the environment, building the model and optimizer, and loading data. All these techniques are covered in more detail in the [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) and [Fine tune BERT with GLUE](https://www.tensorflow.org/text/tutorials/bert_glue) tutorials.\n",
"\n",
"To view how the training is set up for this tutorial, expand the rest of this section.\n",
"\n",
" \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n",
" \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eExpand Section\u003c/button\u003e --\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ljy0z-i3okCS"
},
"source": [
"### Import the necessary packages\n",
"\n",
"Import the BERT model and dataset building library from [Tensorflow Model Garden](https://github.com/tensorflow/models)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCBo6wxA2b5n"
},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"import pathlib\n",
"import tempfile\n",
"import time\n",
"\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PG1kwhnvq3VC"
},
"outputs": [],
"source": [
"from official.nlp.data import sentence_prediction_dataloader\n",
"from official.nlp import optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PsbhUV_p3wxN"
},
"source": [
"### Configure the distribution strategy\n",
"\n",
"While `tf.distribute` won't help the model's runtime if you're running on a single machine or GPU, it's necessary for TPUs. Setting up a distribution strategy allows you to use the same code regardless of the configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PG702dqstXIk"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
" tf.config.experimental_connect_to_cluster(resolver)\n",
" tf.tpu.experimental.initialize_tpu_system(resolver)\n",
" strategy = tf.distribute.TPUStrategy(resolver)\n",
"else:\n",
" strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eaQgM98deAMu"
},
"source": [
"For more information about the TPU setup, refer to the [TPU guide](https://www.tensorflow.org/guide/tpu)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7aOxMLLV32Zm"
},
"source": [
"### Create a model and an optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YRdWzOfK3_56"
},
"outputs": [],
"source": [
"max_seq_length = 128\n",
"learning_rate = 3e-5\n",
"num_train_epochs = 3\n",
"train_batch_size = 32\n",
"eval_batch_size = 64\n",
"\n",
"train_data_size = 3668\n",
"steps_per_epoch = int(train_data_size / train_batch_size)\n",
"\n",
"train_steps = steps_per_epoch * num_train_epochs\n",
"warmup_steps = int(train_steps * 0.1)\n",
"\n",
"print(\"train batch size: \", train_batch_size)\n",
"print(\"train epochs: \", num_train_epochs)\n",
"print(\"steps_per_epoch: \", steps_per_epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BVw3886Ysse6"
},
"outputs": [],
"source": [
"model_dir = pathlib.Path(tempfile.mkdtemp())\n",
"print(model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mu9cV7ew-cVe"
},
"source": [
"\n",
"Create a BERT Classifier model and a simple optimizer. They must be created inside `strategy.scope` so that the variables can be distributed. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gmwtX0cp-mj5"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" encoder_network = tfm.nlp.encoders.build_encoder(\n",
" tfm.nlp.encoders.EncoderConfig(type=\"bert\"))\n",
" classifier_model = tfm.nlp.models.BertClassifier(\n",
" network=encoder_network, num_classes=2)\n",
"\n",
" optimizer = optimization.create_optimizer(\n",
" init_lr=3e-5,\n",
" num_train_steps=steps_per_epoch * num_train_epochs,\n",
" num_warmup_steps=warmup_steps,\n",
" end_lr=0.0,\n",
" optimizer_type='adamw')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jwJSfewG5jVV"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQy5pYgAf8Ft"
},
"source": [
"### Initialize from a Checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6CE14GEybgRR"
},
"outputs": [],
"source": [
"bert_dir = 'gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12/'\n",
"tf.io.gfile.listdir(bert_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x7fwxz9xidKt"
},
"outputs": [],
"source": [
"bert_checkpoint = bert_dir + 'bert_model.ckpt'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q7EfwVCRe7N_"
},
"outputs": [],
"source": [
"def init_from_ckpt_fn():\n",
" init_checkpoint = tf.train.Checkpoint(**classifier_model.checkpoint_items)\n",
" with strategy.scope():\n",
" (init_checkpoint\n",
" .read(bert_checkpoint)\n",
" .expect_partial()\n",
" .assert_existing_objects_matched())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M0LUMlsde-2f"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" init_from_ckpt_fn()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gAuns4vN_IYV"
},
"source": [
"\n",
"To use Orbit, create a `tf.train.CheckpointManager` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i7NwM1Jq_MX7"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=classifier_model, optimizer=optimizer)\n",
"checkpoint_manager = tf.train.CheckpointManager(\n",
" checkpoint,\n",
" directory=model_dir,\n",
" max_to_keep=5,\n",
" step_counter=optimizer.iterations,\n",
" checkpoint_interval=steps_per_epoch,\n",
" init_fn=init_from_ckpt_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nzeiAFhcCOAo"
},
"source": [
"### Create distributed datasets\n",
"\n",
"As a shortcut for this tutorial, the [GLUE/MPRC dataset](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc) has been converted to a pair of [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) files containing serialized `tf.train.Example` protos.\n",
"\n",
"The data was converted using [this script](https://github.com/tensorflow/models/blob/r2.9.0/official/nlp/data/create_finetuning_data.py).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZVfbiT1dCnDk"
},
"outputs": [],
"source": [
"train_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_train.tf_record\"\n",
"eval_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_eval.tf_record\"\n",
"\n",
"def _dataset_fn(input_file_pattern, \n",
" global_batch_size, \n",
" is_training, \n",
" input_context=None):\n",
" data_config = sentence_prediction_dataloader.SentencePredictionDataConfig(\n",
" input_path=input_file_pattern,\n",
" seq_length=max_seq_length,\n",
" global_batch_size=global_batch_size,\n",
" is_training=is_training)\n",
" return sentence_prediction_dataloader.SentencePredictionDataLoader(\n",
" data_config).load(input_context=input_context)\n",
"\n",
"train_dataset = orbit.utils.make_distributed_dataset(\n",
" strategy, _dataset_fn, input_file_pattern=train_data_path,\n",
" global_batch_size=train_batch_size, is_training=True)\n",
"eval_dataset = orbit.utils.make_distributed_dataset(\n",
" strategy, _dataset_fn, input_file_pattern=eval_data_path,\n",
" global_batch_size=eval_batch_size, is_training=False)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dPgiDBQCjsXW"
},
"source": [
"### Create a loss function\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7MCUmmo2jvXl"
},
"outputs": [],
"source": [
"def loss_fn(labels, logits):\n",
" \"\"\"Classification loss.\"\"\"\n",
" labels = tf.squeeze(labels)\n",
" log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
" one_hot_labels = tf.one_hot(\n",
" tf.cast(labels, dtype=tf.int32), depth=2, dtype=tf.float32)\n",
" per_example_loss = -tf.reduce_sum(\n",
" tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)\n",
" return tf.reduce_mean(per_example_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ohlO-8FQkwsr"
},
"source": [
" \u003c/devsite-expandable\u003e\u003c/div\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymhbvPaEJ96T"
},
"source": [
"## Controllers, Trainers and Evaluators\n",
"\n",
"When using Orbit, the `orbit.Controller` class drives the training. The Controller handles the details of distribution strategies, step counting, TensorBoard summaries, and checkpointing.\n",
"\n",
"To implement the training and evaluation, pass a `trainer` and `evaluator`, which are subclass instances of `orbit.AbstractTrainer` and `orbit.AbstractEvaluator`. Keeping with Orbit's light-weight design, these two classes have a minimal interface.\n",
"\n",
"The Controller drives training and evaluation by calling `trainer.train(num_steps)` and `evaluator.evaluate(num_steps)`. These `train` and `evaluate` methods return a dictionary of results for logging.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a6sU2vBeyXtu"
},
"source": [
"Training is broken into chunks of length `num_steps`. This is set by the Controller's [`steps_per_loop`](https://tensorflow.org/api_docs/python/orbit/Controller#args) argument. With the trainer and evaluator abstract base classes, the meaning of `num_steps` is entirely determined by the implementer.\n",
"\n",
"Some common examples include:\n",
"\n",
"* Having the chunks represent dataset-epoch boundaries, like the default keras setup. \n",
"* Using it to more efficiently dispatch a number of training steps to an accelerator with a single `tf.function` call (like the `steps_per_execution` argument to `Model.compile`). \n",
"* Subdividing into smaller chunks as needed.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p4mXGIRJsf1j"
},
"source": [
"### StandardTrainer and StandardEvaluator\n",
"\n",
"Orbit provides two additional classes, `orbit.StandardTrainer` and `orbit.StandardEvaluator`, to give more structure around the training and evaluation loops.\n",
"\n",
"With StandardTrainer, you only need to set `train_loop_begin`, `train_step`, and `train_loop_end`. The base class handles the loops, dataset logic, and `tf.function` (according to the options set by their `orbit.StandardTrainerOptions`). This is simpler than `orbit.AbstractTrainer`, which requires you to handle the entire loop. StandardEvaluator has a similar structure and simplification to StandardTrainer.\n",
"\n",
"This is effectively an implementation of the `steps_per_execution` approach used by Keras."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-hvZ8PvohmR5"
},
"source": [
"Contrast this with Keras, where training is divided both into epochs (a single pass over the dataset) and `steps_per_execution`(set within [`Model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile). In Keras, metric averages are typically accumulated over an epoch, and reported \u0026 reset between epochs. For efficiency, `steps_per_execution` only controls the number of training steps made per call.\n",
"\n",
"In this simple case, `steps_per_loop` (within `StandardTrainer`) will handle both the metric resets and the number of steps per call. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NoDFN1L-1jIu"
},
"source": [
"The minimal setup when using these base classes is to implement the methods as follows:\n",
"\n",
"1. `StandardTrainer.train_loop_begin` - Reset your training metrics.\n",
"2. `StandardTrainer.train_step` - Apply a single gradient update.\n",
"3. `StandardTrainer.train_loop_end` - Report your training metrics.\n",
"\n",
"and\n",
"\n",
"4. `StandardEvaluator.eval_begin` - Reset your evaluation metrics.\n",
"5. `StandardEvaluator.eval_step` - Run a single evaluation setep.\n",
"6. `StandardEvaluator.eval_reduce` - This is not necessary in this simple setup.\n",
"7. `StandardEvaluator.eval_end` - Report your evaluation metrics.\n",
"\n",
"Depending on the settings, the base class may wrap the `train_step` and `eval_step` code in `tf.function` or `tf.while_loop`, which has some limitations compared to standard python."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3KPA0NDZt2JD"
},
"source": [
"### Define the trainer class"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6LDPsvJwfuPR"
},
"source": [
"In this section you'll create a subclass of `orbit.StandardTrainer` for this task. \n",
"\n",
"Note: To better explain the `BertClassifierTrainer` class, this section defines each method as a stand-alone function and assembles them into a class at the end.\n",
"\n",
"The trainer needs access to the training data, model, optimizer, and distribution strategy. Pass these as arguments to the initializer.\n",
"\n",
"Define a single training metric, `training_loss`, using `tf.keras.metrics.Mean`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6DQYZN5ax-MG"
},
"outputs": [],
"source": [
"def trainer_init(self,\n",
" train_dataset,\n",
" model,\n",
" optimizer,\n",
" strategy):\n",
" self.strategy = strategy\n",
" with self.strategy.scope():\n",
" self.model = model\n",
" self.optimizer = optimizer\n",
" self.global_step = self.optimizer.iterations\n",
" \n",
"\n",
" self.train_loss = tf.keras.metrics.Mean(\n",
" 'training_loss', dtype=tf.float32)\n",
" orbit.StandardTrainer.__init__(self, train_dataset)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QOwHD7U5hVue"
},
"source": [
"Before starting a run of the training loop, the `train_loop_begin` method will reset the `train_loss` metric."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AkpcHqXShWL0"
},
"outputs": [],
"source": [
"def train_loop_begin(self):\n",
" self.train_loss.reset_states()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UjtFOFyxn2BB"
},
"source": [
"The `train_step` is a straight-forward loss-calculation and gradient update that is run by the distribution strategy. This is accomplished by defining the gradient step as a nested function (`step_fn`).\n",
"\n",
"The method receives `tf.distribute.DistributedIterator` to handle the [distributed input](https://www.tensorflow.org/tutorials/distribute/input). The method uses `Strategy.run` to execute `step_fn` and feeds it from the distributed iterator.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QuPwNnT5I-GP"
},
"outputs": [],
"source": [
"def train_step(self, iterator):\n",
"\n",
" def step_fn(inputs):\n",
" labels = inputs.pop(\"label_ids\")\n",
" with tf.GradientTape() as tape:\n",
" model_outputs = self.model(inputs, training=True)\n",
" # Raw loss is used for reporting in metrics/logs.\n",
" raw_loss = loss_fn(labels, model_outputs)\n",
" # Scales down the loss for gradients to be invariant from replicas.\n",
" loss = raw_loss / self.strategy.num_replicas_in_sync\n",
"\n",
" grads = tape.gradient(loss, self.model.trainable_variables)\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
" # For reporting, the metric takes the mean of losses.\n",
" self.train_loss.update_state(raw_loss)\n",
"\n",
" self.strategy.run(step_fn, args=(next(iterator),))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VmQNwx5QpyDt"
},
"source": [
"The `orbit.StandardTrainer` handles the `@tf.function` and loops.\n",
"\n",
"After running through `num_steps` of training, `StandardTrainer` calls `train_loop_end`. The function returns the metric results:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GqCyVk1zzGod"
},
"outputs": [],
"source": [
"def train_loop_end(self):\n",
" return {\n",
" self.train_loss.name: self.train_loss.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xvmLONl80KUv"
},
"source": [
"Build a subclass of `orbit.StandardTrainer` with those methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oRoL7VE6xt1G"
},
"outputs": [],
"source": [
"class BertClassifierTrainer(orbit.StandardTrainer):\n",
" __init__ = trainer_init\n",
" train_loop_begin = train_loop_begin\n",
" train_step = train_step\n",
" train_loop_end = train_loop_end"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yjG4QAWj1B00"
},
"source": [
"### Define the evaluator class\n",
"\n",
"Note: Like the previous section, this section defines each method as a stand-alone function and assembles them into a `BertClassifierEvaluator` class at the end.\n",
"\n",
"The evaluator is even simpler for this task. It needs access to the evaluation dataset, the model, and the strategy. After saving references to those objects, the constructor just needs to create the metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cvX7seCY1CWj"
},
"outputs": [],
"source": [
"def evaluator_init(self,\n",
" eval_dataset,\n",
" model,\n",
" strategy):\n",
" self.strategy = strategy\n",
" with self.strategy.scope():\n",
" self.model = model\n",
" \n",
" self.eval_loss = tf.keras.metrics.Mean(\n",
" 'evaluation_loss', dtype=tf.float32)\n",
" self.eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
" name='accuracy', dtype=tf.float32)\n",
" orbit.StandardEvaluator.__init__(self, eval_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0r-z-XK7ybyX"
},
"source": [
"Similar to the trainer, the `eval_begin` and `eval_end` methods just need to reset the metrics before the loop and then report the results after the loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7VVb0Tg6yZjI"
},
"outputs": [],
"source": [
"def eval_begin(self):\n",
" self.eval_accuracy.reset_states()\n",
" self.eval_loss.reset_states()\n",
"\n",
"def eval_end(self):\n",
" return {\n",
" self.eval_accuracy.name: self.eval_accuracy.result(),\n",
" self.eval_loss.name: self.eval_loss.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iDOZcQvttdmZ"
},
"source": [
"The `eval_step` method works like `train_step`. The inner `step_fn` defines the actual work of calculating the loss \u0026 accuracy and updating the metrics. The outer `eval_step` receives `tf.distribute.DistributedIterator` as input, and uses `Strategy.run` to launch the distributed execution to `step_fn`, feeding it from the distributed iterator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JLJnYuuGJjvd"
},
"outputs": [],
"source": [
"def eval_step(self, iterator):\n",
"\n",
" def step_fn(inputs):\n",
" labels = inputs.pop(\"label_ids\")\n",
" model_outputs = self.model(inputs, training=True)\n",
" loss = loss_fn(labels, model_outputs)\n",
" self.eval_loss.update_state(loss)\n",
" self.eval_accuracy.update_state(labels, model_outputs)\n",
"\n",
" self.strategy.run(step_fn, args=(next(iterator),))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gt3hh0V30QcP"
},
"source": [
"Build a subclass of `orbit.StandardEvaluator` with those methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3zqyLxfNyCgA"
},
"outputs": [],
"source": [
"class BertClassifierEvaluator(orbit.StandardEvaluator):\n",
" __init__ = evaluator_init\n",
" eval_begin = eval_begin\n",
" eval_end = eval_end\n",
" eval_step = eval_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aK9gEja9qPOc"
},
"source": [
"### End-to-end training and evaluation\n",
"\n",
"To run the training and evaluation, simply create the trainer, evaluator, and `orbit.Controller` instances. Then call the `Controller.train_and_evaluate` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PqQetxyXqRA9"
},
"outputs": [],
"source": [
"trainer = BertClassifierTrainer(\n",
" train_dataset, classifier_model, optimizer, strategy)\n",
"\n",
"evaluator = BertClassifierEvaluator(\n",
" eval_dataset, classifier_model, strategy)\n",
"\n",
"controller = orbit.Controller(\n",
" trainer=trainer,\n",
" evaluator=evaluator,\n",
" global_step=trainer.global_step,\n",
" steps_per_loop=20,\n",
" checkpoint_manager=checkpoint_manager)\n",
"\n",
"result = controller.train_and_evaluate(\n",
" train_steps=steps_per_epoch * num_train_epochs,\n",
" eval_steps=-1,\n",
" eval_interval=steps_per_epoch)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"Tce3stUlHN0L"
],
"name": "Orbit Tutorial.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
|