Buckets:

rtrm's picture
download
raw
44.5 kB
import{s as js,o as Us,n as Ts}from"../chunks/scheduler.49e4e380.js";import{S as fs,i as gs,g as m,s as n,r as i,A as Cs,h as o,f as s,c as a,j as ws,u as r,x as u,k as bs,y as Zs,a as t,v as p,d,t as M,w as c,m as ks,n as Bs}from"../chunks/index.fb15006d.js";import{T as hs}from"../chunks/Tip.f590f2e1.js";import{Y as Js}from"../chunks/Youtube.42918e4e.js";import{C as y}from"../chunks/CodeBlock.3f4fbe91.js";import{C as Is}from"../chunks/CourseFloatingBanner.c832fd1e.js";import{H as Ge,E as Gs}from"../chunks/getInferenceSnippets.ea935248.js";function $s(ke){let w,b="✏️ <strong>Probier es selbt!</strong> Ändere die vorherige Trainingsschleife, um dein Modell auf dem SST-2-Datensatz fein zu tunen.";return{c(){w=m("p"),w.innerHTML=b},l(h){w=o(h,"P",{"data-svelte-h":!0}),u(w)!=="svelte-10uj33w"&&(w.innerHTML=b)},m(h,Be){t(h,w,Be)},p:Ts,d(h){h&&s(w)}}}function vs(ke){let w;return{c(){w=ks('⚠️ Um von dem Geschwindigkeitsvorteil der Cloud TPUs zu profitieren, empfehlen wir, deine Samples mit den Argumenten `padding="max_length"` und `max_length` des Tokenizers auf eine feste Länge aufzufüllen.')},l(b){w=Bs(b,'⚠️ Um von dem Geschwindigkeitsvorteil der Cloud TPUs zu profitieren, empfehlen wir, deine Samples mit den Argumenten `padding="max_length"` und `max_length` des Tokenizers auf eine feste Länge aufzufüllen.')},m(b,h){t(b,w,h)},d(b){b&&s(w)}}}function Ws(ke){let w,b,h,Be,U,$e,T,ve,f,We,g,Vl="In diesem Abschnitt befassen wir uns damit, wie wir die gleichen Ergebnisse wie im letzten Abschnitt erzielen können, ohne die Klasse <code>Trainer</code> zu verwenden. Auch hier gehen wir davon aus, dass du die Datenverarbeitung in Abschnitt 2 durchgeführt hast. Hier ist eine kurze Zusammenfassung mit allem, was du brauchst:",Xe,C,ze,Z,Re,k,Nl="Bevor wir unsere Trainingsschleife schreiben, müssen wir noch einige Objekte definieren. Zunächst müssen wir die Datalader definieren, mit denen wir über die Batches iterieren werden. Doch bevor wir diese Dataloader definieren können, müssen wir unsere <code>tokenized_datasets</code> nachbearbeiten, um einige Dinge zu erledigen, die der <code>Trainer</code> automatisch für uns erledigt hat. Konkret heißt das, dass wir:",Ae,B,El="<li>Die Spalten entfernen, die Werte enthalten, die das Modell nicht erwartet (wie die Spalten <code>sentence1</code> und <code>sentence2</code>).</li> <li>Die Spalte <code>label</code> in <code>labels</code> umbenennen (weil das Modell erwartet, dass das Argument <code>labels</code> heißt).</li> <li>Das Format der Datensätze anpassen, so dass sie PyTorch-Tensoren statt Listen zurückgeben.</li>",_e,I,Fl="Das <code>tokenized_datasets</code> hat eine Methode für jeden dieser Schritte:",Ye,G,Ve,$,Ql="Anschließend können wir überprüfen, ob der Output nur Spalten enthält, die unser Modell akzeptiert:",Ne,v,Ee,W,Hl="Jetzt können wir ganz einfach unsere Dataloader definieren:",Fe,X,Qe,z,xl="Um sicher zu gehen, überprüfen wir ein Batch auf Fehler in der Datenverarbeitung:",He,R,xe,A,Se,_,Sl="Beachte, dass die Dimensionen der Tensoren wahrscheinlich etwas anders aussehen werden, da wir für den Trainingsdatenlader <code>shuffle=True</code> eingestellt haben und innerhalb des Batches auf die maximale Länge auffüllen.",Le,Y,Ll="Da wir nun mit der Datenvorverarbeitung fertig sind (ein zufriedenstellendes aber schwer erreichbares Ziel für jeden ML-Experten), können wir uns nun dem Modell zuwenden. Wir instanziieren es genauso wie im vorherigen Abschnitt:",De,V,Ke,N,Dl="Als weitere Sicherheitsmaßnahme übergeben wir unseren Batch an das Modell, um sicherzustellen, dass beim Training alles glatt läuft:",qe,E,Pe,F,Oe,Q,Kl="Alle 🤗 Transformer Modelle geben den Verlust zurück, wenn <code>labels</code> angegeben werden, und wir erhalten zusätzlich die Logits (zwei für jede Eingabe in unserem Batch, also einen Tensor der Größe 8 x 2).",el,H,ql='Wir sind fast so weit, unsere Trainingsschleife zu schreiben! Es fehlen nur noch zwei Dinge: ein Optimierer und ein Scheduler für die Lernrate. Da wir versuchen, das zu wiederholen, was der <code>Trainer</code> automatisch gemacht hat, werden wir die gleichen Standardwerte verwenden. Der Optimierer, den der <code>Trainer</code> verwendet, heißt “AdamW” und ist größtenteils derselbe wie Adam, abgesehen von einer Abwandlung für die “Weight Decay Regularization” (siehe [“Decoupled Weight Decay Regularization”] (<a href="https://arxiv.org/abs/1711.05101" rel="nofollow">https://arxiv.org/abs/1711.05101</a>) von Ilya Loshchilov und Frank Hutter):',ll,x,sl,S,Pl="Der standardmäßig verwendete Scheduler für die Lernrate ist ein linearer Abstieg vom Maximalwert (5e-5) auf 0. Um ihn richtig zu definieren, müssen wir die Anzahl der Trainingsschritte kennen, d.h. die Anzahl der Epochen, die die Trainingsschleife durchlaufen soll, multipliziert mit der Anzahl der Trainingsbatches (der Länge unseres Trainingsdatenordners). Der <code>Trainer</code> verwendet standardmäßig drei Epochen, woran wir uns hier orientieren werden:",tl,L,nl,D,al,K,il,q,Ol="Ein letzter Hinweis: Wir wollen die GPU zum Training nutzen, wenn wir Zugang zu einer haben (auf einer CPU kann das Training mehrere Stunden statt ein paar Minuten dauern). Dazu definieren wir <code>device</code> als Gerät auf dem wir unser Modell und unsere Batches speichern:",rl,P,pl,O,dl,ee,es="Wir sind jetzt bereit für das Training! Um ein Gefühl dafür zu bekommen, wann das Training abgeschlossen sein wird, fügen wir mit der Bibliothek <code>tqdm</code> einen Fortschrittsbalken über die Anzahl der Trainingsschritte ein:",Ml,le,cl,se,ls="Der Kern der Trainingsschleife sieht ähnlich aus wie in der Einleitung. Da wir keine Berichte angefordert haben, gibt die Trainingsschleife nichts über die Performance des Modells zurück. Dafür müssen wir eine Evaluationsschleife einfügen.",ml,te,ol,ne,ss="Wie schon zuvor verwenden wir eine Metrik, die von der 🤗 Evaluate-Bibliothek bereitgestellt wird. Wir haben bereits die Methode <code>metric.compute()</code> gesehen, aber Metriken können auch Batches für uns akkumulieren, wenn wir die Vorhersageschleife mit der Methode <code>add_batch()</code> durchlaufen. Sobald wir alle Batches gesammelt haben, können wir das Endergebnis mit der Methode <code>metric.compute()</code> ermitteln. So implementierst du all das in eine Evaluationsschleife:",ul,ae,yl,ie,wl,re,ts="Auch hier werden deine Ergebnisse wegen der Zufälligkeit bei der Initialisierung des Modellkopfes und der Datenverteilung etwas anders ausfallen, aber sie sollten in etwa gleich sein.",bl,J,hl,pe,Jl,de,jl,Me,ns='Die Trainingsschleife, die wir zuvor definiert haben, funktioniert gut auf einer einzelnen CPU oder GPU. Aber mit der Bibliothek <a href="https://github.com/huggingface/accelerate" rel="nofollow">🤗 Accelerate</a> können wir mit wenigen Anpassungen verteiltes Training auf mehreren GPUs oder TPUs implementieren. Beginnend mit der Erstellung der Trainings- und Validierungsdaten, sieht unsere manuelle Trainingsschleife nun folgendermaßen aus:',Ul,ce,Tl,me,as="Und hier sind die Änderungen:",fl,oe,gl,ue,is="Die erste Zeile, die hinzugefügt werden muss, ist die Import-Zeile. Die zweite Zeile instanziiert ein <code>Accelerator</code>-Objekt, das die Hardware analysiert und die richtige verteilte Umgebung initialisiert. Accelerate kümmert sich um die Anordnung der Geräte, du kannst also die Zeilen entfernen, die das Modell auf dem Gerät platzieren (oder, wenn du das möchtest, sie so ändern, dass sie <code>accelerator.device</code> anstelle von <code>device</code> verwenden).",Cl,ye,rs="Der Hauptteil der Arbeit wird dann in der Zeile erledigt, die die Dataloader, das Modell und den Optimierer an <code>accelerator.prepare()</code> sendet. Dadurch werden diese Objekte in den richtigen Container verpackt, damit das verteilte Training wie vorgesehen funktioniert. Die verbleibenden Änderungen sind das Entfernen der Zeile, die das Batch auf dem Gerät mit <code>device</code> ablegt (wenn du das beibehalten willst, kannst du es einfach in <code>accelerator.device</code> ändern) und das Ersetzen von <code>loss.backward()</code> durch <code>accelerator.backward(loss)</code>.",Zl,j,kl,we,ps="Wenn du damit experimentieren möchtest, siehst du hier, wie die komplette Trainingsschleife mit 🤗 Accelerate aussieht:",Bl,be,Il,he,ds="Wenn dies in das Script <code>train.py</code> eingefügt wird, kann das Script auf jeder Art von verteilter Hardware ausgeführt werden. Um es auf deiner verteilten Hardware auszuprobieren, führe den folgenden Befehl aus:",Gl,Je,$l,je,Ms="Du wirst dann aufgefordert werden, einige Fragen zu beantworten und die Antworten in eine Konfigurationsdatei zu schreiben, die von diesem Befehl verwendet wird:",vl,Ue,Wl,Te,cs="Damit wird das verteilte Training gestartet.",Xl,fe,ms="Wenn du das in einem Notebook ausprobieren möchtest (z. B. um es mit TPUs auf Colab zu testen), füge den Code einfach in eine <code>training_function()</code> ein und führe eine letzte Zelle mit aus:",zl,ge,Rl,Ce,os='Weitere Beispiele findest du in dem <a href="https://github.com/huggingface/accelerate/tree/main/examples" rel="nofollow">🤗 Accelerate Repo</a>.',Al,Ze,_l,Ie,Yl;return U=new Ge({props:{title:"Komplettes Training",local:"komplettes-training",headingTag:"h1"}}),T=new Is({props:{chapter:3,classNames:"absolute z-10 right-0 top-0",notebooks:[{label:"Google Colab",value:"https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/de/chapter3/section4.ipynb"},{label:"Aws Studio",value:"https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/de/chapter3/section4.ipynb"}]}}),f=new Js({props:{id:"Dh9CL8fyG80"}}),C=new y({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwbG9hZF9kYXRhc2V0JTBBZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Ub2tlbml6ZXIlMkMlMjBEYXRhQ29sbGF0b3JXaXRoUGFkZGluZyUwQSUwQXJhd19kYXRhc2V0cyUyMCUzRCUyMGxvYWRfZGF0YXNldCglMjJnbHVlJTIyJTJDJTIwJTIybXJwYyUyMiklMEFjaGVja3BvaW50JTIwJTNEJTIwJTIyYmVydC1iYXNlLXVuY2FzZWQlMjIlMEF0b2tlbml6ZXIlMjAlM0QlMjBBdXRvVG9rZW5pemVyLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50KSUwQSUwQSUwQWRlZiUyMHRva2VuaXplX2Z1bmN0aW9uKGV4YW1wbGUpJTNBJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwdG9rZW5pemVyKGV4YW1wbGUlNUIlMjJzZW50ZW5jZTElMjIlNUQlMkMlMjBleGFtcGxlJTVCJTIyc2VudGVuY2UyJTIyJTVEJTJDJTIwdHJ1bmNhdGlvbiUzRFRydWUpJTBBJTBBJTBBdG9rZW5pemVkX2RhdGFzZXRzJTIwJTNEJTIwcmF3X2RhdGFzZXRzLm1hcCh0b2tlbml6ZV9mdW5jdGlvbiUyQyUyMGJhdGNoZWQlM0RUcnVlKSUwQWRhdGFfY29sbGF0b3IlMjAlM0QlMjBEYXRhQ29sbGF0b3JXaXRoUGFkZGluZyh0b2tlbml6ZXIlM0R0b2tlbml6ZXIp",highlighted:`<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding
raw_datasets = load_dataset(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
checkpoint = <span class="hljs-string">&quot;bert-base-uncased&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>):
<span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">&quot;sentence1&quot;</span>], example[<span class="hljs-string">&quot;sentence2&quot;</span>], truncation=<span class="hljs-literal">True</span>)
tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)`,wrap:!1}}),Z=new Ge({props:{title:"Vorbereitung für das Training",local:"vorbereitung-für-das-training",headingTag:"h3"}}),G=new y({props:{code:"dG9rZW5pemVkX2RhdGFzZXRzJTIwJTNEJTIwdG9rZW5pemVkX2RhdGFzZXRzLnJlbW92ZV9jb2x1bW5zKCU1QiUyMnNlbnRlbmNlMSUyMiUyQyUyMCUyMnNlbnRlbmNlMiUyMiUyQyUyMCUyMmlkeCUyMiU1RCklMEF0b2tlbml6ZWRfZGF0YXNldHMlMjAlM0QlMjB0b2tlbml6ZWRfZGF0YXNldHMucmVuYW1lX2NvbHVtbiglMjJsYWJlbCUyMiUyQyUyMCUyMmxhYmVscyUyMiklMEF0b2tlbml6ZWRfZGF0YXNldHMuc2V0X2Zvcm1hdCglMjJ0b3JjaCUyMiklMEF0b2tlbml6ZWRfZGF0YXNldHMlNUIlMjJ0cmFpbiUyMiU1RC5jb2x1bW5fbmFtZXM=",highlighted:`tokenized_datasets = tokenized_datasets.remove_columns([<span class="hljs-string">&quot;sentence1&quot;</span>, <span class="hljs-string">&quot;sentence2&quot;</span>, <span class="hljs-string">&quot;idx&quot;</span>])
tokenized_datasets = tokenized_datasets.rename_column(<span class="hljs-string">&quot;label&quot;</span>, <span class="hljs-string">&quot;labels&quot;</span>)
tokenized_datasets.set_format(<span class="hljs-string">&quot;torch&quot;</span>)
tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>].column_names`,wrap:!1}}),v=new y({props:{code:"JTVCJTIyYXR0ZW50aW9uX21hc2slMjIlMkMlMjAlMjJpbnB1dF9pZHMlMjIlMkMlMjAlMjJsYWJlbHMlMjIlMkMlMjAlMjJ0b2tlbl90eXBlX2lkcyUyMiU1RA==",highlighted:'[<span class="hljs-string">&quot;attention_mask&quot;</span>, <span class="hljs-string">&quot;input_ids&quot;</span>, <span class="hljs-string">&quot;labels&quot;</span>, <span class="hljs-string">&quot;token_type_ids&quot;</span>]',wrap:!1}}),X=new y({props:{code:"ZnJvbSUyMHRvcmNoLnV0aWxzLmRhdGElMjBpbXBvcnQlMjBEYXRhTG9hZGVyJTBBJTBBdHJhaW5fZGF0YWxvYWRlciUyMCUzRCUyMERhdGFMb2FkZXIoJTBBJTIwJTIwJTIwJTIwdG9rZW5pemVkX2RhdGFzZXRzJTVCJTIydHJhaW4lMjIlNUQlMkMlMjBzaHVmZmxlJTNEVHJ1ZSUyQyUyMGJhdGNoX3NpemUlM0Q4JTJDJTIwY29sbGF0ZV9mbiUzRGRhdGFfY29sbGF0b3IlMEEpJTBBZXZhbF9kYXRhbG9hZGVyJTIwJTNEJTIwRGF0YUxvYWRlciglMEElMjAlMjAlMjAlMjB0b2tlbml6ZWRfZGF0YXNldHMlNUIlMjJ2YWxpZGF0aW9uJTIyJTVEJTJDJTIwYmF0Y2hfc2l6ZSUzRDglMkMlMjBjb2xsYXRlX2ZuJTNEZGF0YV9jb2xsYXRvciUwQSk=",highlighted:`<span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
train_dataloader = DataLoader(
tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>], shuffle=<span class="hljs-literal">True</span>, batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator
)
eval_dataloader = DataLoader(
tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>], batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator
)`,wrap:!1}}),R=new y({props:{code:"Zm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjBicmVhayUwQSU3QmslM0ElMjB2LnNoYXBlJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0Q=",highlighted:`<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
<span class="hljs-keyword">break</span>
{k: v.shape <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}`,wrap:!1}}),A=new y({props:{code:"JTdCJ2F0dGVudGlvbl9tYXNrJyUzQSUyMHRvcmNoLlNpemUoJTVCOCUyQyUyMDY1JTVEKSUyQyUwQSUyMCdpbnB1dF9pZHMnJTNBJTIwdG9yY2guU2l6ZSglNUI4JTJDJTIwNjUlNUQpJTJDJTBBJTIwJ2xhYmVscyclM0ElMjB0b3JjaC5TaXplKCU1QjglNUQpJTJDJTBBJTIwJ3Rva2VuX3R5cGVfaWRzJyUzQSUyMHRvcmNoLlNpemUoJTVCOCUyQyUyMDY1JTVEKSU3RA==",highlighted:`{<span class="hljs-string">&#x27;attention_mask&#x27;</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]),
<span class="hljs-string">&#x27;input_ids&#x27;</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]),
<span class="hljs-string">&#x27;labels&#x27;</span>: torch.Size([<span class="hljs-number">8</span>]),
<span class="hljs-string">&#x27;token_type_ids&#x27;</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>])}`,wrap:!1}}),V=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMEElMEFtb2RlbCUyMCUzRCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24uZnJvbV9wcmV0cmFpbmVkKGNoZWNrcG9pbnQlMkMlMjBudW1fbGFiZWxzJTNEMik=",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)`,wrap:!1}}),E=new y({props:{code:"b3V0cHV0cyUyMCUzRCUyMG1vZGVsKCoqYmF0Y2gpJTBBcHJpbnQob3V0cHV0cy5sb3NzJTJDJTIwb3V0cHV0cy5sb2dpdHMuc2hhcGUp",highlighted:`outputs = model(**batch)
<span class="hljs-built_in">print</span>(outputs.loss, outputs.logits.shape)`,wrap:!1}}),F=new y({props:{code:"dGVuc29yKDAuNTQ0MSUyQyUyMGdyYWRfZm4lM0QlM0NObGxMb3NzQmFja3dhcmQlM0UpJTIwdG9yY2guU2l6ZSglNUI4JTJDJTIwMiU1RCk=",highlighted:'tensor(<span class="hljs-number">0.5441</span>, grad_fn=&lt;NllLossBackward&gt;) torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">2</span>])',wrap:!1}}),x=new y({props:{code:"ZnJvbSUyMHRvcmNoLm9wdGltJTIwaW1wb3J0JTIwQWRhbVclMEElMEFvcHRpbWl6ZXIlMjAlM0QlMjBBZGFtVyhtb2RlbC5wYXJhbWV0ZXJzKCklMkMlMjBsciUzRDVlLTUp",highlighted:`<span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">5e-5</span>)`,wrap:!1}}),L=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMGdldF9zY2hlZHVsZXIlMEElMEFudW1fZXBvY2hzJTIwJTNEJTIwMyUwQW51bV90cmFpbmluZ19zdGVwcyUyMCUzRCUyMG51bV9lcG9jaHMlMjAqJTIwbGVuKHRyYWluX2RhdGFsb2FkZXIpJTBBbHJfc2NoZWR1bGVyJTIwJTNEJTIwZ2V0X3NjaGVkdWxlciglMEElMjAlMjAlMjAlMjAlMjJsaW5lYXIlMjIlMkMlMEElMjAlMjAlMjAlMjBvcHRpbWl6ZXIlM0RvcHRpbWl6ZXIlMkMlMEElMjAlMjAlMjAlMjBudW1fd2FybXVwX3N0ZXBzJTNEMCUyQyUwQSUyMCUyMCUyMCUyMG51bV90cmFpbmluZ19zdGVwcyUzRG51bV90cmFpbmluZ19zdGVwcyUyQyUwQSklMEFwcmludChudW1fdHJhaW5pbmdfc3RlcHMp",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler
num_epochs = <span class="hljs-number">3</span>
num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader)
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)
<span class="hljs-built_in">print</span>(num_training_steps)`,wrap:!1}}),D=new y({props:{code:"MTM3Nw==",highlighted:'<span class="hljs-number">1377</span>',wrap:!1}}),K=new Ge({props:{title:"Die Trainingsschleife",local:"die-trainingsschleife",headingTag:"h3"}}),P=new y({props:{code:"aW1wb3J0JTIwdG9yY2glMEElMEFkZXZpY2UlMjAlM0QlMjB0b3JjaC5kZXZpY2UoJTIyY3VkYSUyMiklMjBpZiUyMHRvcmNoLmN1ZGEuaXNfYXZhaWxhYmxlKCklMjBlbHNlJTIwdG9yY2guZGV2aWNlKCUyMmNwdSUyMiklMEFtb2RlbC50byhkZXZpY2UpJTBBZGV2aWNl",highlighted:`<span class="hljs-keyword">import</span> torch
device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
model.to(device)
device`,wrap:!1}}),O=new y({props:{code:"ZGV2aWNlKHR5cGUlM0QnY3VkYScp",highlighted:'device(<span class="hljs-built_in">type</span>=<span class="hljs-string">&#x27;cuda&#x27;</span>)',wrap:!1}}),le=new y({props:{code:"ZnJvbSUyMHRxZG0uYXV0byUyMGltcG9ydCUyMHRxZG0lMEElMEFwcm9ncmVzc19iYXIlMjAlM0QlMjB0cWRtKHJhbmdlKG51bV90cmFpbmluZ19zdGVwcykpJTBBJTBBbW9kZWwudHJhaW4oKSUwQWZvciUyMGVwb2NoJTIwaW4lMjByYW5nZShudW1fZXBvY2hzKSUzQSUwQSUyMCUyMCUyMCUyMGZvciUyMGJhdGNoJTIwaW4lMjB0cmFpbl9kYXRhbG9hZGVyJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwYmF0Y2glMjAlM0QlMjAlN0JrJTNBJTIwdi50byhkZXZpY2UpJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0QlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzJTIwJTNEJTIwb3V0cHV0cy5sb3NzJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbG9zcy5iYWNrd2FyZCgpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
model.train()
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs):
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),te=new Ge({props:{title:"Die Evaluationsschleife",local:"die-evaluationsschleife",headingTag:"h3"}}),ae=new y({props:{code:"aW1wb3J0JTIwZXZhbHVhdGUlMEElMEFtZXRyaWMlMjAlM0QlMjBldmFsdWF0ZS5sb2FkKCUyMmdsdWUlMjIlMkMlMjAlMjJtcnBjJTIyKSUwQW1vZGVsLmV2YWwoKSUwQWZvciUyMGJhdGNoJTIwaW4lMjBldmFsX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjBiYXRjaCUyMCUzRCUyMCU3QmslM0ElMjB2LnRvKGRldmljZSklMjBmb3IlMjBrJTJDJTIwdiUyMGluJTIwYmF0Y2guaXRlbXMoKSU3RCUwQSUyMCUyMCUyMCUyMHdpdGglMjB0b3JjaC5ub19ncmFkKCklM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMEElMjAlMjAlMjAlMjBsb2dpdHMlMjAlM0QlMjBvdXRwdXRzLmxvZ2l0cyUwQSUyMCUyMCUyMCUyMHByZWRpY3Rpb25zJTIwJTNEJTIwdG9yY2guYXJnbWF4KGxvZ2l0cyUyQyUyMGRpbSUzRC0xKSUwQSUyMCUyMCUyMCUyMG1ldHJpYy5hZGRfYmF0Y2gocHJlZGljdGlvbnMlM0RwcmVkaWN0aW9ucyUyQyUyMHJlZmVyZW5jZXMlM0RiYXRjaCU1QiUyMmxhYmVscyUyMiU1RCklMEElMEFtZXRyaWMuY29tcHV0ZSgp",highlighted:`<span class="hljs-keyword">import</span> evaluate
metric = evaluate.load(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
model.<span class="hljs-built_in">eval</span>()
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader:
batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-<span class="hljs-number">1</span>)
metric.add_batch(predictions=predictions, references=batch[<span class="hljs-string">&quot;labels&quot;</span>])
metric.compute()`,wrap:!1}}),ie=new y({props:{code:"JTdCJ2FjY3VyYWN5JyUzQSUyMDAuODQzMTM3MjU0OTAxOTYwOCUyQyUyMCdmMSclM0ElMjAwLjg5MDc4NDk4MjkzNTE1MzUlN0Q=",highlighted:'{<span class="hljs-string">&#x27;accuracy&#x27;</span>: <span class="hljs-number">0.8431372549019608</span>, <span class="hljs-string">&#x27;f1&#x27;</span>: <span class="hljs-number">0.8907849829351535</span>}',wrap:!1}}),J=new hs({props:{$$slots:{default:[$s]},$$scope:{ctx:ke}}}),pe=new Ge({props:{title:"Verbessere deine Trainingsschleife mit 🤗 Accelerate",local:"verbessere-deine-trainingsschleife-mit--accelerate",headingTag:"h3"}}),de=new Js({props:{id:"s7dy8QRgjJ0"}}),ce=new y({props:{code:"ZnJvbSUyMHRvcmNoLm9wdGltJTIwaW1wb3J0JTIwQWRhbVclMEFmcm9tJTIwdHJhbnNmb3JtZXJzJTIwaW1wb3J0JTIwQXV0b01vZGVsRm9yU2VxdWVuY2VDbGFzc2lmaWNhdGlvbiUyQyUyMGdldF9zY2hlZHVsZXIlMEElMEFtb2RlbCUyMCUzRCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24uZnJvbV9wcmV0cmFpbmVkKGNoZWNrcG9pbnQlMkMlMjBudW1fbGFiZWxzJTNEMiklMEFvcHRpbWl6ZXIlMjAlM0QlMjBBZGFtVyhtb2RlbC5wYXJhbWV0ZXJzKCklMkMlMjBsciUzRDNlLTUpJTBBJTBBZGV2aWNlJTIwJTNEJTIwdG9yY2guZGV2aWNlKCUyMmN1ZGElMjIpJTIwaWYlMjB0b3JjaC5jdWRhLmlzX2F2YWlsYWJsZSgpJTIwZWxzZSUyMHRvcmNoLmRldmljZSglMjJjcHUlMjIpJTBBbW9kZWwudG8oZGV2aWNlKSUwQSUwQW51bV9lcG9jaHMlMjAlM0QlMjAzJTBBbnVtX3RyYWluaW5nX3N0ZXBzJTIwJTNEJTIwbnVtX2Vwb2NocyUyMColMjBsZW4odHJhaW5fZGF0YWxvYWRlciklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTJDJTBBKSUwQSUwQXByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEFtb2RlbC50cmFpbigpJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBiYXRjaCUyMCUzRCUyMCU3QmslM0ElMjB2LnRvKGRldmljZSklMjBmb3IlMjBrJTJDJTIwdiUyMGluJTIwYmF0Y2guaXRlbXMoKSU3RCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMG91dHB1dHMlMjAlM0QlMjBtb2RlbCgqKmJhdGNoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MlMjAlM0QlMjBvdXRwdXRzLmxvc3MlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzLmJhY2t3YXJkKCklMEElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuc3RlcCgpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbHJfc2NoZWR1bGVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMG9wdGltaXplci56ZXJvX2dyYWQoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHByb2dyZXNzX2Jhci51cGRhdGUoMSk=",highlighted:`<span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>)
device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
model.to(device)
num_epochs = <span class="hljs-number">3</span>
num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader)
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
model.train()
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs):
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),oe=new y({props:{code:"JTJCJTIwZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQSUyMCUyMGZyb20lMjB0b3JjaC5vcHRpbSUyMGltcG9ydCUyMEFkYW1XJTBBJTIwJTIwZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMkMlMjBnZXRfc2NoZWR1bGVyJTBBJTBBJTJCJTIwYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcigpJTBBJTBBJTIwJTIwbW9kZWwlMjAlM0QlMjBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50JTJDJTIwbnVtX2xhYmVscyUzRDIpJTBBJTIwJTIwb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0QzZS01KSUwQSUwQS0lMjBkZXZpY2UlMjAlM0QlMjB0b3JjaC5kZXZpY2UoJTIyY3VkYSUyMiklMjBpZiUyMHRvcmNoLmN1ZGEuaXNfYXZhaWxhYmxlKCklMjBlbHNlJTIwdG9yY2guZGV2aWNlKCUyMmNwdSUyMiklMEEtJTIwbW9kZWwudG8oZGV2aWNlKSUwQSUwQSUyQiUyMHRyYWluX2RhdGFsb2FkZXIlMkMlMjBldmFsX2RhdGFsb2FkZXIlMkMlMjBtb2RlbCUyQyUyMG9wdGltaXplciUyMCUzRCUyMGFjY2VsZXJhdG9yLnByZXBhcmUoJTBBJTJCJTIwJTIwJTIwJTIwJTIwdHJhaW5fZGF0YWxvYWRlciUyQyUyMGV2YWxfZGF0YWxvYWRlciUyQyUyMG1vZGVsJTJDJTIwb3B0aW1pemVyJTBBJTJCJTIwKSUwQSUwQSUyMCUyMG51bV9lcG9jaHMlMjAlM0QlMjAzJTBBJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTIwJTNEJTIwbnVtX2Vwb2NocyUyMColMjBsZW4odHJhaW5fZGF0YWxvYWRlciklMEElMjAlMjBscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTBBJTIwJTIwKSUwQSUwQSUyMCUyMHByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEElMjAlMjBtb2RlbC50cmFpbigpJTBBJTIwJTIwZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEEtJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwYmF0Y2glMjAlM0QlMjAlN0JrJTNBJTIwdi50byhkZXZpY2UpJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0QlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzJTIwJTNEJTIwb3V0cHV0cy5sb3NzJTBBLSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MuYmFja3dhcmQoKSUwQSUyQiUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-addition">+ from accelerate import Accelerator</span>
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification, get_scheduler
<span class="hljs-addition">+ accelerator = Accelerator()</span>
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(), lr=3e-5)
<span class="hljs-deletion">- device = torch.device(&quot;cuda&quot;) if torch.cuda.is_available() else torch.device(&quot;cpu&quot;)</span>
<span class="hljs-deletion">- model.to(device)</span>
<span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(</span>
<span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer</span>
<span class="hljs-addition">+ )</span>
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
&quot;linear&quot;,
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
for batch in train_dataloader:
<span class="hljs-deletion">- batch = {k: v.to(device) for k, v in batch.items()}</span>
outputs = model(**batch)
loss = outputs.loss
<span class="hljs-deletion">- loss.backward()</span>
<span class="hljs-addition">+ accelerator.backward(loss)</span>
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)`,wrap:!1}}),j=new hs({props:{$$slots:{default:[vs]},$$scope:{ctx:ke}}}),be=new y({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQWZyb20lMjB0b3JjaC5vcHRpbSUyMGltcG9ydCUyMEFkYW1XJTBBZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMkMlMjBnZXRfc2NoZWR1bGVyJTBBJTBBYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcigpJTBBJTBBbW9kZWwlMjAlM0QlMjBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50JTJDJTIwbnVtX2xhYmVscyUzRDIpJTBBb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0QzZS01KSUwQSUwQXRyYWluX2RsJTJDJTIwZXZhbF9kbCUyQyUyMG1vZGVsJTJDJTIwb3B0aW1pemVyJTIwJTNEJTIwYWNjZWxlcmF0b3IucHJlcGFyZSglMEElMjAlMjAlMjAlMjB0cmFpbl9kYXRhbG9hZGVyJTJDJTIwZXZhbF9kYXRhbG9hZGVyJTJDJTIwbW9kZWwlMkMlMjBvcHRpbWl6ZXIlMEEpJTBBJTBBbnVtX2Vwb2NocyUyMCUzRCUyMDMlMEFudW1fdHJhaW5pbmdfc3RlcHMlMjAlM0QlMjBudW1fZXBvY2hzJTIwKiUyMGxlbih0cmFpbl9kbCklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTJDJTBBKSUwQSUwQXByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEFtb2RlbC50cmFpbigpJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RsJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3V0cHV0cyUyMCUzRCUyMG1vZGVsKCoqYmF0Y2gpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbG9zcyUyMCUzRCUyMG91dHB1dHMubG9zcyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler
accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>)
train_dl, eval_dl, model, optimizer = accelerator.prepare(
train_dataloader, eval_dataloader, model, optimizer
)
num_epochs = <span class="hljs-number">3</span>
num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dl)
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
model.train()
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs):
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dl:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),Je=new y({props:{code:"YWNjZWxlcmF0ZSUyMGNvbmZpZw==",highlighted:"accelerate config",wrap:!1}}),Ue=new y({props:{code:"YWNjZWxlcmF0ZSUyMGxhdW5jaCUyMHRyYWluLnB5",highlighted:'accelerate <span class="hljs-built_in">launch</span> train.py',wrap:!1}}),ge=new y({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBub3RlYm9va19sYXVuY2hlciUwQSUwQW5vdGVib29rX2xhdW5jaGVyKHRyYWluaW5nX2Z1bmN0aW9uKQ==",highlighted:`<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher
notebook_launcher(training_function)`,wrap:!1}}),Ze=new Gs({props:{source:"https://github.com/huggingface/course/blob/main/chapters/de/chapter3/4.mdx"}}),{c(){w=m("meta"),b=n(),h=m("p"),Be=n(),i(U.$$.fragment),$e=n(),i(T.$$.fragment),ve=n(),i(f.$$.fragment),We=n(),g=m("p"),g.innerHTML=Vl,Xe=n(),i(C.$$.fragment),ze=n(),i(Z.$$.fragment),Re=n(),k=m("p"),k.innerHTML=Nl,Ae=n(),B=m("ul"),B.innerHTML=El,_e=n(),I=m("p"),I.innerHTML=Fl,Ye=n(),i(G.$$.fragment),Ve=n(),$=m("p"),$.textContent=Ql,Ne=n(),i(v.$$.fragment),Ee=n(),W=m("p"),W.textContent=Hl,Fe=n(),i(X.$$.fragment),Qe=n(),z=m("p"),z.textContent=xl,He=n(),i(R.$$.fragment),xe=n(),i(A.$$.fragment),Se=n(),_=m("p"),_.innerHTML=Sl,Le=n(),Y=m("p"),Y.textContent=Ll,De=n(),i(V.$$.fragment),Ke=n(),N=m("p"),N.textContent=Dl,qe=n(),i(E.$$.fragment),Pe=n(),i(F.$$.fragment),Oe=n(),Q=m("p"),Q.innerHTML=Kl,el=n(),H=m("p"),H.innerHTML=ql,ll=n(),i(x.$$.fragment),sl=n(),S=m("p"),S.innerHTML=Pl,tl=n(),i(L.$$.fragment),nl=n(),i(D.$$.fragment),al=n(),i(K.$$.fragment),il=n(),q=m("p"),q.innerHTML=Ol,rl=n(),i(P.$$.fragment),pl=n(),i(O.$$.fragment),dl=n(),ee=m("p"),ee.innerHTML=es,Ml=n(),i(le.$$.fragment),cl=n(),se=m("p"),se.textContent=ls,ml=n(),i(te.$$.fragment),ol=n(),ne=m("p"),ne.innerHTML=ss,ul=n(),i(ae.$$.fragment),yl=n(),i(ie.$$.fragment),wl=n(),re=m("p"),re.textContent=ts,bl=n(),i(J.$$.fragment),hl=n(),i(pe.$$.fragment),Jl=n(),i(de.$$.fragment),jl=n(),Me=m("p"),Me.innerHTML=ns,Ul=n(),i(ce.$$.fragment),Tl=n(),me=m("p"),me.textContent=as,fl=n(),i(oe.$$.fragment),gl=n(),ue=m("p"),ue.innerHTML=is,Cl=n(),ye=m("p"),ye.innerHTML=rs,Zl=n(),i(j.$$.fragment),kl=n(),we=m("p"),we.textContent=ps,Bl=n(),i(be.$$.fragment),Il=n(),he=m("p"),he.innerHTML=ds,Gl=n(),i(Je.$$.fragment),$l=n(),je=m("p"),je.textContent=Ms,vl=n(),i(Ue.$$.fragment),Wl=n(),Te=m("p"),Te.textContent=cs,Xl=n(),fe=m("p"),fe.innerHTML=ms,zl=n(),i(ge.$$.fragment),Rl=n(),Ce=m("p"),Ce.innerHTML=os,Al=n(),i(Ze.$$.fragment),_l=n(),Ie=m("p"),this.h()},l(e){const l=Cs("svelte-u9bgzb",document.head);w=o(l,"META",{name:!0,content:!0}),l.forEach(s),b=a(e),h=o(e,"P",{}),ws(h).forEach(s),Be=a(e),r(U.$$.fragment,e),$e=a(e),r(T.$$.fragment,e),ve=a(e),r(f.$$.fragment,e),We=a(e),g=o(e,"P",{"data-svelte-h":!0}),u(g)!=="svelte-u9dlj6"&&(g.innerHTML=Vl),Xe=a(e),r(C.$$.fragment,e),ze=a(e),r(Z.$$.fragment,e),Re=a(e),k=o(e,"P",{"data-svelte-h":!0}),u(k)!=="svelte-10uwzhv"&&(k.innerHTML=Nl),Ae=a(e),B=o(e,"UL",{"data-svelte-h":!0}),u(B)!=="svelte-97ia1"&&(B.innerHTML=El),_e=a(e),I=o(e,"P",{"data-svelte-h":!0}),u(I)!=="svelte-1xre5mi"&&(I.innerHTML=Fl),Ye=a(e),r(G.$$.fragment,e),Ve=a(e),$=o(e,"P",{"data-svelte-h":!0}),u($)!=="svelte-1ne07l6"&&($.textContent=Ql),Ne=a(e),r(v.$$.fragment,e),Ee=a(e),W=o(e,"P",{"data-svelte-h":!0}),u(W)!=="svelte-1qzexy9"&&(W.textContent=Hl),Fe=a(e),r(X.$$.fragment,e),Qe=a(e),z=o(e,"P",{"data-svelte-h":!0}),u(z)!=="svelte-58ge5g"&&(z.textContent=xl),He=a(e),r(R.$$.fragment,e),xe=a(e),r(A.$$.fragment,e),Se=a(e),_=o(e,"P",{"data-svelte-h":!0}),u(_)!=="svelte-tuv6h6"&&(_.innerHTML=Sl),Le=a(e),Y=o(e,"P",{"data-svelte-h":!0}),u(Y)!=="svelte-qcm9iy"&&(Y.textContent=Ll),De=a(e),r(V.$$.fragment,e),Ke=a(e),N=o(e,"P",{"data-svelte-h":!0}),u(N)!=="svelte-1o9btcm"&&(N.textContent=Dl),qe=a(e),r(E.$$.fragment,e),Pe=a(e),r(F.$$.fragment,e),Oe=a(e),Q=o(e,"P",{"data-svelte-h":!0}),u(Q)!=="svelte-1xdm3p"&&(Q.innerHTML=Kl),el=a(e),H=o(e,"P",{"data-svelte-h":!0}),u(H)!=="svelte-n59cj8"&&(H.innerHTML=ql),ll=a(e),r(x.$$.fragment,e),sl=a(e),S=o(e,"P",{"data-svelte-h":!0}),u(S)!=="svelte-1d7k25q"&&(S.innerHTML=Pl),tl=a(e),r(L.$$.fragment,e),nl=a(e),r(D.$$.fragment,e),al=a(e),r(K.$$.fragment,e),il=a(e),q=o(e,"P",{"data-svelte-h":!0}),u(q)!=="svelte-10laz2u"&&(q.innerHTML=Ol),rl=a(e),r(P.$$.fragment,e),pl=a(e),r(O.$$.fragment,e),dl=a(e),ee=o(e,"P",{"data-svelte-h":!0}),u(ee)!=="svelte-5rg5wc"&&(ee.innerHTML=es),Ml=a(e),r(le.$$.fragment,e),cl=a(e),se=o(e,"P",{"data-svelte-h":!0}),u(se)!=="svelte-p31dwq"&&(se.textContent=ls),ml=a(e),r(te.$$.fragment,e),ol=a(e),ne=o(e,"P",{"data-svelte-h":!0}),u(ne)!=="svelte-1ho21o0"&&(ne.innerHTML=ss),ul=a(e),r(ae.$$.fragment,e),yl=a(e),r(ie.$$.fragment,e),wl=a(e),re=o(e,"P",{"data-svelte-h":!0}),u(re)!=="svelte-14ibogl"&&(re.textContent=ts),bl=a(e),r(J.$$.fragment,e),hl=a(e),r(pe.$$.fragment,e),Jl=a(e),r(de.$$.fragment,e),jl=a(e),Me=o(e,"P",{"data-svelte-h":!0}),u(Me)!=="svelte-105b2zv"&&(Me.innerHTML=ns),Ul=a(e),r(ce.$$.fragment,e),Tl=a(e),me=o(e,"P",{"data-svelte-h":!0}),u(me)!=="svelte-1b5atwv"&&(me.textContent=as),fl=a(e),r(oe.$$.fragment,e),gl=a(e),ue=o(e,"P",{"data-svelte-h":!0}),u(ue)!=="svelte-1rigc9c"&&(ue.innerHTML=is),Cl=a(e),ye=o(e,"P",{"data-svelte-h":!0}),u(ye)!=="svelte-2rrnhi"&&(ye.innerHTML=rs),Zl=a(e),r(j.$$.fragment,e),kl=a(e),we=o(e,"P",{"data-svelte-h":!0}),u(we)!=="svelte-1la63xh"&&(we.textContent=ps),Bl=a(e),r(be.$$.fragment,e),Il=a(e),he=o(e,"P",{"data-svelte-h":!0}),u(he)!=="svelte-1l337r5"&&(he.innerHTML=ds),Gl=a(e),r(Je.$$.fragment,e),$l=a(e),je=o(e,"P",{"data-svelte-h":!0}),u(je)!=="svelte-134ovcr"&&(je.textContent=Ms),vl=a(e),r(Ue.$$.fragment,e),Wl=a(e),Te=o(e,"P",{"data-svelte-h":!0}),u(Te)!=="svelte-1m2v744"&&(Te.textContent=cs),Xl=a(e),fe=o(e,"P",{"data-svelte-h":!0}),u(fe)!=="svelte-umh1t9"&&(fe.innerHTML=ms),zl=a(e),r(ge.$$.fragment,e),Rl=a(e),Ce=o(e,"P",{"data-svelte-h":!0}),u(Ce)!=="svelte-1vh5l4m"&&(Ce.innerHTML=os),Al=a(e),r(Ze.$$.fragment,e),_l=a(e),Ie=o(e,"P",{}),ws(Ie).forEach(s),this.h()},h(){bs(w,"name","hf:doc:metadata"),bs(w,"content",Xs)},m(e,l){Zs(document.head,w),t(e,b,l),t(e,h,l),t(e,Be,l),p(U,e,l),t(e,$e,l),p(T,e,l),t(e,ve,l),p(f,e,l),t(e,We,l),t(e,g,l),t(e,Xe,l),p(C,e,l),t(e,ze,l),p(Z,e,l),t(e,Re,l),t(e,k,l),t(e,Ae,l),t(e,B,l),t(e,_e,l),t(e,I,l),t(e,Ye,l),p(G,e,l),t(e,Ve,l),t(e,$,l),t(e,Ne,l),p(v,e,l),t(e,Ee,l),t(e,W,l),t(e,Fe,l),p(X,e,l),t(e,Qe,l),t(e,z,l),t(e,He,l),p(R,e,l),t(e,xe,l),p(A,e,l),t(e,Se,l),t(e,_,l),t(e,Le,l),t(e,Y,l),t(e,De,l),p(V,e,l),t(e,Ke,l),t(e,N,l),t(e,qe,l),p(E,e,l),t(e,Pe,l),p(F,e,l),t(e,Oe,l),t(e,Q,l),t(e,el,l),t(e,H,l),t(e,ll,l),p(x,e,l),t(e,sl,l),t(e,S,l),t(e,tl,l),p(L,e,l),t(e,nl,l),p(D,e,l),t(e,al,l),p(K,e,l),t(e,il,l),t(e,q,l),t(e,rl,l),p(P,e,l),t(e,pl,l),p(O,e,l),t(e,dl,l),t(e,ee,l),t(e,Ml,l),p(le,e,l),t(e,cl,l),t(e,se,l),t(e,ml,l),p(te,e,l),t(e,ol,l),t(e,ne,l),t(e,ul,l),p(ae,e,l),t(e,yl,l),p(ie,e,l),t(e,wl,l),t(e,re,l),t(e,bl,l),p(J,e,l),t(e,hl,l),p(pe,e,l),t(e,Jl,l),p(de,e,l),t(e,jl,l),t(e,Me,l),t(e,Ul,l),p(ce,e,l),t(e,Tl,l),t(e,me,l),t(e,fl,l),p(oe,e,l),t(e,gl,l),t(e,ue,l),t(e,Cl,l),t(e,ye,l),t(e,Zl,l),p(j,e,l),t(e,kl,l),t(e,we,l),t(e,Bl,l),p(be,e,l),t(e,Il,l),t(e,he,l),t(e,Gl,l),p(Je,e,l),t(e,$l,l),t(e,je,l),t(e,vl,l),p(Ue,e,l),t(e,Wl,l),t(e,Te,l),t(e,Xl,l),t(e,fe,l),t(e,zl,l),p(ge,e,l),t(e,Rl,l),t(e,Ce,l),t(e,Al,l),p(Ze,e,l),t(e,_l,l),t(e,Ie,l),Yl=!0},p(e,[l]){const us={};l&2&&(us.$$scope={dirty:l,ctx:e}),J.$set(us);const ys={};l&2&&(ys.$$scope={dirty:l,ctx:e}),j.$set(ys)},i(e){Yl||(d(U.$$.fragment,e),d(T.$$.fragment,e),d(f.$$.fragment,e),d(C.$$.fragment,e),d(Z.$$.fragment,e),d(G.$$.fragment,e),d(v.$$.fragment,e),d(X.$$.fragment,e),d(R.$$.fragment,e),d(A.$$.fragment,e),d(V.$$.fragment,e),d(E.$$.fragment,e),d(F.$$.fragment,e),d(x.$$.fragment,e),d(L.$$.fragment,e),d(D.$$.fragment,e),d(K.$$.fragment,e),d(P.$$.fragment,e),d(O.$$.fragment,e),d(le.$$.fragment,e),d(te.$$.fragment,e),d(ae.$$.fragment,e),d(ie.$$.fragment,e),d(J.$$.fragment,e),d(pe.$$.fragment,e),d(de.$$.fragment,e),d(ce.$$.fragment,e),d(oe.$$.fragment,e),d(j.$$.fragment,e),d(be.$$.fragment,e),d(Je.$$.fragment,e),d(Ue.$$.fragment,e),d(ge.$$.fragment,e),d(Ze.$$.fragment,e),Yl=!0)},o(e){M(U.$$.fragment,e),M(T.$$.fragment,e),M(f.$$.fragment,e),M(C.$$.fragment,e),M(Z.$$.fragment,e),M(G.$$.fragment,e),M(v.$$.fragment,e),M(X.$$.fragment,e),M(R.$$.fragment,e),M(A.$$.fragment,e),M(V.$$.fragment,e),M(E.$$.fragment,e),M(F.$$.fragment,e),M(x.$$.fragment,e),M(L.$$.fragment,e),M(D.$$.fragment,e),M(K.$$.fragment,e),M(P.$$.fragment,e),M(O.$$.fragment,e),M(le.$$.fragment,e),M(te.$$.fragment,e),M(ae.$$.fragment,e),M(ie.$$.fragment,e),M(J.$$.fragment,e),M(pe.$$.fragment,e),M(de.$$.fragment,e),M(ce.$$.fragment,e),M(oe.$$.fragment,e),M(j.$$.fragment,e),M(be.$$.fragment,e),M(Je.$$.fragment,e),M(Ue.$$.fragment,e),M(ge.$$.fragment,e),M(Ze.$$.fragment,e),Yl=!1},d(e){e&&(s(b),s(h),s(Be),s($e),s(ve),s(We),s(g),s(Xe),s(ze),s(Re),s(k),s(Ae),s(B),s(_e),s(I),s(Ye),s(Ve),s($),s(Ne),s(Ee),s(W),s(Fe),s(Qe),s(z),s(He),s(xe),s(Se),s(_),s(Le),s(Y),s(De),s(Ke),s(N),s(qe),s(Pe),s(Oe),s(Q),s(el),s(H),s(ll),s(sl),s(S),s(tl),s(nl),s(al),s(il),s(q),s(rl),s(pl),s(dl),s(ee),s(Ml),s(cl),s(se),s(ml),s(ol),s(ne),s(ul),s(yl),s(wl),s(re),s(bl),s(hl),s(Jl),s(jl),s(Me),s(Ul),s(Tl),s(me),s(fl),s(gl),s(ue),s(Cl),s(ye),s(Zl),s(kl),s(we),s(Bl),s(Il),s(he),s(Gl),s($l),s(je),s(vl),s(Wl),s(Te),s(Xl),s(fe),s(zl),s(Rl),s(Ce),s(Al),s(_l),s(Ie)),s(w),c(U,e),c(T,e),c(f,e),c(C,e),c(Z,e),c(G,e),c(v,e),c(X,e),c(R,e),c(A,e),c(V,e),c(E,e),c(F,e),c(x,e),c(L,e),c(D,e),c(K,e),c(P,e),c(O,e),c(le,e),c(te,e),c(ae,e),c(ie,e),c(J,e),c(pe,e),c(de,e),c(ce,e),c(oe,e),c(j,e),c(be,e),c(Je,e),c(Ue,e),c(ge,e),c(Ze,e)}}}const Xs='{"title":"Komplettes Training","local":"komplettes-training","sections":[{"title":"Vorbereitung für das Training","local":"vorbereitung-für-das-training","sections":[],"depth":3},{"title":"Die Trainingsschleife","local":"die-trainingsschleife","sections":[],"depth":3},{"title":"Die Evaluationsschleife","local":"die-evaluationsschleife","sections":[],"depth":3},{"title":"Verbessere deine Trainingsschleife mit 🤗 Accelerate","local":"verbessere-deine-trainingsschleife-mit--accelerate","sections":[],"depth":3}],"depth":1}';function zs(ke){return Us(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class Fs extends fs{constructor(w){super(),gs(this,w,zs,Ws,js,{})}}export{Fs as component};

Xet Storage Details

Size:
44.5 kB
·
Xet hash:
fe974fd1ec951540c293abf360886d28bab02c1206820d335565d2e783332c6e

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.