Buckets:

rtrm's picture
download
raw
73.3 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Trainer&quot;,&quot;local&quot;:&quot;trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;기본 사용법&quot;,&quot;local&quot;:&quot;basic-usage&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;체크포인트&quot;,&quot;local&quot;:&quot;checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Trainer 맞춤 설정&quot;,&quot;local&quot;:&quot;customize-the-trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;콜백&quot;,&quot;local&quot;:&quot;callbacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;로깅&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;NEFTune&quot;,&quot;local&quot;:&quot;neftune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;GaLore&quot;,&quot;local&quot;:&quot;galore&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;LOMO 옵티마이저&quot;,&quot;local&quot;:&quot;lomo-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Accelerate와 Trainer&quot;,&quot;local&quot;:&quot;accelerate-and-trainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/transformers/main/ko/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/entry/start.9aa88961.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/scheduler.9bc65507.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/singletons.9eec45c3.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/index.3b203c72.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/paths.566078f7.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/entry/app.84fb67c3.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/index.707bf1b6.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/nodes/0.1c99376b.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/nodes/93.2358f2d1.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/Tip.c2ecdbf4.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/CodeBlock.54a9f38d.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/EditOnGithub.922df6ba.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/HfOption.6d864328.js">
<link rel="modulepreload" href="/docs/transformers/main/ko/_app/immutable/chunks/stores.c16bc1a5.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Trainer&quot;,&quot;local&quot;:&quot;trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;기본 사용법&quot;,&quot;local&quot;:&quot;basic-usage&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;체크포인트&quot;,&quot;local&quot;:&quot;checkpoints&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Trainer 맞춤 설정&quot;,&quot;local&quot;:&quot;customize-the-trainer&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;콜백&quot;,&quot;local&quot;:&quot;callbacks&quot;,&quot;sections&quot;:[],&quot;depth&quot;:3}],&quot;depth&quot;:2},{&quot;title&quot;:&quot;로깅&quot;,&quot;local&quot;:&quot;logging&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;NEFTune&quot;,&quot;local&quot;:&quot;neftune&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;GaLore&quot;,&quot;local&quot;:&quot;galore&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;LOMO 옵티마이저&quot;,&quot;local&quot;:&quot;lomo-optimizer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;Accelerate와 Trainer&quot;,&quot;local&quot;:&quot;accelerate-and-trainer&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <h1 class="relative group"><a id="trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Trainer</span></h1> <p data-svelte-h="svelte-1r4aiss"><code>Trainer</code>는 Transformers 라이브러리에 구현된 PyTorch 모델을 반복하여 훈련 및 평가 과정입니다. 훈련에 필요한 요소(모델, 토크나이저, 데이터셋, 평가 함수, 훈련 하이퍼파라미터 등)만 제공하면 <code>Trainer</code>가 필요한 나머지 작업을 처리합니다. 이를 통해 직접 훈련 루프를 작성하지 않고도 빠르게 훈련을 시작할 수 있습니다. 또한 <code>Trainer</code>는 강력한 맞춤 설정과 다양한 훈련 옵션을 제공하여 사용자 맞춤 훈련이 가능합니다.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-26hx6t">Transformers는 <code>Trainer</code> 클래스 외에도 번역이나 요약과 같은 시퀀스-투-시퀀스 작업을 위한 <code>Seq2SeqTrainer</code> 클래스도 제공합니다. 또한 <a href="https://hf.co/docs/trl" rel="nofollow">TRL</a> 라이브러리에는 <code>Trainer</code> 클래스를 감싸고 Llama-2 및 Mistral과 같은 언어 모델을 자동 회귀 기법으로 훈련하는 데 최적화된 <code>SFTTrainer</code> 클래스 입니다. <code>SFTTrainer</code>는 시퀀스 패킹, LoRA, 양자화 및 DeepSpeed와 같은 기능을 지원하여 크기 상관없이 모델 효율적으로 확장할 수 있습니다.</p> <br> <p data-svelte-h="svelte-7uuu5x">이들 다른 <code>Trainer</code> 유형 클래스에 대해 더 알고 싶다면 <a href="./main_classes/trainer">API 참조</a>를 확인하여 언제 어떤 클래스가 적합할지 얼마든지 확인하세요. 일반적으로 <code>Trainer</code>는 가장 다재다능한 옵션으로, 다양한 작업에 적합합니다. <code>Seq2SeqTrainer</code>는 시퀀스-투-시퀀스 작업을 위해 설계되었고, <code>SFTTrainer</code>는 언어 모델 훈련을 위해 설계되었습니다.</p></div> <p data-svelte-h="svelte-1cp0146">시작하기 전에, 분산 환경에서 PyTorch 훈련과 실행을 할 수 있게 <a href="https://hf.co/docs/accelerate" rel="nofollow">Accelerate</a> 라이브러리가 설치되었는지 확인하세요.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install accelerate
<span class="hljs-comment"># 업그레이드</span>
pip install accelerate --upgrade<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1eiyo2g">이 가이드는 <code>Trainer</code> 클래스에 대한 개요를 제공합니다.</p> <h2 class="relative group"><a id="basic-usage" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#basic-usage"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>기본 사용법</span></h2> <p data-svelte-h="svelte-1acln4k"><code>Trainer</code>는 기본적인 훈련 루프에 필요한 모든 코드를 포함하고 있습니다.</p> <ol data-svelte-h="svelte-1paahjl"><li>손실을 계산하는 훈련 단계를 수행합니다.</li> <li><code>backward</code> 메소드로 그레이디언트를 계산합니다.</li> <li>그레이디언트를 기반으로 가중치를 업데이트합니다.</li> <li>정해진 에폭 수에 도달할 때까지 이 과정을 반복합니다.</li></ol> <p data-svelte-h="svelte-3o0je1"><code>Trainer</code> 클래스는 PyTorch와 훈련 과정에 익숙하지 않거나 막 시작한 경우에도 훈련이 가능하도록 필요한 모든 코드를 추상화하였습니다. 또한 매번 훈련 루프를 손수 작성하지 않아도 되며, 훈련에 필요한 모델과 데이터셋 같은 필수 구성 요소만 제공하면, [Trainer] 클래스가 나머지를 처리합니다.</p> <p data-svelte-h="svelte-4zmz0l">훈련 옵션이나 하이퍼파라미터를 지정하려면, <code>TrainingArguments</code> 클래스에서 확인 할 수 있습니다. 예를 들어, 모델을 저장할 디렉토리를 <code>output_dir</code>에 정의하고, 훈련 후에 Hub로 모델을 푸시하려면 <code>push_to_hub=True</code>로 설정합니다.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments
training_args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;your-model&quot;</span>,
learning_rate=<span class="hljs-number">2e-5</span>,
per_device_train_batch_size=<span class="hljs-number">16</span>,
per_device_eval_batch_size=<span class="hljs-number">16</span>,
num_train_epochs=<span class="hljs-number">2</span>,
weight_decay=<span class="hljs-number">0.01</span>,
eval_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
save_strategy=<span class="hljs-string">&quot;epoch&quot;</span>,
load_best_model_at_end=<span class="hljs-literal">True</span>,
push_to_hub=<span class="hljs-literal">True</span>,
)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-rthqdp"><code>training_args</code><code>Trainer</code>에 모델, 데이터셋, 데이터셋 전처리 도구(데이터 유형에 따라 토크나이저, 특징 추출기 또는 이미지 프로세서일 수 있음), 데이터 수집기 및 훈련 중 확인할 지표를 계산할 함수를 함께 전달하세요.</p> <p data-svelte-h="svelte-a31jzj">마지막으로, <code>train()</code>를 호출하여 훈련을 시작하세요!</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=dataset[<span class="hljs-string">&quot;test&quot;</span>],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="checkpoints" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#checkpoints"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>체크포인트</span></h3> <p data-svelte-h="svelte-b83gro"><code>Trainer</code> 클래스는 <code>TrainingArguments</code><code>output_dir</code> 매개변수에 지정된 디렉토리에 모델 체크포인트를 저장합니다. 체크포인트는 <code>checkpoint-000</code> 하위 폴더에 저장되며, 여기서 끝의 숫자는 훈련 단계에 해당합니다. 체크포인트를 저장하면 나중에 훈련을 재개할 때 유용합니다.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-comment"># 최신 체크포인트에서 재개</span>
trainer.train(resume_from_checkpoint=<span class="hljs-literal">True</span>)
<span class="hljs-comment"># 출력 디렉토리에 저장된 특정 체크포인트에서 재개</span>
trainer.train(resume_from_checkpoint=<span class="hljs-string">&quot;your-model/checkpoint-1000&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1qy235s">체크포인트를 Hub에 푸시하려면 <code>TrainingArguments</code>에서 <code>push_to_hub=True</code>로 설정하여 커밋하고 푸시할 수 있습니다. 체크포인트 저장 방법을 결정하는 다른 옵션은 <a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.hub_strategy" rel="nofollow"><code>hub_strategy</code></a> 매개변수에서 설정합니다:</p> <ul data-svelte-h="svelte-1biqkkg"><li><code>hub_strategy=&quot;checkpoint&quot;</code>는 최신 체크포인트를 “last-checkpoint”라는 하위 폴더에 푸시하여 훈련을 재개할 수 있습니다.</li> <li><code>hub_strategy=&quot;all_checkpoints&quot;</code>는 모든 체크포인트를 <code>output_dir</code>에 정의된 디렉토리에 푸시합니다(모델 리포지토리에서 폴더당 하나의 체크포인트를 볼 수 있습니다).</li></ul> <p data-svelte-h="svelte-iwkjnt">체크포인트에서 훈련을 재개할 때, <code>Trainer</code>는 체크포인트가 저장될 때와 동일한 Python, NumPy 및 PyTorch RNG 상태를 유지하려고 합니다. 하지만 PyTorch는 기본 설정으로 ‘일관된 결과를 보장하지 않음’으로 많이 되어있기 때문에, RNG 상태가 동일할 것이라고 보장할 수 없습니다. 따라서, 일관된 결과가 보장되도록 활성화 하려면, <a href="https://pytorch.org/docs/stable/notes/randomness#controlling-sources-of-randomness" rel="nofollow">랜덤성 제어</a> 가이드를 참고하여 훈련을 완전히 일관된 결과를 보장 받도록 만들기 위해 활성화할 수 있는 항목을 확인하세요. 다만, 특정 설정을 결정적으로 만들면 훈련이 느려질 수 있습니다.</p> <h2 class="relative group"><a id="customize-the-trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#customize-the-trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Trainer 맞춤 설정</span></h2> <p data-svelte-h="svelte-tbyxn3"><code>Trainer</code> 클래스는 접근성과 용이성을 염두에 두고 설계되었지만, 더 다양한 기능을 원하는 사용자들을 위해 다양한 맞춤 설정 옵션을 제공합니다. <code>Trainer</code>의 많은 메소드는 서브클래스화 및 오버라이드하여 원하는 기능을 제공할 수 있으며, 이를 통해 전체 훈련 루프를 다시 작성할 필요 없이 원하는 기능을 추가할 수 있습니다. 이러한 메소드에는 다음이 포함됩니다:</p> <ul data-svelte-h="svelte-2gbdwc"><li><code>get_train_dataloader()</code>는 훈련 데이터로더를 생성합니다.</li> <li><code>get_eval_dataloader()</code>는 평가 데이터로더를 생성합니다.</li> <li><code>get_test_dataloader()</code>는 테스트 데이터로더를 생성합니다.</li> <li><code>log()</code>는 훈련을 모니터링하는 다양한 객체에 대한 정보를 로그로 남깁니다.</li> <li><code>create_optimizer_and_scheduler()</code><code>__init__</code>에서 전달되지 않은 경우 옵티마이저와 학습률 스케줄러를 생성합니다. 이들은 각각 <code>create_optimizer()</code><code>create_scheduler()</code>로 별도로 맞춤 설정 할 수 있습니다.</li> <li><code>compute_loss()</code>는 훈련 입력 배치에 대한 손실을 계산합니다.</li> <li><code>training_step()</code>는 훈련 단계를 수행합니다.</li> <li><code>prediction_step()</code>는 예측 및 테스트 단계를 수행합니다.</li> <li><code>evaluate()</code>는 모델을 평가하고 평가 지표을 반환합니다.</li> <li><code>predict()</code>는 테스트 세트에 대한 예측(레이블이 있는 경우 지표 포함)을 수행합니다.</li></ul> <p data-svelte-h="svelte-vw04s9">예를 들어, <code>compute_loss()</code> 메소드를 맞춤 설정하여 가중 손실을 사용하려는 경우:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
<span class="hljs-keyword">class</span> <span class="hljs-title class_">CustomTrainer</span>(<span class="hljs-title class_ inherited__">Trainer</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">compute_loss</span>(<span class="hljs-params">self,
model, inputs, return_outputs=<span class="hljs-literal">False</span></span>):
labels = inputs.pop(<span class="hljs-string">&quot;labels&quot;</span>)
<span class="hljs-comment"># 순방향 전파</span>
outputs = model(**inputs)
logits = outputs.get(<span class="hljs-string">&quot;logits&quot;</span>)
<span class="hljs-comment"># 서로 다른 가중치로 3개의 레이블에 대한 사용자 정의 손실을 계산</span>
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([<span class="hljs-number">1.0</span>, <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>], device=model.device))
loss = loss_fct(logits.view(-<span class="hljs-number">1</span>, self.model.config.num_labels), labels.view(-<span class="hljs-number">1</span>))
<span class="hljs-keyword">return</span> (loss, outputs) <span class="hljs-keyword">if</span> return_outputs <span class="hljs-keyword">else</span> loss<!-- HTML_TAG_END --></pre></div> <h3 class="relative group"><a id="callbacks" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#callbacks"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>콜백</span></h3> <p data-svelte-h="svelte-gudmov"><code>Trainer</code>를 맞춤 설정하는 또 다른 방법은 <a href="callbacks">콜백</a>을 사용하는 것입니다. 콜백은 훈련 루프에서 <em>변화를 주지 않습니다</em>. 훈련 루프의 상태를 검사한 후 상태에 따라 일부 작업(조기 종료, 결과 로그 등)을 실행합니다. 즉, 콜백은 사용자 정의 손실 함수와 같은 것을 구현하는 데 사용할 수 없으며, 이를 위해서는 <code>compute_loss()</code> 메소드를 서브클래스화하고 오버라이드해야 합니다.</p> <p data-svelte-h="svelte-1sr1d63">예를 들어, 훈련 루프에 10단계 후 조기 종료 콜백을 추가하려면 다음과 같이 합니다.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainerCallback
<span class="hljs-keyword">class</span> <span class="hljs-title class_">EarlyStoppingCallback</span>(<span class="hljs-title class_ inherited__">TrainerCallback</span>):
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, num_steps=<span class="hljs-number">10</span></span>):
self.num_steps = num_steps
<span class="hljs-keyword">def</span> <span class="hljs-title function_">on_step_end</span>(<span class="hljs-params">self, args, state, control, **kwargs</span>):
<span class="hljs-keyword">if</span> state.global_step &gt;= self.num_steps:
<span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;should_training_stop&quot;</span>: <span class="hljs-literal">True</span>}
<span class="hljs-keyword">else</span>:
<span class="hljs-keyword">return</span> {}<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-c4o7g1">그런 다음, 이를 <code>Trainer</code><code>callback</code> 매개변수에 전달합니다.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset[<span class="hljs-string">&quot;train&quot;</span>],
eval_dataset=dataset[<span class="hljs-string">&quot;test&quot;</span>],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()],
)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="logging" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#logging"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>로깅</span></h2> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-1djg64a">로깅 API에 대한 자세한 내용은 <a href="./main_classes/logging">로깅</a> API 레퍼런스를 확인하세요.</p></div> <p data-svelte-h="svelte-1igbvzt"><code>Trainer</code>는 기본적으로 <code>logging.INFO</code>로 설정되어 있어 오류, 경고 및 기타 기본 정보를 보고합니다. 분산 환경에서는 <code>Trainer</code> 복제본이 <code>logging.WARNING</code>으로 설정되어 오류와 경고만 보고합니다. <code>TrainingArguments</code><a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level" rel="nofollow"><code>log_level</code></a><a href="https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.log_level_replica" rel="nofollow"><code>log_level_replica</code></a> 매개변수로 로그 레벨을 변경할 수 있습니다.</p> <p data-svelte-h="svelte-1xaqk62">각 노드의 로그 레벨 설정을 구성하려면 <a href="https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments.log_on_each_node" rel="nofollow"><code>log_on_each_node</code></a> 매개변수를 사용하여 각 노드에서 로그 레벨을 사용할지 아니면 주 노드에서만 사용할지 결정하세요.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-rdkluh"><code>Trainer</code><code>Trainer.__init__()</code> 메소드에서 각 노드에 대해 로그 레벨을 별도로 설정하므로, 다른 Transformers 기능을 사용할 경우 <code>Trainer</code> 객체를 생성하기 전에 이를 미리 설정하는 것이 좋습니다.</p></div> <p data-svelte-h="svelte-1dt04b5">예를 들어, 메인 코드와 모듈을 각 노드에 따라 동일한 로그 레벨을 사용하도록 설정하려면 다음과 같이 합니다.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->logger = logging.getLogger(__name__)
logging.basicConfig(
<span class="hljs-built_in">format</span>=<span class="hljs-string">&quot;%(asctime)s - %(levelname)s - %(name)s - %(message)s&quot;</span>,
datefmt=<span class="hljs-string">&quot;%m/%d/%Y %H:%M:%S&quot;</span>,
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dk6uay">각 노드에서 기록될 내용을 구성하기 위해 <code>log_level</code><code>log_level_replica</code>를 다양한 조합으로 사용해보세요.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">single node </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">multi-node </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->my_app.py ... --log_level warning --log_level_replica error<!-- HTML_TAG_END --></pre></div> </div> <h2 class="relative group"><a id="neftune" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#neftune"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>NEFTune</span></h2> <p data-svelte-h="svelte-1hypqsm"><a href="https://hf.co/papers/2310.05914" rel="nofollow">NEFTune</a>은 훈련 중 임베딩 벡터에 노이즈를 추가하여 성능을 향상시킬 수 있는 기술입니다. <code>Trainer</code>에서 이를 활성화하려면 <code>TrainingArguments</code><code>neftune_noise_alpha</code> 매개변수를 설정하여 노이즈의 양을 조절합니다.</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, Trainer
training_args = TrainingArguments(..., neftune_noise_alpha=<span class="hljs-number">0.1</span>)
trainer = Trainer(..., args=training_args)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1ft34jp">NEFTune은 예상치 못한 동작을 피할 목적으로 처음 임베딩 레이어로 복원하기 위해 훈련 후 비활성화 됩니다.</p> <h2 class="relative group"><a id="galore" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#galore"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>GaLore</span></h2> <p data-svelte-h="svelte-zig1fw">Gradient Low-Rank Projection (GaLore)은 전체 매개변수를 학습하면서도 LoRA와 같은 일반적인 저계수 적응 방법보다 더 메모리 효율적인 저계수 학습 전략입니다.</p> <p data-svelte-h="svelte-15aiu37">먼저 GaLore 공식 리포지토리를 설치합니다:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->pip install galore-torch<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-13s90m9">그런 다음 <code>optim</code><code>[&quot;galore_adamw&quot;, &quot;galore_adafactor&quot;, &quot;galore_adamw_8bit&quot;]</code> 중 하나와 함께 <code>optim_target_modules</code>를 추가합니다. 이는 적용하려는 대상 모듈 이름에 해당하는 문자열, 정규 표현식 또는 전체 경로의 목록일 수 있습니다. 아래는 end-to-end 예제 스크립트입니다(필요한 경우 <code>pip install trl datasets</code>를 실행):</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw&quot;</span>,
optim_target_modules=[<span class="hljs-string">&quot;attn&quot;</span>, <span class="hljs-string">&quot;mlp&quot;</span>]
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-q0mtk0">GaLore가 지원하는 추가 매개변수를 전달하려면 <code>optim_args</code>를 설정합니다. 예를 들어:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw&quot;</span>,
optim_target_modules=[<span class="hljs-string">&quot;attn&quot;</span>, <span class="hljs-string">&quot;mlp&quot;</span>],
optim_args=<span class="hljs-string">&quot;rank=64, update_proj_gap=100, scale=0.10&quot;</span>,
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-9wx27h">해당 방법에 대한 자세한 내용은 <a href="https://github.com/jiaweizzhao/GaLore" rel="nofollow">원본 리포지토리</a> 또는 <a href="https://arxiv.org/abs/2403.03507" rel="nofollow">논문</a>을 참고하세요.</p> <p data-svelte-h="svelte-3m2n0g">현재 GaLore 레이어로 간주되는 Linear 레이어만 훈련 할수 있으며, 저계수 분해를 사용하여 훈련되고 나머지 레이어는 기존 방식으로 최적화됩니다.</p> <p data-svelte-h="svelte-o488v2">훈련 시작 전에 시간이 약간 걸릴 수 있습니다(NVIDIA A100에서 2B 모델의 경우 약 3분), 하지만 이후 훈련은 원활하게 진행됩니다.</p> <p data-svelte-h="svelte-6sty29">다음과 같이 옵티마이저 이름에 <code>layerwise</code>를 추가하여 레이어별 최적화를 수행할 수도 있습니다:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">import</span> trl
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-galore&quot;</span>,
max_steps=<span class="hljs-number">100</span>,
per_device_train_batch_size=<span class="hljs-number">2</span>,
optim=<span class="hljs-string">&quot;galore_adamw_layerwise&quot;</span>,
optim_target_modules=[<span class="hljs-string">&quot;attn&quot;</span>, <span class="hljs-string">&quot;mlp&quot;</span>]
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">512</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1h6bbfe">레이어별 최적화는 다소 실험적이며 DDP(분산 데이터 병렬)를 지원하지 않으므로, 단일 GPU에서만 훈련 스크립트를 실행할 수 있습니다. 자세한 내용은 <a href="https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory" rel="nofollow">이 문서를</a>을 참조하세요. gradient clipping, DeepSpeed 등 다른 기능은 기본적으로 지원되지 않을 수 있습니다. 이러한 문제가 발생하면 <a href="https://github.com/huggingface/transformers/issues" rel="nofollow">GitHub에 이슈를 올려주세요</a>.</p> <h2 class="relative group"><a id="lomo-optimizer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#lomo-optimizer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>LOMO 옵티마이저</span></h2> <p data-svelte-h="svelte-19j518g">LOMO 옵티마이저는 <a href="https://hf.co/papers/2306.09782" rel="nofollow">제한된 자원으로 대형 언어 모델의 전체 매개변수 미세 조정</a><a href="https://hf.co/papers/2310.10195" rel="nofollow">적응형 학습률을 통한 저메모리 최적화(AdaLomo)</a>에서 도입되었습니다.
이들은 모두 효율적인 전체 매개변수 미세 조정 방법으로 구성되어 있습니다. 이러한 옵티마이저들은 메모리 사용량을 줄이기 위해 그레이디언트 계산과 매개변수 업데이트를 하나의 단계로 융합합니다. LOMO에서 지원되는 옵티마이저는 <code>&quot;lomo&quot;</code><code>&quot;adalomo&quot;</code>입니다. 먼저 pypi에서 <code>pip install lomo-optim</code>를 통해 <code>lomo</code>를 설치하거나, GitHub 소스에서 <code>pip install git+https://github.com/OpenLMLab/LOMO.git</code>로 설치하세요.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-733wbz">저자에 따르면, <code>grad_norm</code> 없이 <code>AdaLomo</code>를 사용하는 것이 더 나은 성능과 높은 처리량을 제공한다고 합니다.</p></div> <p data-svelte-h="svelte-16ffc3o">다음은 IMDB 데이터셋에서 <a href="https://huggingface.co/google/gemma-2b" rel="nofollow">google/gemma-2b</a>를 최대 정밀도로 미세 조정하는 간단한 스크립트입니다:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-keyword">import</span> torch
<span class="hljs-keyword">import</span> datasets
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> TrainingArguments, AutoTokenizer, AutoModelForCausalLM
<span class="hljs-keyword">import</span> trl
train_dataset = datasets.load_dataset(<span class="hljs-string">&#x27;imdb&#x27;</span>, split=<span class="hljs-string">&#x27;train&#x27;</span>)
args = TrainingArguments(
output_dir=<span class="hljs-string">&quot;./test-lomo&quot;</span>,
max_steps=<span class="hljs-number">1000</span>,
per_device_train_batch_size=<span class="hljs-number">4</span>,
optim=<span class="hljs-string">&quot;adalomo&quot;</span>,
gradient_checkpointing=<span class="hljs-literal">True</span>,
logging_strategy=<span class="hljs-string">&quot;steps&quot;</span>,
logging_steps=<span class="hljs-number">1</span>,
learning_rate=<span class="hljs-number">2e-6</span>,
save_strategy=<span class="hljs-string">&quot;no&quot;</span>,
run_name=<span class="hljs-string">&quot;lomo-imdb&quot;</span>,
)
model_id = <span class="hljs-string">&quot;google/gemma-2b&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=<span class="hljs-literal">True</span>).to(<span class="hljs-number">0</span>)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field=<span class="hljs-string">&#x27;text&#x27;</span>,
max_seq_length=<span class="hljs-number">1024</span>,
)
trainer.train()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="accelerate-and-trainer" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#accelerate-and-trainer"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Accelerate와 Trainer</span></h2> <p data-svelte-h="svelte-145u0tn"><code>Trainer</code> 클래스는 <a href="https://hf.co/docs/accelerate" rel="nofollow">Accelerate</a>로 구동되며, 이는 <a href="https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/" rel="nofollow">FullyShardedDataParallel (FSDP)</a><a href="https://www.deepspeed.ai/" rel="nofollow">DeepSpeed</a>와 같은 통합을 지원하는 분산 환경에서 PyTorch 모델을 쉽게 훈련할 수 있는 라이브러리입니다.</p> <div class="course-tip bg-gradient-to-br dark:bg-gradient-to-r before:border-green-500 dark:before:border-green-800 from-green-50 dark:from-gray-900 to-white dark:to-gray-950 border border-green-50 text-green-700 dark:text-gray-400"><p data-svelte-h="svelte-u7bkxx">FSDP 샤딩 전략, CPU 오프로드 및 <code>Trainer</code>와 함께 사용할 수 있는 더 많은 기능을 알아보려면 <a href="fsdp">Fully Sharded Data Parallel</a> 가이드를 확인하세요.</p></div> <p data-svelte-h="svelte-1j99ily"><code>Trainer</code>와 Accelerate를 사용하려면 <a href="https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-config" rel="nofollow"><code>accelerate.config</code></a> 명령을 실행하여 훈련 환경을 설정하세요. 이 명령은 훈련 스크립트를 실행할 때 사용할 <code>config_file.yaml</code>을 생성합니다. 예를 들어, 다음 예시는 설정할 수 있는 일부 구성 예입니다.</p> <div class="flex space-x-2 items-center my-1.5 mr-8 h-7 !pl-0 -mx-3 md:mx-0"><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd border-gray-800 bg-black dark:bg-gray-700 text-white">DistributedDataParallel </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">FSDP </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">DeepSpeed </div><div class="flex items-center border rounded-lg px-1.5 py-1 leading-none select-none text-smd text-gray-500 cursor-pointer opacity-90 hover:text-gray-700 dark:hover:text-gray-200 hover:shadow-sm">DeepSpeed with Accelerate plugin </div></div> <div class="language-select"><div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-attr">compute_environment:</span> <span class="hljs-string">LOCAL_MACHINE</span>
<span class="hljs-attr">distributed_type:</span> <span class="hljs-string">MULTI_GPU</span>
<span class="hljs-attr">downcast_bf16:</span> <span class="hljs-string">&#x27;no&#x27;</span>
<span class="hljs-attr">gpu_ids:</span> <span class="hljs-string">all</span>
<span class="hljs-attr">machine_rank:</span> <span class="hljs-number">0</span> <span class="hljs-comment"># 노드에 따라 순위를 변경하세요</span>
<span class="hljs-attr">main_process_ip:</span> <span class="hljs-number">192.168</span><span class="hljs-number">.20</span><span class="hljs-number">.1</span>
<span class="hljs-attr">main_process_port:</span> <span class="hljs-number">9898</span>
<span class="hljs-attr">main_training_function:</span> <span class="hljs-string">main</span>
<span class="hljs-attr">mixed_precision:</span> <span class="hljs-string">fp16</span>
<span class="hljs-attr">num_machines:</span> <span class="hljs-number">2</span>
<span class="hljs-attr">num_processes:</span> <span class="hljs-number">8</span>
<span class="hljs-attr">rdzv_backend:</span> <span class="hljs-string">static</span>
<span class="hljs-attr">same_network:</span> <span class="hljs-literal">true</span>
<span class="hljs-attr">tpu_env:</span> []
<span class="hljs-attr">tpu_use_cluster:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">tpu_use_sudo:</span> <span class="hljs-literal">false</span>
<span class="hljs-attr">use_cpu:</span> <span class="hljs-literal">false</span><!-- HTML_TAG_END --></pre></div> </div> <p data-svelte-h="svelte-y960yk"><a href="https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch" rel="nofollow"><code>accelerate_launch</code></a> 명령은 Accelerate와 <code>Trainer</code>를 사용하여 분산 시스템에서 훈련 스크립트를 실행하는 권장 방법이며, <code>config_file.yaml</code>에 지정된 매개변수를 사용합니다. 이 파일은 Accelerate 캐시 폴더에 저장되며 <code>accelerate_launch</code>를 실행할 때 자동으로 로드됩니다.</p> <p data-svelte-h="svelte-uvyyka">예를 들어, FSDP 구성을 사용하여 <a href="https://github.com/huggingface/transformers/blob/f4db565b695582891e43a5e042e5d318e28f20b8/examples/pytorch/text-classification/run_glue.py#L4" rel="nofollow">run_glue.py</a> 훈련 스크립트를 실행하려면 다음과 같이 합니다:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name <span class="hljs-variable">$TASK_NAME</span> \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/<span class="hljs-variable">$TASK_NAME</span>/ \
--overwrite_output_dir<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1nohodh"><code>config_file.yaml</code> 파일의 매개변수를 직접 지정할 수도 있습니다:</p> <div class="code-block relative"><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->accelerate launch --num_processes=2 \
--use_fsdp \
--mixed_precision=bf16 \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap=<span class="hljs-string">&quot;BertLayer&quot;</span> \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name <span class="hljs-variable">$TASK_NAME</span> \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/<span class="hljs-variable">$TASK_NAME</span>/ \
--overwrite_output_dir<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1xu5fm6"><code>accelerate_launch</code>와 사용자 정의 구성에 대해 더 알아보려면 <a href="https://huggingface.co/docs/accelerate/basic_tutorials/launch" rel="nofollow">Accelerate 스크립트 실행</a> 튜토리얼을 확인하세요.</p> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/transformers/blob/main/docs/source/ko/trainer.md" target="_blank"><span data-svelte-h="svelte-1kd6by1">&lt;</span> <span data-svelte-h="svelte-x0xyl0">&gt;</span> <span data-svelte-h="svelte-1dajgef"><span class="underline ml-1.5">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_1hrx8 = {
assets: "/docs/transformers/main/ko",
base: "/docs/transformers/main/ko",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/transformers/main/ko/_app/immutable/entry/start.9aa88961.js"),
import("/docs/transformers/main/ko/_app/immutable/entry/app.84fb67c3.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 93],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
73.3 kB
·
Xet hash:
f470fc686d02902f790cd4d560e7fb64cde9be366e2d7e56f90b910c047b440d

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.