Buckets:
| <meta charset="utf-8" /><meta name="hf:doc:metadata" content="{"title":"Отладка обучения","local":"debugging-the-training-pipeline","sections":[{"title":"Отладка обучающего пайплайна","local":"debugging-the-training-pipeline","sections":[{"title":"Проверка данных","local":"check-your-data","sections":[],"depth":3},{"title":"От датасетов к загрузчикам данных","local":"from-datasets-to-dataloaders","sections":[],"depth":3},{"title":"Проверьте данные","local":"going-through-the-model","sections":[],"depth":3},{"title":"Выполнение одного шага оптиимзации","local":"performing-one-optimization-step","sections":[],"depth":3},{"title":"Как справиться с ошибками нехватки памяти","local":"dealing-with-cuda-out-of-memory-errors","sections":[],"depth":3},{"title":"Валидация модели","local":"evaluating-the-model","sections":[],"depth":3}],"depth":2},{"title":"Отладка скрытых ошибок во время обучения","local":"debugging-silent-errors-during-training","sections":[{"title":"Проверьте свои данные (еще раз!)","local":"check-your-data-again","sections":[],"depth":3},{"title":"Переобучение модели на одном батче","local":"overfit-your-model-on-one-batch","sections":[],"depth":3},{"title":"Не обучайте ничего, пока не получите первый бейзлайн.","local":"dont-tune-anything-until-you-have-a-first-baseline","sections":[],"depth":3},{"title":"Попросите о помощи","local":"ask-for-help","sections":[],"depth":3}],"depth":2}],"depth":1}"> | |
| <link href="/docs/course/pr_1069/ru/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/entry/start.48687cc8.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/scheduler.37c15a92.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/singletons.6f259016.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/index.18351ede.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/paths.930ed261.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/entry/app.b79a803d.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/index.2bf4358c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/nodes/0.e11366e4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/each.e59479a4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/nodes/64.0782aad4.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/Tip.363c041f.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/Youtube.1e50a667.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/CodeBlock.4e987730.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/CourseFloatingBanner.9ff4c771.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/DocNotebookDropdown.efc1fb7c.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/FrameworkSwitchCourse.8d4d4ab6.js"> | |
| <link rel="modulepreload" href="/docs/course/pr_1069/ru/_app/immutable/chunks/getInferenceSnippets.24b50994.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{"title":"Отладка обучения","local":"debugging-the-training-pipeline","sections":[{"title":"Отладка обучающего пайплайна","local":"debugging-the-training-pipeline","sections":[{"title":"Проверка данных","local":"check-your-data","sections":[],"depth":3},{"title":"От датасетов к загрузчикам данных","local":"from-datasets-to-dataloaders","sections":[],"depth":3},{"title":"Проверьте данные","local":"going-through-the-model","sections":[],"depth":3},{"title":"Выполнение одного шага оптиимзации","local":"performing-one-optimization-step","sections":[],"depth":3},{"title":"Как справиться с ошибками нехватки памяти","local":"dealing-with-cuda-out-of-memory-errors","sections":[],"depth":3},{"title":"Валидация модели","local":"evaluating-the-model","sections":[],"depth":3}],"depth":2},{"title":"Отладка скрытых ошибок во время обучения","local":"debugging-silent-errors-during-training","sections":[{"title":"Проверьте свои данные (еще раз!)","local":"check-your-data-again","sections":[],"depth":3},{"title":"Переобучение модели на одном батче","local":"overfit-your-model-on-one-batch","sections":[],"depth":3},{"title":"Не обучайте ничего, пока не получите первый бейзлайн.","local":"dont-tune-anything-until-you-have-a-first-baseline","sections":[],"depth":3},{"title":"Попросите о помощи","local":"ask-for-help","sections":[],"depth":3}],"depth":2}],"depth":1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="bg-white leading-none border border-gray-100 rounded-lg flex p-0.5 w-56 text-sm mb-4"><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-l bg-red-50 dark:bg-transparent text-red-600" href="?fw=pt"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><defs><clipPath id="a"><rect x="3.05" y="0.5" width="25.73" height="31" fill="none"></rect></clipPath></defs><g clip-path="url(#a)"><path d="M24.94,9.51a12.81,12.81,0,0,1,0,18.16,12.68,12.68,0,0,1-18,0,12.81,12.81,0,0,1,0-18.16l9-9V5l-.84.83-6,6a9.58,9.58,0,1,0,13.55,0ZM20.44,9a1.68,1.68,0,1,1,1.67-1.67A1.68,1.68,0,0,1,20.44,9Z" fill="#ee4c2c"></path></g></svg> Pytorch </a><a class="flex justify-center flex-1 py-1.5 px-2.5 focus:outline-none !no-underline rounded-r text-gray-500 filter grayscale" href="?fw=tf"><svg class="mr-1.5" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" focusable="false" role="img" width="0.94em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 274"><path d="M145.726 42.065v42.07l72.861 42.07v-42.07l-72.86-42.07zM0 84.135v42.07l36.43 21.03V105.17L0 84.135zm109.291 21.035l-36.43 21.034v126.2l36.43 21.035v-84.135l36.435 21.035v-42.07l-36.435-21.034V105.17z" fill="#E55B2D"></path><path d="M145.726 42.065L36.43 105.17v42.065l72.861-42.065v42.065l36.435-21.03v-84.14zM255.022 63.1l-36.435 21.035v42.07l36.435-21.035V63.1zm-72.865 84.135l-36.43 21.035v42.07l36.43-21.036v-42.07zm-36.43 63.104l-36.436-21.035v84.135l36.435-21.035V210.34z" fill="#ED8E24"></path><path d="M145.726 0L0 84.135l36.43 21.035l109.296-63.105l72.861 42.07L255.022 63.1L145.726 0zm0 126.204l-36.435 21.03l36.435 21.036l36.43-21.035l-36.43-21.03z" fill="#F8BF3C"></path></svg> TensorFlow </a></div> <h1 class="relative group"><a id="debugging-the-training-pipeline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#debugging-the-training-pipeline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Отладка обучения</span></h1> <div class="flex space-x-1 absolute z-10 right-0 top-0"><a href="https://discuss.huggingface.co/t/chapter-8-questions" target="_blank"><img alt="Ask a Question" class="!m-0" src="https://img.shields.io/badge/Ask%20a%20question-ffcb4c.svg?logo=data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgLTEgMTA0IDEwNiI+PGRlZnM+PHN0eWxlPi5jbHMtMXtmaWxsOiMyMzFmMjA7fS5jbHMtMntmaWxsOiNmZmY5YWU7fS5jbHMtM3tmaWxsOiMwMGFlZWY7fS5jbHMtNHtmaWxsOiMwMGE5NGY7fS5jbHMtNXtmaWxsOiNmMTVkMjI7fS5jbHMtNntmaWxsOiNlMzFiMjM7fTwvc3R5bGU+PC9kZWZzPjx0aXRsZT5EaXNjb3Vyc2VfbG9nbzwvdGl0bGU+PGcgaWQ9IkxheWVyXzIiPjxnIGlkPSJMYXllcl8zIj48cGF0aCBjbGFzcz0iY2xzLTEiIGQ9Ik01MS44NywwQzIzLjcxLDAsMCwyMi44MywwLDUxYzAsLjkxLDAsNTIuODEsMCw1Mi44MWw1MS44Ni0uMDVjMjguMTYsMCw1MS0yMy43MSw1MS01MS44N1M4MCwwLDUxLjg3LDBaIi8+PHBhdGggY2xhc3M9ImNscy0yIiBkPSJNNTIuMzcsMTkuNzRBMzEuNjIsMzEuNjIsMCwwLDAsMjQuNTgsNjYuNDFsLTUuNzIsMTguNEwzOS40LDgwLjE3YTMxLjYxLDMxLjYxLDAsMSwwLDEzLTYwLjQzWiIvPjxwYXRoIGNsYXNzPSJjbHMtMyIgZD0iTTc3LjQ1LDMyLjEyYTMxLjYsMzEuNiwwLDAsMS0zOC4wNSw0OEwxOC44Niw4NC44MmwyMC45MS0yLjQ3QTMxLjYsMzEuNiwwLDAsMCw3Ny40NSwzMi4xMloiLz48cGF0aCBjbGFzcz0iY2xzLTQiIGQ9Ik03MS42MywyNi4yOUEzMS42LDMxLjYsMCwwLDEsMzguOCw3OEwxOC44Niw4NC44MiwzOS40LDgwLjE3QTMxLjYsMzEuNiwwLDAsMCw3MS42MywyNi4yOVoiLz48cGF0aCBjbGFzcz0iY2xzLTUiIGQ9Ik0yNi40Nyw2Ny4xMWEzMS42MSwzMS42MSwwLDAsMSw1MS0zNUEzMS42MSwzMS42MSwwLDAsMCwyNC41OCw2Ni40MWwtNS43MiwxOC40WiIvPjxwYXRoIGNsYXNzPSJjbHMtNiIgZD0iTTI0LjU4LDY2LjQxQTMxLjYxLDMxLjYxLDAsMCwxLDcxLjYzLDI2LjI5YTMxLjYxLDMxLjYxLDAsMCwwLTQ5LDM5LjYzbC0zLjc2LDE4LjlaIi8+PC9nPjwvZz48L3N2Zz4="></a> <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter8/section4.ipynb" target="_blank"><img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"></a> <a href="https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter8/section4.ipynb" target="_blank"><img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"></a></div> <p data-svelte-h="svelte-i74yku">Вы написали прекрасный сценарий для обучения или дообучения модели на заданной задаче, послушно следуя советам из <a href="/course/chapter7/1">Главы 7</a>. Но когда вы запускаете команду <code>model.fit()</code>, происходит нечто ужасное: вы получаете ошибку 😱! Или, что еще хуже, все вроде бы хорошо, обучение проходит без ошибок, но результирующая модель получается плохой. В этом разделе мы покажем вам, что можно сделать для отладки подобных проблем.</p> <h2 class="relative group"><a id="debugging-the-training-pipeline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#debugging-the-training-pipeline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Отладка обучающего пайплайна</span></h2> <iframe class="w-full xl:w-4/6 h-80" src="https://www.youtube-nocookie.com/embed/L-WSwUWde1U" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe> <p data-svelte-h="svelte-16g84ru">Проблема, когда вы сталкиваетесь с ошибкой в <code>trainer.train()</code>, заключается в том, что она может возникнуть из нескольких источников, поскольку <code>Trainer</code> обычно собирает вместе множество вещей. Он преобразует наборы данных в загрузчики данных, поэтому проблема может заключаться в том, что в вашем наборе данных что-то не так, или в том, что вы пытаетесь объединить элементы наборов данных в батч. Затем он берет батч данных и подает его в модель, так что проблема может быть в коде модели. После этого вычисляются градиенты и выполняется этап оптимизации, так что проблема может быть и в вашем оптимизаторе. И даже если во время обучения все идет хорошо, во время валидации все равно что-то может пойти не так, если проблема в метрике.</p> <p data-svelte-h="svelte-j3otxp">Лучший способ отладить ошибку, возникшую в <code>trainer.train()</code>, - это вручную пройти весь пайплайн и посмотреть, где все пошло не так. В этом случае ошибку часто очень легко устранить.</p> <p data-svelte-h="svelte-7his7t">Чтобы продемонстрировать это, мы используем следующий скрипт, который (пытается) точно настроить модель DistilBERT на наборе данных <a href="https://huggingface.co/datasets/glue" rel="nofollow">MNLI dataset</a>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">import</span> evaluate | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| model_checkpoint = <span class="hljs-string">"distilbert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">"premise"</span>], examples[<span class="hljs-string">"hypothesis"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
| args = TrainingArguments( | |
| <span class="hljs-string">f"distilbert-finetuned-mnli"</span>, | |
| evaluation_strategy=<span class="hljs-string">"epoch"</span>, | |
| save_strategy=<span class="hljs-string">"epoch"</span>, | |
| learning_rate=<span class="hljs-number">2e-5</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| weight_decay=<span class="hljs-number">0.01</span>, | |
| ) | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>): | |
| predictions, labels = eval_pred | |
| <span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels) | |
| trainer = Trainer( | |
| model, | |
| args, | |
| train_dataset=raw_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=raw_datasets[<span class="hljs-string">"validation_matched"</span>], | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-fuew7a">Если вы попытаетесь выполнить его, то столкнетесь с довольно загадочной ошибкой:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'ValueError: You have to specify either input_ids or inputs_embeds'</span><!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="check-your-data" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#check-your-data"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Проверка данных</span></h3> <p data-svelte-h="svelte-1ae2sud">Это может быть очевидно, но если ваши данные повреждены, <code>Trainer</code> не сможет сформировать батчи, не говоря уже об обучении вашей модели. Поэтому прежде всего необходимо посмотреть, что находится в вашем обучающем наборе.</p> <p data-svelte-h="svelte-1yz06m4">Чтобы избежать бесчисленных часов, потраченных на попытки исправить то, что не является источником ошибки, мы рекомендуем использовать <code>trainer.train_dataset</code> для проверок и ничего больше. Так что давайте сделаем это здесь:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train_dataset[<span class="hljs-number">0</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'hypothesis'</span>: <span class="hljs-string">'Product and geography are what make cream skimming work. '</span>, | |
| <span class="hljs-string">'idx'</span>: <span class="hljs-number">0</span>, | |
| <span class="hljs-string">'label'</span>: <span class="hljs-number">1</span>, | |
| <span class="hljs-string">'premise'</span>: <span class="hljs-string">'Conceptually cream skimming has two basic dimensions - product and geography.'</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1fjzhww">Вы заметили что-то неладное? В сочетании с сообщением об ошибке <code>input_ids</code>, вы должны понять, что это тексты, а не числа, которые модель может понять. Здесь исходная ошибка вводит в заблуждение, потому что <code>Trainer</code> автоматически удаляет столбцы, которые не соответствуют сигнатуре модели (то есть аргументам, ожидаемым моделью). Это означает, что здесь было удалено все, кроме меток. Таким образом, не было никаких проблем с созданием батчей и отправкой их модели, которая, в свою очередь, жаловалась, что не получила нужных входных данных.</p> <p data-svelte-h="svelte-1x1vxdz">Почему данные не обрабатывались? Мы действительно использовали метод <code>Dataset.map()</code> для наборов данных, чтобы применить токенизатор к каждой выборке. Но если вы внимательно посмотрите на код, то увидите, что мы допустили ошибку при передаче обучающего и валидационного наборов в <code>Trainer</code>. Вместо того чтобы использовать <code>tokenized_datasets</code>, мы использовали <code>raw_datasets</code> 🤦. Так что давайте исправим это!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">import</span> evaluate | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| model_checkpoint = <span class="hljs-string">"distilbert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">"premise"</span>], examples[<span class="hljs-string">"hypothesis"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
| args = TrainingArguments( | |
| <span class="hljs-string">f"distilbert-finetuned-mnli"</span>, | |
| evaluation_strategy=<span class="hljs-string">"epoch"</span>, | |
| save_strategy=<span class="hljs-string">"epoch"</span>, | |
| learning_rate=<span class="hljs-number">2e-5</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| weight_decay=<span class="hljs-number">0.01</span>, | |
| ) | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>): | |
| predictions, labels = eval_pred | |
| <span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels) | |
| trainer = Trainer( | |
| model, | |
| args, | |
| train_dataset=tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=tokenized_datasets[<span class="hljs-string">"validation_matched"</span>], | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-kzrtgn">Теперь этот новый код будет выдавать другую ошибку (прогресс!):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'ValueError: expected sequence of length 43 at dim 1 (got 37)'</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-gwq31x">Посмотрев на трассировку, мы видим, что ошибка происходит на этапе сборки данных:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->~/git/transformers/src/transformers/data/data_collator.py <span class="hljs-keyword">in</span> torch_default_data_collator(features) | |
| <span class="hljs-number">105</span> batch[k] = torch.stack([f[k] <span class="hljs-keyword">for</span> f <span class="hljs-keyword">in</span> features]) | |
| <span class="hljs-number">106</span> <span class="hljs-keyword">else</span>: | |
| --> <span class="hljs-number">107</span> batch[k] = torch.tensor([f[k] <span class="hljs-keyword">for</span> f <span class="hljs-keyword">in</span> features]) | |
| <span class="hljs-number">108</span> | |
| <span class="hljs-number">109</span> <span class="hljs-keyword">return</span> batch<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-gm88e5">Поэтому нам следует перейти к этому. Однако перед этим давайте закончим проверку наших данных, чтобы быть на 100% уверенными в их правильности.</p> <p data-svelte-h="svelte-2f8x3p">При отладке обучения всегда нужно смотреть на декодированные входы модели. Мы не можем понять смысл чисел, которые подаем ей напрямую, поэтому мы должны посмотреть, что эти числа представляют. В компьютерном зрении, например, это означает просмотр декодированных изображений пройденных пикселей, в речи - прослушивание декодированных образцов звука, а в нашем примере с NLP - использование нашего токенизатора для декодирования входных данных:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->tokenizer.decode(trainer.train_dataset[<span class="hljs-number">0</span>][<span class="hljs-string">"input_ids"</span>])<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-string">'[CLS] conceptually cream skimming has two basic dimensions - product and geography. [SEP] product and geography are what make cream skimming work. [SEP]'</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-fhhvpq">Так что, похоже, все правильно. Вы должны сделать это для всех ключей во входах:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train_dataset[<span class="hljs-number">0</span>].keys()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->dict_keys([<span class="hljs-string">'attention_mask'</span>, <span class="hljs-string">'hypothesis'</span>, <span class="hljs-string">'idx'</span>, <span class="hljs-string">'input_ids'</span>, <span class="hljs-string">'label'</span>, <span class="hljs-string">'premise'</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hcvxf">Обратите внимание, что ключи, не соответствующие входам, принимаемым моделью, будут автоматически отброшены, поэтому здесь мы оставим только <code>input_ids</code>, <code>attention_mask</code> и <code>label</code> (которая будет переименована в <code>labels</code>). Чтобы перепроверить сигнатуру модели, вы можете вывести класс вашей модели, а затем проверить ее документацию:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">type</span>(trainer.model)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->transformers.models.distilbert.modeling_distilbert.DistilBertForSequenceClassification<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1sjbam">Итак, в нашем случае мы можем проверить принятые параметры на <a href="https://huggingface.co/transformers/model_doc/distilbert.html#distilbertforsequenceclassification" rel="nofollow">этой странице</a>. “Trainer” также будет регистрировать столбцы, которые он отбрасывает.</p> <p data-svelte-h="svelte-e8yr2f">Мы проверили правильность входных идентификаторов, декодировав их. Далее находится <code>attention_mask</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train_dataset[<span class="hljs-number">0</span>][<span class="hljs-string">"attention_mask"</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">1</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ptp61">Так как мы не применяли в препроцессинге дополнение нулями, это кажется совершенно естественным. Чтобы убедиться в отсутствии проблем с этой маской внимания, давайте проверим, что она имеет ту же длину, что и наши входные идентификаторы:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-built_in">len</span>(trainer.train_dataset[<span class="hljs-number">0</span>][<span class="hljs-string">"attention_mask"</span>]) == <span class="hljs-built_in">len</span>( | |
| trainer.train_dataset[<span class="hljs-number">0</span>][<span class="hljs-string">"input_ids"</span>] | |
| )<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-literal">True</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-41fi5j">Это хорошо! И наконец, давайте проверим нашу метку класса:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train_dataset[<span class="hljs-number">0</span>][<span class="hljs-string">"label"</span>]<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">1</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3wlf7g">Как и входные идентификаторы, это число, которое само по себе не имеет смысла. Как мы уже видели, соответствие между целыми числами и именами меток хранится в атрибуте <code>names</code> соответствующей <em>функции</em> набора данных:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.train_dataset.features[<span class="hljs-string">"label"</span>].names<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->[<span class="hljs-string">'entailment'</span>, <span class="hljs-string">'neutral'</span>, <span class="hljs-string">'contradiction'</span>]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1cvooet">Итак, <code>1</code> означает <code>нейтральный</code>, а значит, два предложения, которые мы видели выше, не противоречат друг другу, и из первого не следует второе. Это кажется правильным!</p> <p data-svelte-h="svelte-1nf3k9j">Здесь у нас нет идентификаторов типов токенов, поскольку DistilBERT их не ожидает; если в вашей модели они есть, вам также следует убедиться, что они правильно соответствуют месту первого и второго предложений во входных данных.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-453sg9">✏️ <strong>Ваша очередь!</strong> Проверьте, все ли правильно со вторым элементом обучающего набора данных.</p></div> <p data-svelte-h="svelte-9ucqbl">В данном случае мы проверяем только обучающий набор, но, конечно, вы должны дважды проверить валидационный и тестовый наборы таким же образом.</p> <p data-svelte-h="svelte-8755tv">Теперь, когда мы знаем, что наши наборы данных выглядят хорошо, пришло время проверить следующий этап пайплайна обучения.</p> <h3 class="relative group"><a id="from-datasets-to-dataloaders" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#from-datasets-to-dataloaders"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>От датасетов к загрузчикам данных</span></h3> <p data-svelte-h="svelte-1ubc5pl">Следующее, что может пойти не так в пайплайне обучения, - это когда <code>Trainer</code> пытается сформировать батчи из обучающего или проверочного набора. Убедившись, что наборы данных <code>Trainer</code> корректны, можно попробовать вручную сформировать батч, выполнив следующие действия (замените <code>train</code> на <code>eval</code> для валидационного загрузчика данных):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> trainer.get_train_dataloader(): | |
| <span class="hljs-keyword">break</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-b63dqv">Этот код создает загрузчик данных для обучения, затем выполняет итерации по нему, останавливаясь на первой итерации. Если код выполняется без ошибок, у вас есть первый обучающий батч, который можно проверить, а если код выдает ошибку, вы точно знаете, что проблема в загрузчике данных, как в данном случае:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->~/git/transformers/src/transformers/data/data_collator.py <span class="hljs-keyword">in</span> torch_default_data_collator(features) | |
| <span class="hljs-number">105</span> batch[k] = torch.stack([f[k] <span class="hljs-keyword">for</span> f <span class="hljs-keyword">in</span> features]) | |
| <span class="hljs-number">106</span> <span class="hljs-keyword">else</span>: | |
| --> <span class="hljs-number">107</span> batch[k] = torch.tensor([f[k] <span class="hljs-keyword">for</span> f <span class="hljs-keyword">in</span> features]) | |
| <span class="hljs-number">108</span> | |
| <span class="hljs-number">109</span> <span class="hljs-keyword">return</span> batch | |
| ValueError: expected sequence of length <span class="hljs-number">45</span> at dim <span class="hljs-number">1</span> (got <span class="hljs-number">76</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-o0yb3p">Осмотра последнего кадра трассировки должно быть достаточно, чтобы дать вам подсказку, но давайте покопаемся еще немного. Большинство проблем при создании батчей возникает из-за объединения примеров в один батч, поэтому первое, что нужно проверить при возникновении сомнений, это то, какой <code>collate_fn</code> использует ваш <code>DataLoader</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->data_collator = trainer.get_train_dataloader().collate_fn | |
| data_collator<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><function transformers.data.data_collator.default_data_collator(features: <span class="hljs-type">List</span>[InputDataClass], return_tensors=<span class="hljs-string">'pt'</span>) -> <span class="hljs-type">Dict</span>[<span class="hljs-built_in">str</span>, <span class="hljs-type">Any</span>]><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-s94t3r">Этот <code>default_data_collator</code> не то, что нам нужно в данном случае. Мы хотим разбить наши примеры на самые длинные предложения в пакете, что делает <code>DataCollatorWithPadding</code>. И этот класс данных должен использоваться по умолчанию в <code>Trainer</code>, так почему же он не используется здесь?</p> <p data-svelte-h="svelte-qlao8o">Ответ заключается в том, что мы не передали <code>tokenizer</code> в <code>Trainer</code>, поэтому он не смог создать нужный нам <code>DataCollatorWithPadding</code>. На практике вам никогда не следует стесняться явно передавать data collator, который вы хотите использовать, чтобы избежать подобных ошибок. Давайте адаптируем наш код, чтобы сделать именно это:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">import</span> evaluate | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| DataCollatorWithPadding, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| model_checkpoint = <span class="hljs-string">"distilbert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">"premise"</span>], examples[<span class="hljs-string">"hypothesis"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint) | |
| args = TrainingArguments( | |
| <span class="hljs-string">f"distilbert-finetuned-mnli"</span>, | |
| evaluation_strategy=<span class="hljs-string">"epoch"</span>, | |
| save_strategy=<span class="hljs-string">"epoch"</span>, | |
| learning_rate=<span class="hljs-number">2e-5</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| weight_decay=<span class="hljs-number">0.01</span>, | |
| ) | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>): | |
| predictions, labels = eval_pred | |
| <span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| trainer = Trainer( | |
| model, | |
| args, | |
| train_dataset=tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=tokenized_datasets[<span class="hljs-string">"validation_matched"</span>], | |
| compute_metrics=compute_metrics, | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-kbq8q2">Хорошие новости? Мы не получаем ту же ошибку, что и раньше, что, безусловно, является прогрессом. Плохие новости? Вместо нее мы получаем печально известную ошибку CUDA:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1u2oyow">Это плохо, потому что ошибки CUDA вообще очень трудно отлаживать. Через минуту мы увидим, как решить эту проблему, но сначала давайте закончим анализ создания батчей.</p> <p data-svelte-h="svelte-2bykvd">Если вы уверены, что ваш data collator правильный, попробуйте применить его на паре примеров вашего набора данных:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->data_collator = trainer.get_train_dataloader().collate_fn | |
| batch = data_collator([trainer.train_dataset[i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">4</span>)])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-fe3ado">Этот код не сработает, потому что <code>train_dataset</code> содержит строковые колонки, которые <code>Trainer</code> обычно удаляет. Вы можете удалить их вручную, или, если вы хотите в точности повторить то, что <code>Trainer</code> делает за кулисами: нужно вызвать приватный метод <code>Trainer._remove_unused_columns()</code>, который делает это:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->data_collator = trainer.get_train_dataloader().collate_fn | |
| actual_train_set = trainer._remove_unused_columns(trainer.train_dataset) | |
| batch = data_collator([actual_train_set[i] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">4</span>)])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1k2z729">Если ошибка не исчезнет, вы сможете вручную отладить, что происходит внутри data collator.</p> <p data-svelte-h="svelte-1movyjg">Теперь, когда мы отладили процесс создания батча, пришло время пропустить его через модель!</p> <h3 class="relative group"><a id="going-through-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#going-through-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Проверьте данные</span></h3> <p data-svelte-h="svelte-a2nxrs">Вы должны иметь возможность получить батч, выполнив следующую команду:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> trainer.get_train_dataloader(): | |
| <span class="hljs-keyword">break</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-5zx2a4">Если вы выполняете этот код в ноутбуке, вы можете получить ошибку CUDA, похожую на ту, что мы видели ранее, и в этом случае вам нужно перезапустить ноутбук и заново выполнить последний фрагмент без строки <code>trainer.train()</code>. Это вторая самая неприятная вещь в ошибках CUDA: они безвозвратно ломают ваше ядро. Самое неприятное в них то, что их трудно отлаживать.</p> <p data-svelte-h="svelte-cef22x">Почему? Это связано с тем, как работают графические процессоры. Они чрезвычайно эффективны при параллельном выполнении множества операций, но их недостаток в том, что когда одна из этих инструкций приводит к ошибке, вы не сразу об этом узнаете. Только когда программа вызовет синхронизацию нескольких процессов на GPU, она поймет, что что-то пошло не так, поэтому ошибка возникает в том месте, которое не имеет никакого отношения к тому, что ее создало. Например, если мы посмотрим на наш предыдущий трассировочный откат, ошибка была вызвана во время обратного прохода, но через минуту мы увидим, что на самом деле она возникла из-за чего-то в прямом проходе через модель.</p> <p data-svelte-h="svelte-btkayt">Так как же отладить эти ошибки? Ответ прост: никак. Если только ошибка CUDA не является ошибкой вне памяти (что означает, что в вашем GPU недостаточно памяти), вы всегда должны возвращаться к CPU, чтобы отладить ее.</p> <p data-svelte-h="svelte-1iaolg6">Чтобы сделать это в нашем случае, нам просто нужно вернуть модель на CPU и вызвать ее на нашем пакете - пакет, возвращаемый <code>DataLoader</code>, еще не был перемещен на GPU:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->outputs = trainer.model.cpu()(**batch)<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->~/.pyenv/versions/<span class="hljs-number">3.7</span><span class="hljs-number">.9</span>/envs/base/lib/python3<span class="hljs-number">.7</span>/site-packages/torch/nn/functional.py <span class="hljs-keyword">in</span> nll_loss(<span class="hljs-built_in">input</span>, target, weight, size_average, ignore_index, reduce, reduction) | |
| <span class="hljs-number">2386</span> ) | |
| <span class="hljs-number">2387</span> <span class="hljs-keyword">if</span> dim == <span class="hljs-number">2</span>: | |
| -> <span class="hljs-number">2388</span> ret = torch._C._nn.nll_loss(<span class="hljs-built_in">input</span>, target, weight, _Reduction.get_enum(reduction), ignore_index) | |
| <span class="hljs-number">2389</span> <span class="hljs-keyword">elif</span> dim == <span class="hljs-number">4</span>: | |
| <span class="hljs-number">2390</span> ret = torch._C._nn.nll_loss2d(<span class="hljs-built_in">input</span>, target, weight, _Reduction.get_enum(reduction), ignore_index) | |
| IndexError: Target <span class="hljs-number">2</span> <span class="hljs-keyword">is</span> out of bounds.<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xr2p7k">Итак, картина проясняется. Вместо ошибки CUDA у нас теперь <code>IndexError</code> при вычислении функции потерь (так что обратный проход, как мы уже говорили, здесь ни при чем). Точнее, мы видим, что ошибка возникает именно в метке класса 2, так что это очень хороший момент для проверки количества меток нашей модели:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.model.config.num_labels<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-number">2</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-yblc4i">При двух метках в качестве значений допускаются только 0 и 1, но, согласно сообщению об ошибке, мы получили 2. Получение 2 на самом деле нормально: если мы помним имена меток, которые мы извлекли ранее, их было три, поэтому в нашем наборе данных есть индексы 0, 1 и 2. Проблема в том, что мы не сообщили об этом нашей модели, которая должна была быть создана с тремя метками. Так что давайте это исправим!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">import</span> evaluate | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| DataCollatorWithPadding, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| model_checkpoint = <span class="hljs-string">"distilbert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">"premise"</span>], examples[<span class="hljs-string">"hypothesis"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=<span class="hljs-number">3</span>) | |
| args = TrainingArguments( | |
| <span class="hljs-string">f"distilbert-finetuned-mnli"</span>, | |
| evaluation_strategy=<span class="hljs-string">"epoch"</span>, | |
| save_strategy=<span class="hljs-string">"epoch"</span>, | |
| learning_rate=<span class="hljs-number">2e-5</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| weight_decay=<span class="hljs-number">0.01</span>, | |
| ) | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>): | |
| predictions, labels = eval_pred | |
| <span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| trainer = Trainer( | |
| model, | |
| args, | |
| train_dataset=tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=tokenized_datasets[<span class="hljs-string">"validation_matched"</span>], | |
| compute_metrics=compute_metrics, | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, | |
| )<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ififvv">Мы пока не включаем строку <code>trainer.train()</code>, чтобы потратить время на проверку того, что все выглядит хорошо. Если мы запросим батч и передадим ее нашей модели, то теперь она работает без ошибок!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> trainer.get_train_dataloader(): | |
| <span class="hljs-keyword">break</span> | |
| outputs = trainer.model.cpu()(**batch)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nhffwj">Следующим шагом будет возвращение к графическому процессору и проверка того, что все по-прежнему работает:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = trainer.model.to(device)(**batch)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-6b6rwu">Если вы все еще получаете ошибку, убедитесь, что перезагрузили ноутбук и выполнили только последнюю версию скрипта.</p> <h3 class="relative group"><a id="performing-one-optimization-step" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#performing-one-optimization-step"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Выполнение одного шага оптиимзации</span></h3> <p data-svelte-h="svelte-156whrd">Теперь, когда мы знаем, что можем создавать батчи, которые действительно проходят через модель без ошибок, мы готовы к следующему шагу пайплайна обучения: вычислению градиентов и выполнению шага оптимизации.</p> <p data-svelte-h="svelte-1o4uma2">Первая часть заключается в вызове метода <code>backward()</code> для функции потерь:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->loss = outputs.loss | |
| loss.backward()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-16fqkcw">Ошибки на этом этапе возникают довольно редко, но если они все же возникнут, обязательно вернитесь к процессору, чтобы получить полезное сообщение об ошибке.</p> <p data-svelte-h="svelte-1vb7gyx">Чтобы выполнить шаг оптимизации, нам нужно просто создать <code>optimizer</code> и вызвать его метод <code>step()</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.create_optimizer() | |
| trainer.optimizer.step()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-7xu1gs">Опять же, если вы используете оптимизатор по умолчанию в <code>Trainer</code>, вы не должны получить ошибку на этом этапе, но если вы используете собственный оптимизатор, здесь могут возникнуть некоторые проблемы для отладки. Не забудьте вернуться к процессору, если на этом этапе вы получите странную ошибку CUDA. Говоря об ошибках CUDA, ранее мы упоминали особый случай. Давайте посмотрим на него сейчас.</p> <h3 class="relative group"><a id="dealing-with-cuda-out-of-memory-errors" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dealing-with-cuda-out-of-memory-errors"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Как справиться с ошибками нехватки памяти</span></h3> <p data-svelte-h="svelte-15f70qk">Если вы получаете сообщение об ошибке, начинающееся с <code>RuntimeError: CUDA out of memory</code>, это означает, что вам не хватает памяти GPU. Это не связано напрямую с вашим кодом, и может произойти со скриптом, который работает совершенно нормально. Эта ошибка означает, что вы пытались поместить слишком много данных во внутреннюю память вашего GPU, и это привело к ошибке. Как и в случае с другими ошибками CUDA, вам придется перезапустить ядро, чтобы снова запустить обучение.</p> <p data-svelte-h="svelte-19lrb7">Чтобы решить эту проблему, нужно просто использовать меньше памяти на GPU - что зачастую легче сказать, чем сделать. Во-первых, убедитесь, что у вас нет двух моделей на GPU одновременно (если, конечно, это не требуется для решения вашей задачи). Затем, вероятно, следует уменьшить размер батча, поскольку он напрямую влияет на размеры всех промежуточных выходов модели и их градиентов. Если проблема сохраняется, подумайте о том, чтобы использовать меньшую версию модели.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-3zd7su">В следующей части курса мы рассмотрим более продвинутые техники, которые помогут вам уменьшить объем занимаемой памяти и позволят точно настроить самые большие модели.</p></div> <h3 class="relative group"><a id="evaluating-the-model" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#evaluating-the-model"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Валидация модели</span></h3> <p data-svelte-h="svelte-1cp8phi">Теперь, когда мы решили все проблемы с нашим кодом, все идеально, и обучение должно пройти гладко, верно? Не так быстро! Если вы запустите команду <code>trainer.train()</code>, сначала все будет выглядеть хорошо, но через некоторое время вы получите следующее:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># Это займет много времени и приведет к ошибке, поэтому не стоит запускать эту ячейку</span> | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->TypeError: only size-<span class="hljs-number">1</span> arrays can be converted to Python scalars<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-c1eyic">Вы поймете, что эта ошибка появляется во время фазы валидации, так что это последнее, что нам нужно будет отладить.</p> <p data-svelte-h="svelte-18eb5gd">Вы можете запустить цикл оценки <code>Trainer</code> независимо от обучения следующим образом:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->trainer.evaluate()<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->TypeError: only size-<span class="hljs-number">1</span> arrays can be converted to Python scalars<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1942v49">💡 Перед запуском <code>trainer.train()</code> всегда следует убедиться, что вы можете запустить <code>trainer.evaluate()</code>, чтобы не тратить много вычислительных ресурсов до того, как столкнетесь с ошибкой.</p></div> <p data-svelte-h="svelte-mqpinc">Прежде чем пытаться отладить проблему в цикле валидации, нужно сначала убедиться, что вы посмотрели на данные, смогли правильно сформировать батч и запустить на нем свою модель. Мы выполнили все эти шаги, поэтому следующий код может быть выполнен без ошибок:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> trainer.get_eval_dataloader(): | |
| <span class="hljs-keyword">break</span> | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = trainer.model(**batch)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-n1tv18">Ошибка возникает позже, в конце фазы валидации, и если мы посмотрим на трассировку, то увидим следующее:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->~/git/datasets/src/datasets/metric.py <span class="hljs-keyword">in</span> add_batch(self, predictions, references) | |
| <span class="hljs-number">431</span> <span class="hljs-string">""" | |
| 432 batch = {"predictions": predictions, "references": references} | |
| --> 433 batch = self.info.features.encode_batch(batch) | |
| 434 if self.writer is None: | |
| 435 self._init_writer()</span><!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-grkkrc">Это говорит нам о том, что ошибка возникает в модуле <code>datasets/metric.py</code> - так что это проблема с нашей функцией <code>compute_metrics()</code>. Она принимает кортеж с логитами и метками в виде массивов NumPy, так что давайте попробуем скормить ей это:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions = outputs.logits.cpu().numpy() | |
| labels = batch[<span class="hljs-string">"labels"</span>].cpu().numpy() | |
| compute_metrics((predictions, labels))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->TypeError: only size-<span class="hljs-number">1</span> arrays can be converted to Python scalars<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-12fcw8q">Мы получаем ту же ошибку, так что проблема определенно кроется в этой функции. Если мы посмотрим на ее код, то увидим, что она просто передает <code>predictions</code> и <code>labels</code> в <code>metric.compute()</code>. Так есть ли проблема в этом методе? На самом деле нет. Давайте посмотрим на размерности:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->predictions.shape, labels.shape<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->((<span class="hljs-number">8</span>, <span class="hljs-number">3</span>), (<span class="hljs-number">8</span>,))<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1794hwq">Наши предсказания все еще являются логитами, а не реальными предсказаниями, поэтому метрика возвращает эту (несколько непонятную) ошибку. Исправить это довольно просто: нужно просто добавить argmax в функцию <code>compute_metrics()</code>:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=<span class="hljs-number">1</span>) | |
| <span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels) | |
| compute_metrics((predictions, labels))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'accuracy'</span>: <span class="hljs-number">0.625</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-wlyrc">Теперь наша ошибка исправлена! Она была последней, поэтому теперь наш скрипт будет правильно обучать модель.</p> <p data-svelte-h="svelte-16ewhvs">Для справки, вот полностью исправленный скрипт:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np | |
| <span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">import</span> evaluate | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| DataCollatorWithPadding, | |
| TrainingArguments, | |
| Trainer, | |
| ) | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| model_checkpoint = <span class="hljs-string">"distilbert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">preprocess_function</span>(<span class="hljs-params">examples</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(examples[<span class="hljs-string">"premise"</span>], examples[<span class="hljs-string">"hypothesis"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(preprocess_function, batched=<span class="hljs-literal">True</span>) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=<span class="hljs-number">3</span>) | |
| args = TrainingArguments( | |
| <span class="hljs-string">f"distilbert-finetuned-mnli"</span>, | |
| evaluation_strategy=<span class="hljs-string">"epoch"</span>, | |
| save_strategy=<span class="hljs-string">"epoch"</span>, | |
| learning_rate=<span class="hljs-number">2e-5</span>, | |
| num_train_epochs=<span class="hljs-number">3</span>, | |
| weight_decay=<span class="hljs-number">0.01</span>, | |
| ) | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mnli"</span>) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_metrics</span>(<span class="hljs-params">eval_pred</span>): | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=<span class="hljs-number">1</span>) | |
| <span class="hljs-keyword">return</span> metric.compute(predictions=predictions, references=labels) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| trainer = Trainer( | |
| model, | |
| args, | |
| train_dataset=tokenized_datasets[<span class="hljs-string">"train"</span>], | |
| eval_dataset=tokenized_datasets[<span class="hljs-string">"validation_matched"</span>], | |
| compute_metrics=compute_metrics, | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, | |
| ) | |
| trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-14zkqr0">В этом случае проблем больше нет, и наш скрипт обучит модель, которая должна дать приемлемые результаты. Но что делать, если обучение проходит без ошибок, а обученная модель совсем не работает? Это самая сложная часть машинного обучения, и мы покажем вам несколько приемов, которые могут помочь.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1gym4dk">💡 Если вы используете ручной цикл обучения, для отладки пайплайна обучения применимы те же шаги, но их проще разделить. Убедитесь, что вы не забыли <code>model.eval()</code> или <code>model.train()</code> в нужных местах, или <code>zero_grad()</code> на каждом шаге!</p></div> <h2 class="relative group"><a id="debugging-silent-errors-during-training" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#debugging-silent-errors-during-training"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Отладка скрытых ошибок во время обучения</span></h2> <p data-svelte-h="svelte-11ajyay">Что можно сделать, чтобы отладить обучение, которое завершается без ошибок, но не дает хороших результатов? Мы дадим вам несколько советов, но имейте в виду, что такая отладка - самая сложная часть машинного обучения, и волшебного ответа на этот вопрос не существует.</p> <h3 class="relative group"><a id="check-your-data-again" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#check-your-data-again"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Проверьте свои данные (еще раз!)</span></h3> <p data-svelte-h="svelte-6w5kk9">Ваша модель научится чему-то только в том случае, если из ваших данных действительно можно чему-то научиться. Если в данных есть ошибка, которая портит их, или метки приписаны случайным образом, то, скорее всего, вы не сможете обучить модель на своем наборе данных. Поэтому всегда начинайте с перепроверки декодированных входных данных и меток и задавайте себе следующие вопросы:</p> <ul data-svelte-h="svelte-1d8vmua"><li>Понятны ли декодированные данные?</li> <li>Правильные ли метки?</li> <li>Есть ли одна метка, которая встречается чаще других?</li> <li>Каким должно быть значение функции потерь/метрики если модель предсказала случайный ответ/всегда один и тот же ответ?</li></ul> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-12z68zh">⚠️ Если вы проводите распределенное обучение, распечатайте образцы набора данных в каждом процессе и трижды проверьте, что вы получаете одно и то же. Одна из распространенных ошибок - наличие некоторого источника случайности при создании данных, из-за которого каждый процесс имеет свою версию набора данных.</p></div> <p data-svelte-h="svelte-1lctfpx">Просмотрев данные, проанализируйте несколько предсказаний модели и декодируйте их. Если модель постоянно предсказывает одно и то же, это может быть связано с тем, что ваш набор данных смещен в сторону одной категории (для проблем классификации); здесь могут помочь такие методы, как oversampling редких классов.</p> <p data-svelte-h="svelte-5eyk32">Если значения функции потерь/метрики, которые вы получаете на начальной модели, сильно отличаются от значений функции потерь/метрик, которые можно было бы ожидать для случайных предсказаний, перепроверьте способ вычисления потерь или метрик, так как, возможно, в них есть ошибка. Если вы используете несколько функций потерь, убедитесь, что они имеют одинаковый масштаб.</p> <p data-svelte-h="svelte-hhg4yn">Когда вы убедитесь, что ваши данные идеальны, вы можете проверить, способна ли модель обучаться на них, с помощью одного простого теста.</p> <h3 class="relative group"><a id="overfit-your-model-on-one-batch" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#overfit-your-model-on-one-batch"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Переобучение модели на одном батче</span></h3> <p data-svelte-h="svelte-1cslmaj">Обычно мы стараемся избегать переобучения при тренировке модели, поскольку это означает, что модель не учится распознавать общие характеристики, а просто запоминает обучающие выборки. Однако попытка обучить модель на одной выборке снова и снова - это хороший тест, позволяющий проверить, может ли проблема в том виде, в котором вы ее сформулировали, быть решена с помощью модели, которую вы пытаетесь обучить. Это также поможет вам понять, не слишком ли высока ваша начальная скорость обучения.</p> <p data-svelte-h="svelte-1slr0p1">Сделать это после того, как вы определили свой <code>Trainer</code>, очень просто: просто возьмите батч обучающих данных, а затем запустите небольшой цикл ручного обучения, используя только этот батч в течение примерно 20 шагов:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> trainer.get_train_dataloader(): | |
| <span class="hljs-keyword">break</span> | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| trainer.create_optimizer() | |
| <span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">20</span>): | |
| outputs = trainer.model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| trainer.optimizer.step() | |
| trainer.optimizer.zero_grad()<!-- HTML_TAG_END --></pre></div> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-w8su75">💡 Если ваши обучающие данные несбалансированы, обязательно создайте батч обучающих данных, содержащий все метки.</p></div> <p data-svelte-h="svelte-90whek">Результирующая модель должна иметь близкие к идеальным результаты на одном и том же батче. Вычислим метрику по полученным предсказаниям:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = trainer.model(**batch) | |
| preds = outputs.logits | |
| labels = batch[<span class="hljs-string">"labels"</span>] | |
| compute_metrics((preds.cpu().numpy(), labels.cpu().numpy()))<!-- HTML_TAG_END --></pre></div> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->{<span class="hljs-string">'accuracy'</span>: <span class="hljs-number">1.0</span>}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1q8k6oi">Точность 100 %, вот это хороший пример переобучения (это значит, что если вы попробуете использовать модель на любом другом предложении, она, скорее всего, даст вам неправильный ответ)!</p> <p data-svelte-h="svelte-1rfk2rq">Если вам не удается добиться от модели таких идеальных результатов, значит, что-то не так с постановкой задачи или данными, и вам следует это исправить. Только когда вам удастся пройти тест на переобучение, вы сможете быть уверены, что ваша модель действительно способна чему-то научиться.</p> <div class="course-tip course-tip-orange bg-gradient-to-br dark:bg-gradient-to-r before:border-orange-500 dark:before:border-orange-800 from-orange-50 dark:from-gray-900 to-white dark:to-gray-950 border border-orange-50 text-orange-700 dark:text-gray-400"><p data-svelte-h="svelte-6bq32q">⚠️ Вам придется пересоздать модель и <code>Trainer</code> после этого теста на переобучение, поскольку полученная модель, вероятно, не сможет восстановиться и научиться чему-то полезному на полном наборе данных.</p></div> <h3 class="relative group"><a id="dont-tune-anything-until-you-have-a-first-baseline" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#dont-tune-anything-until-you-have-a-first-baseline"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Не обучайте ничего, пока не получите первый бейзлайн.</span></h3> <p data-svelte-h="svelte-amuojt">Настройка гиперпараметров всегда считается самой сложной частью машинного обучения, но это всего лишь последний шаг, который поможет вам немного улучшить метрику. В большинстве случаев гиперпараметры по умолчанию <code>Trainer</code> будут работать нормально и давать вам хорошие результаты, поэтому не приступайте к трудоемкому и дорогостоящему поиску гиперпараметров, пока у вас не будет чего-то, что превосходит базовый уровень, который у вас есть в вашем наборе данных.</p> <p data-svelte-h="svelte-scvkmh">Как только у вас будет достаточно хорошая модель, вы можете начать ее немного оптимизировать. Не пытайтесь запустить тысячу раз с разными гиперпараметрами, но сравните пару запусков с разными значениями одного гиперпараметра, чтобы получить представление о том, какой из них оказывает наибольшее влияние.</p> <p data-svelte-h="svelte-1xje688">Если вы настраиваете саму модель, будьте проще и не пробуйте то, что не можете обосновать. Всегда возвращайтесь к тесту на перебор, чтобы проверить, не привело ли ваше изменение к каким-либо непредвиденным последствиям.</p> <h3 class="relative group"><a id="ask-for-help" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#ask-for-help"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Попросите о помощи</span></h3> <p data-svelte-h="svelte-1hsg2d4">Надеемся, вы нашли в этом разделе советы, которые помогли вам решить вашу проблему, но если это не так, помните, что вы всегда можете спросить у сообщества на <a href="https://discuss.huggingface.co/" rel="nofollow">форумах</a>.</p> <p data-svelte-h="svelte-25fayq">Вот некоторые дополнительные ресурсы, которые могут оказаться полезными:</p> <ul data-svelte-h="svelte-s74dcg"><li><a href="https://docs.google.com/presentation/d/1yHLPvPhUs2KGI5ZWo0sU-PKU3GimAk3iTsI38Z-B5Gw/edit#slide=id.p" rel="nofollow">” Reproducibility as a vehicle for engineering best practices”</a> by Joel Grus</li> <li><a href="https://towardsdatascience.com/checklist-for-debugging-neural-networks-d8b2a9434f21" rel="nofollow">“Checklist for debugging neural networks”</a> by Cecelia Shao</li> <li><a href="https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765" rel="nofollow">” How to unit test machine learning code”</a> by Chase Roberts</li> <li><a href="http://karpathy.github.io/2019/04/25/recipe/" rel="nofollow">""A Recipe for Training Neural Networks”</a> by Andrej Karpathy</li></ul> <p data-svelte-h="svelte-t3dqmr">Конечно, не все проблемы, с которыми вы сталкиваетесь при обучении нейросетей, возникают по вашей вине! Если в библиотеке 🤗 Transformers или 🤗 Datasets вы столкнулись с чем-то, что кажется вам неправильным, возможно, вы обнаружили ошибку. Вам обязательно нужно рассказать нам об этом, и в следующем разделе мы объясним, как именно это сделать.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/course/blob/main/chapters/ru/chapter8/4.mdx" target="_blank"><span data-svelte-h="svelte-1kd6by1"><</span> <span data-svelte-h="svelte-x0xyl0">></span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p> | |
| <script> | |
| { | |
| __sveltekit_j8s7wf = { | |
| assets: "/docs/course/pr_1069/ru", | |
| base: "/docs/course/pr_1069/ru", | |
| env: {} | |
| }; | |
| const element = document.currentScript.parentElement; | |
| const data = [null,null]; | |
| Promise.all([ | |
| import("/docs/course/pr_1069/ru/_app/immutable/entry/start.48687cc8.js"), | |
| import("/docs/course/pr_1069/ru/_app/immutable/entry/app.b79a803d.js") | |
| ]).then(([kit, app]) => { | |
| kit.start(app, element, { | |
| node_ids: [0, 64], | |
| data, | |
| form: null, | |
| error: null | |
| }); | |
| }); | |
| } | |
| </script> | |
Xet Storage Details
- Size:
- 162 kB
- Xet hash:
- aed651ad5adb189e67aa51bc0a9d6636941de648ce79e2af1cac48af4330b30e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.