Buckets:
| import{s as Ut,o as Tt,n as ht}from"../chunks/scheduler.37c15a92.js";import{S as ft,i as Ct,g as m,s,r as i,A as Zt,h as u,f as t,c as n,j as bt,u as r,x as d,k as jt,y as Bt,a,v as p,d as c,t as M,w as o,m as vt,n as It}from"../chunks/index.2bf4358c.js";import{T as Jt}from"../chunks/Tip.363c041f.js";import{Y as wt}from"../chunks/Youtube.1e50a667.js";import{C as y}from"../chunks/CodeBlock.4e987730.js";import{C as $t}from"../chunks/CourseFloatingBanner.6add7356.js";import{H as Ge,E as Gt}from"../chunks/getInferenceSnippets.24b50994.js";function kt(ve){let b,j="✏️ <strong>Încercați!</strong> Modificați bucla de antrenament anterioară pentru a vă rafina modelul pe dataset-ul SST-2.";return{c(){b=m("p"),b.innerHTML=j},l(J){b=u(J,"P",{"data-svelte-h":!0}),d(b)!=="svelte-105erym"&&(b.innerHTML=j)},m(J,Ie){a(J,b,Ie)},p:ht,d(J){J&&t(b)}}}function Xt(ve){let b;return{c(){b=vt('⚠️ Pentru a beneficia de creșterea vitezei oferită de Cloud TPU-uri, vă recomandăm să împachetați mostrele la o lungime fixă folosind argumentele `padding="max_length"` și `max_length` ale tokenizer-ului.')},l(j){b=It(j,'⚠️ Pentru a beneficia de creșterea vitezei oferită de Cloud TPU-uri, vă recomandăm să împachetați mostrele la o lungime fixă folosind argumentele `padding="max_length"` și `max_length` ale tokenizer-ului.')},m(j,J){a(j,b,J)},d(j){j&&t(b)}}}function Wt(ve){let b,j,J,Ie,T,ke,h,Xe,f,We,C,zl="Acum vom vedea cum să obținem aceleași rezultate ca în secțiunea anterioară, dar fără să folosim clasa <code>Trainer</code>. Din nou, presupunem că ați parcurs deja procesarea datelor din secțiunea 2. Iată un scurt rezumat al tot ceea ce veți avea nevoie:",ge,Z,Re,B,_e,v,Nl="Înainte de a scrie efectiv bucla de antrenament, va trebui să definim câteva obiecte. Primele sunt încărcătoarele de date (dataloaders) pe care le vom folosi pentru a itera pe batch-uri. Dar, înainte de a putea defini acele dataloaders, trebuie să aplicăm un pic de postprocesare dataset-urilor noastre <code>tokenized_datasets</code>, pentru a ne ocupa de câteva lucruri pe care <code>Trainer</code> le făcea automat pentru noi. Mai exact, trebuie să:",Ye,I,El="<li>Eliminăm coloanele care corespund valorilor pe care modelul nu le așteaptă (cum ar fi coloanele <code>sentence1</code> și <code>sentence2</code>).</li> <li>Redenumim coloana <code>label</code> în <code>labels</code> (pentru că modelul se așteaptă ca argumentul să se numească <code>labels</code>).</li> <li>Setăm formatul dataset-urilor astfel încât să returneze tensori PyTorch în loc de liste.</li>",Ae,$,Ql="<code>tokenized_datasets</code> are câte o metodă pentru fiecare dintre acești pași:",Ve,G,ze,k,Fl="Putem apoi să verificăm că rezultatul are doar coloanele pe care modelul le va accepta:",Ne,X,Ee,W,xl="Acum, după ce am terminat acest pas, putem defini foarte ușor dataloader-urile noastre:",Qe,g,Fe,R,Hl="Pentru a verifica rapid că nu există nicio eroare în procesarea datelor, putem inspecta un batch astfel:",xe,_,He,Y,Se,A,Sl="Observați că formele reale ar putea fi ușor diferite pentru voi, pentru că am setat <code>shuffle=True</code> în dataloader-ul nostru de antrenament și pentru că împachetăm (padding) la lungimea maximă în interiorul batch-ului.",qe,V,ql="Acum că am terminat complet procesarea datelor (un obiectiv satisfăcător, dar uneori greu de atins pentru orice practician ML), să trecem la model. Îl instanțiem exact ca în secțiunea anterioară:",Ke,z,Le,N,Kl="Pentru a ne asigura că totul va decurge fără probleme în timpul antrenamentului, trecem batch-ul nostru prin model:",Pe,E,De,Q,Oe,F,Ll="Toate modelele 🤗 Transformers vor returna pierderea (loss) când <code>labels</code> sunt furnizate, și, de asemenea, obținem logits (două pentru fiecare intrare în batch, deci un tensor de mărimea 8 x 2).",el,x,Pl='Suntem aproape gata să scriem bucla de antrenament! Ne mai lipsesc două lucruri: un optimizer și un scheduler pentru rata de învățare. Pentru că încercăm să reproducem ceea ce făcea <code>Trainer</code>, vom folosi aceleași valori implicite. Optimizer-ul folosit de <code>Trainer</code> este <code>AdamW</code>, care este același cu Adam, dar cu o abordare particulară pentru regularizarea weight decay (vedeți lucrarea <a href="https://arxiv.org/abs/1711.05101" rel="nofollow">“Decoupled Weight Decay Regularization”</a> de Ilya Loshchilov și Frank Hutter):',ll,H,tl,S,Dl="În final, scheduler-ul pentru rata de învățare folosit implicit este doar o descreștere liniară de la valoarea maximă (5e-5) la 0. Pentru a-l defini corect, trebuie să știm numărul de pași de antrenament pe care îi vom face, care este numărul de epoci dorit înmulțit cu numărul de batch-uri de antrenament (care este lungimea dataloader-ului nostru de antrenament). <code>Trainer</code> folosește trei epoci implicit, așa că vom urma acest exemplu:",al,q,sl,K,nl,L,il,P,Ol="Încă un lucru: vom dori să folosim GPU-ul dacă avem acces la unul (pe un CPU, antrenamentul poate dura câteva ore în loc de câteva minute). Pentru asta, definim un <code>device</code> pe care vom pune modelul și batch-urile noastre:",rl,D,pl,O,cl,ee,et="Acum suntem gata de antrenament! Pentru a ne face o idee despre momentul în care se va termina antrenamentul, adăugăm o bară de progres peste numărul de pași de antrenament, folosind biblioteca <code>tqdm</code>:",Ml,le,ol,te,lt="Observați că partea principală a buclei de antrenament arată foarte asemănător cu cea din introducere. Nu am cerut niciun raport, așa că această buclă de antrenament nu ne va spune nimic despre performanța modelului. Pentru a avea feedback, trebuie să adăugăm o buclă de evaluare.",ml,ae,ul,se,tt="Ca și înainte, vom folosi o metrică oferită de biblioteca 🤗 Evaluate. Am văzut deja metoda <code>metric.compute()</code>, dar metricle pot de fapt să acumuleze batch-uri pentru noi în timp ce parcurgem bucla de predicție, cu metoda <code>add_batch()</code>. Odată ce am acumulat toate batch-urile, putem obține rezultatul final cu <code>metric.compute()</code>. Iată cum să implementăm toate acestea într-o buclă de evaluare:",dl,ne,yl,ie,bl,re,at="Din nou, rezultatele voastre vor fi ușor diferite din cauza aleatorietății în inițializarea layer-ului final (model head) și a amestecării datelor, dar ar trebui să fie în aceeași zonă valorică.",jl,w,Jl,pe,wl,ce,Ul,Me,st='Bucla de antrenament pe care am definit-o anterior funcționează bine pe un singur CPU sau GPU. Dar, folosind biblioteca <a href="https://github.com/huggingface/accelerate" rel="nofollow">🤗 Accelerate</a>, cu doar câteva ajustări putem activa antrenarea distribuită pe mai multe GPU-uri sau TPU-uri. Pornind de la crearea dataloader-urilor de antrenament și validare, iată cum arată bucla noastră manuală de antrenament:',Tl,oe,hl,me,nt="Iar aici sunt modificările:",fl,ue,Cl,de,it="Prima linie de adăugat este linia de import. A doua linie instanțiază un obiect <code>Accelerator</code> care va examina mediul și va inițializa setarea distribuită corespunzătoare. 🤗 Accelerate se ocupă de plasarea pe device pentru voi, așa că puteți elimina liniile care pun modelul pe device (sau, dacă preferați, le puteți schimba să folosească <code>accelerator.device</code> în loc de <code>device</code>).",Zl,ye,rt="Apoi, partea principală a muncii este făcută în linia care trimite dataloaders, modelul și optimizer-ul la <code>accelerator.prepare()</code>. Aceasta va împacheta acele obiecte în containerul potrivit pentru a vă asigura că antrenarea distribuită funcționează corespunzător. Restul modificărilor constau în eliminarea liniei care mută batch-ul pe <code>device</code> (din nou, dacă doriți să o păstrați puteți doar să o schimbați să folosească <code>accelerator.device</code>) și înlocuirea <code>loss.backward()</code> cu <code>accelerator.backward(loss)</code>.",Bl,U,vl,be,pt="Dacă vreți să copiați și să lipiți pentru a vă juca, iată cum arată bucla completă de antrenament cu 🤗 Accelerate:",Il,je,$l,Je,ct="Plasând acest cod într-un fișier <code>train.py</code> îl face rulabil pe orice tip de configurare distribuită. Pentru a-l încerca în configurarea voastră distribuită, rulați comanda:",Gl,we,kl,Ue,Mt="care vă va cere să răspundeți la câteva întrebări și vă va crea un fișier de configurare folosit de comanda:",Xl,Te,Wl,he,ot="care va porni antrenarea distribuită.",gl,fe,mt="Dacă vreți să încercați asta într-un Notebook (de exemplu, pentru a-l testa cu TPU-uri pe Colab), doar lipiți codul într-o funcție <code>training_function()</code> și rulați într-o celulă finală:",Rl,Ce,_l,Ze,ut='Puteți găsi mai multe exemple în <a href="https://github.com/huggingface/accelerate/tree/main/examples" rel="nofollow">repo-ul 🤗 Accelerate</a>.',Yl,Be,Al,$e,Vl;return T=new Ge({props:{title:"O instruire completă",local:"o-instruire-completă",headingTag:"h1"}}),h=new $t({props:{chapter:3,classNames:"absolute z-10 right-0 top-0",notebooks:[{label:"Google Colab",value:"https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/en/chapter3/section4.ipynb"},{label:"Aws Studio",value:"https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/en/chapter3/section4.ipynb"}]}}),f=new wt({props:{id:"Dh9CL8fyG80"}}),Z=new y({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwbG9hZF9kYXRhc2V0JTBBZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Ub2tlbml6ZXIlMkMlMjBEYXRhQ29sbGF0b3JXaXRoUGFkZGluZyUwQSUwQXJhd19kYXRhc2V0cyUyMCUzRCUyMGxvYWRfZGF0YXNldCglMjJnbHVlJTIyJTJDJTIwJTIybXJwYyUyMiklMEFjaGVja3BvaW50JTIwJTNEJTIwJTIyYmVydC1iYXNlLXVuY2FzZWQlMjIlMEF0b2tlbml6ZXIlMjAlM0QlMjBBdXRvVG9rZW5pemVyLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50KSUwQSUwQSUwQWRlZiUyMHRva2VuaXplX2Z1bmN0aW9uKGV4YW1wbGUpJTNBJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwdG9rZW5pemVyKGV4YW1wbGUlNUIlMjJzZW50ZW5jZTElMjIlNUQlMkMlMjBleGFtcGxlJTVCJTIyc2VudGVuY2UyJTIyJTVEJTJDJTIwdHJ1bmNhdGlvbiUzRFRydWUpJTBBJTBBJTBBdG9rZW5pemVkX2RhdGFzZXRzJTIwJTNEJTIwcmF3X2RhdGFzZXRzLm1hcCh0b2tlbml6ZV9mdW5jdGlvbiUyQyUyMGJhdGNoZWQlM0RUcnVlKSUwQWRhdGFfY29sbGF0b3IlMjAlM0QlMjBEYXRhQ29sbGF0b3JXaXRoUGFkZGluZyh0b2tlbml6ZXIlM0R0b2tlbml6ZXIp",highlighted:`<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding | |
| raw_datasets = load_dataset(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| checkpoint = <span class="hljs-string">"bert-base-uncased"</span> | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| <span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>): | |
| <span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">"sentence1"</span>], example[<span class="hljs-string">"sentence2"</span>], truncation=<span class="hljs-literal">True</span>) | |
| tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer)`,wrap:!1}}),B=new Ge({props:{title:"Pregătirea pentru antrenament",local:"pregătirea-pentru-antrenament",headingTag:"h3"}}),G=new y({props:{code:"dG9rZW5pemVkX2RhdGFzZXRzJTIwJTNEJTIwdG9rZW5pemVkX2RhdGFzZXRzLnJlbW92ZV9jb2x1bW5zKCU1QiUyMnNlbnRlbmNlMSUyMiUyQyUyMCUyMnNlbnRlbmNlMiUyMiUyQyUyMCUyMmlkeCUyMiU1RCklMEF0b2tlbml6ZWRfZGF0YXNldHMlMjAlM0QlMjB0b2tlbml6ZWRfZGF0YXNldHMucmVuYW1lX2NvbHVtbiglMjJsYWJlbCUyMiUyQyUyMCUyMmxhYmVscyUyMiklMEF0b2tlbml6ZWRfZGF0YXNldHMuc2V0X2Zvcm1hdCglMjJ0b3JjaCUyMiklMEF0b2tlbml6ZWRfZGF0YXNldHMlNUIlMjJ0cmFpbiUyMiU1RC5jb2x1bW5fbmFtZXM=",highlighted:`tokenized_datasets = tokenized_datasets.remove_columns([<span class="hljs-string">"sentence1"</span>, <span class="hljs-string">"sentence2"</span>, <span class="hljs-string">"idx"</span>]) | |
| tokenized_datasets = tokenized_datasets.rename_column(<span class="hljs-string">"label"</span>, <span class="hljs-string">"labels"</span>) | |
| tokenized_datasets.set_format(<span class="hljs-string">"torch"</span>) | |
| tokenized_datasets[<span class="hljs-string">"train"</span>].column_names`,wrap:!1}}),X=new y({props:{code:"JTVCJTIyYXR0ZW50aW9uX21hc2slMjIlMkMlMjAlMjJpbnB1dF9pZHMlMjIlMkMlMjAlMjJsYWJlbHMlMjIlMkMlMjAlMjJ0b2tlbl90eXBlX2lkcyUyMiU1RA==",highlighted:'[<span class="hljs-string">"attention_mask"</span>, <span class="hljs-string">"input_ids"</span>, <span class="hljs-string">"labels"</span>, <span class="hljs-string">"token_type_ids"</span>]',wrap:!1}}),g=new y({props:{code:"ZnJvbSUyMHRvcmNoLnV0aWxzLmRhdGElMjBpbXBvcnQlMjBEYXRhTG9hZGVyJTBBJTBBdHJhaW5fZGF0YWxvYWRlciUyMCUzRCUyMERhdGFMb2FkZXIoJTBBJTIwJTIwJTIwJTIwdG9rZW5pemVkX2RhdGFzZXRzJTVCJTIydHJhaW4lMjIlNUQlMkMlMjBzaHVmZmxlJTNEVHJ1ZSUyQyUyMGJhdGNoX3NpemUlM0Q4JTJDJTIwY29sbGF0ZV9mbiUzRGRhdGFfY29sbGF0b3IlMEEpJTBBZXZhbF9kYXRhbG9hZGVyJTIwJTNEJTIwRGF0YUxvYWRlciglMEElMjAlMjAlMjAlMjB0b2tlbml6ZWRfZGF0YXNldHMlNUIlMjJ2YWxpZGF0aW9uJTIyJTVEJTJDJTIwYmF0Y2hfc2l6ZSUzRDglMkMlMjBjb2xsYXRlX2ZuJTNEZGF0YV9jb2xsYXRvciUwQSk=",highlighted:`<span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader | |
| train_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"train"</span>], shuffle=<span class="hljs-literal">True</span>, batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator | |
| ) | |
| eval_dataloader = DataLoader( | |
| tokenized_datasets[<span class="hljs-string">"validation"</span>], batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator | |
| )`,wrap:!1}}),_=new y({props:{code:"Zm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjBicmVhayUwQSU3QmslM0ElMjB2LnNoYXBlJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0Q=",highlighted:`<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| <span class="hljs-keyword">break</span> | |
| {k: v.shape <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}`,wrap:!1}}),Y=new y({props:{code:"JTdCJ2F0dGVudGlvbl9tYXNrJyUzQSUyMHRvcmNoLlNpemUoJTVCOCUyQyUyMDY1JTVEKSUyQyUwQSUyMCdpbnB1dF9pZHMnJTNBJTIwdG9yY2guU2l6ZSglNUI4JTJDJTIwNjUlNUQpJTJDJTBBJTIwJ2xhYmVscyclM0ElMjB0b3JjaC5TaXplKCU1QjglNUQpJTJDJTBBJTIwJ3Rva2VuX3R5cGVfaWRzJyUzQSUyMHRvcmNoLlNpemUoJTVCOCUyQyUyMDY1JTVEKSU3RA==",highlighted:`{<span class="hljs-string">'attention_mask'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]), | |
| <span class="hljs-string">'input_ids'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]), | |
| <span class="hljs-string">'labels'</span>: torch.Size([<span class="hljs-number">8</span>]), | |
| <span class="hljs-string">'token_type_ids'</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>])}`,wrap:!1}}),z=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMEElMEFtb2RlbCUyMCUzRCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24uZnJvbV9wcmV0cmFpbmVkKGNoZWNrcG9pbnQlMkMlMjBudW1fbGFiZWxzJTNEMik=",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)`,wrap:!1}}),E=new y({props:{code:"b3V0cHV0cyUyMCUzRCUyMG1vZGVsKCoqYmF0Y2gpJTBBcHJpbnQob3V0cHV0cy5sb3NzJTJDJTIwb3V0cHV0cy5sb2dpdHMuc2hhcGUp",highlighted:`outputs = model(**batch) | |
| <span class="hljs-built_in">print</span>(outputs.loss, outputs.logits.shape)`,wrap:!1}}),Q=new y({props:{code:"dGVuc29yKDAuNTQ0MSUyQyUyMGdyYWRfZm4lM0QlM0NObGxMb3NzQmFja3dhcmQlM0UpJTIwdG9yY2guU2l6ZSglNUI4JTJDJTIwMiU1RCk=",highlighted:'tensor(<span class="hljs-number">0.5441</span>, grad_fn=<NllLossBackward>) torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">2</span>])',wrap:!1}}),H=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEFkYW1XJTBBJTBBb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0Q1ZS01KQ==",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AdamW | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">5e-5</span>)`,wrap:!1}}),q=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMGdldF9zY2hlZHVsZXIlMEElMEFudW1fZXBvY2hzJTIwJTNEJTIwMyUwQW51bV90cmFpbmluZ19zdGVwcyUyMCUzRCUyMG51bV9lcG9jaHMlMjAqJTIwbGVuKHRyYWluX2RhdGFsb2FkZXIpJTBBbHJfc2NoZWR1bGVyJTIwJTNEJTIwZ2V0X3NjaGVkdWxlciglMEElMjAlMjAlMjAlMjAlMjJsaW5lYXIlMjIlMkMlMEElMjAlMjAlMjAlMjBvcHRpbWl6ZXIlM0RvcHRpbWl6ZXIlMkMlMEElMjAlMjAlMjAlMjBudW1fd2FybXVwX3N0ZXBzJTNEMCUyQyUwQSUyMCUyMCUyMCUyMG51bV90cmFpbmluZ19zdGVwcyUzRG51bV90cmFpbmluZ19zdGVwcyUyQyUwQSklMEFwcmludChudW1fdHJhaW5pbmdfc3RlcHMp",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| <span class="hljs-built_in">print</span>(num_training_steps)`,wrap:!1}}),K=new y({props:{code:"MTM3Nw==",highlighted:'<span class="hljs-number">1377</span>',wrap:!1}}),L=new Ge({props:{title:"Bucla de antrenament",local:"bucla-de-antrenament",headingTag:"h3"}}),D=new y({props:{code:"aW1wb3J0JTIwdG9yY2glMEElMEFkZXZpY2UlMjAlM0QlMjB0b3JjaC5kZXZpY2UoJTIyY3VkYSUyMiklMjBpZiUyMHRvcmNoLmN1ZGEuaXNfYXZhaWxhYmxlKCklMjBlbHNlJTIwdG9yY2guZGV2aWNlKCUyMmNwdSUyMiklMEFtb2RlbC50byhkZXZpY2UpJTBBZGV2aWNl",highlighted:`<span class="hljs-keyword">import</span> torch | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| model.to(device) | |
| device`,wrap:!1}}),O=new y({props:{code:"ZGV2aWNlKHR5cGUlM0QnY3VkYScp",highlighted:'device(<span class="hljs-built_in">type</span>=<span class="hljs-string">'cuda'</span>)',wrap:!1}}),le=new y({props:{code:"ZnJvbSUyMHRxZG0uYXV0byUyMGltcG9ydCUyMHRxZG0lMEElMEFwcm9ncmVzc19iYXIlMjAlM0QlMjB0cWRtKHJhbmdlKG51bV90cmFpbmluZ19zdGVwcykpJTBBJTBBbW9kZWwudHJhaW4oKSUwQWZvciUyMGVwb2NoJTIwaW4lMjByYW5nZShudW1fZXBvY2hzKSUzQSUwQSUyMCUyMCUyMCUyMGZvciUyMGJhdGNoJTIwaW4lMjB0cmFpbl9kYXRhbG9hZGVyJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwYmF0Y2glMjAlM0QlMjAlN0JrJTNBJTIwdi50byhkZXZpY2UpJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0QlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzJTIwJTNEJTIwb3V0cHV0cy5sb3NzJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbG9zcy5iYWNrd2FyZCgpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),ae=new Ge({props:{title:"Bucla de evaluare",local:"bucla-de-evaluare",headingTag:"h3"}}),ne=new y({props:{code:"aW1wb3J0JTIwZXZhbHVhdGUlMEElMEFtZXRyaWMlMjAlM0QlMjBldmFsdWF0ZS5sb2FkKCUyMmdsdWUlMjIlMkMlMjAlMjJtcnBjJTIyKSUwQW1vZGVsLmV2YWwoKSUwQWZvciUyMGJhdGNoJTIwaW4lMjBldmFsX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjBiYXRjaCUyMCUzRCUyMCU3QmslM0ElMjB2LnRvKGRldmljZSklMjBmb3IlMjBrJTJDJTIwdiUyMGluJTIwYmF0Y2guaXRlbXMoKSU3RCUwQSUyMCUyMCUyMCUyMHdpdGglMjB0b3JjaC5ub19ncmFkKCklM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMEElMjAlMjAlMjAlMjBsb2dpdHMlMjAlM0QlMjBvdXRwdXRzLmxvZ2l0cyUwQSUyMCUyMCUyMCUyMHByZWRpY3Rpb25zJTIwJTNEJTIwdG9yY2guYXJnbWF4KGxvZ2l0cyUyQyUyMGRpbSUzRC0xKSUwQSUyMCUyMCUyMCUyMG1ldHJpYy5hZGRfYmF0Y2gocHJlZGljdGlvbnMlM0RwcmVkaWN0aW9ucyUyQyUyMHJlZmVyZW5jZXMlM0RiYXRjaCU1QiUyMmxhYmVscyUyMiU1RCklMEElMEFtZXRyaWMuY29tcHV0ZSgp",highlighted:`<span class="hljs-keyword">import</span> evaluate | |
| metric = evaluate.load(<span class="hljs-string">"glue"</span>, <span class="hljs-string">"mrpc"</span>) | |
| model.<span class="hljs-built_in">eval</span>() | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| <span class="hljs-keyword">with</span> torch.no_grad(): | |
| outputs = model(**batch) | |
| logits = outputs.logits | |
| predictions = torch.argmax(logits, dim=-<span class="hljs-number">1</span>) | |
| metric.add_batch(predictions=predictions, references=batch[<span class="hljs-string">"labels"</span>]) | |
| metric.compute()`,wrap:!1}}),ie=new y({props:{code:"JTdCJ2FjY3VyYWN5JyUzQSUyMDAuODQzMTM3MjU0OTAxOTYwOCUyQyUyMCdmMSclM0ElMjAwLjg5MDc4NDk4MjkzNTE1MzUlN0Q=",highlighted:'{<span class="hljs-string">'accuracy'</span>: <span class="hljs-number">0.8431372549019608</span>, <span class="hljs-string">'f1'</span>: <span class="hljs-number">0.8907849829351535</span>}',wrap:!1}}),w=new Jt({props:{$$slots:{default:[kt]},$$scope:{ctx:ve}}}),pe=new Ge({props:{title:"Îmbunătățiți circuitul de antrenament cu 🤗 Accelerate",local:"îmbunătățiți-circuitul-de-antrenament-cu-accelerate",headingTag:"h3"}}),ce=new wt({props:{id:"s7dy8QRgjJ0"}}),oe=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEFkYW1XJTJDJTIwQXV0b01vZGVsRm9yU2VxdWVuY2VDbGFzc2lmaWNhdGlvbiUyQyUyMGdldF9zY2hlZHVsZXIlMEElMEFtb2RlbCUyMCUzRCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24uZnJvbV9wcmV0cmFpbmVkKGNoZWNrcG9pbnQlMkMlMjBudW1fbGFiZWxzJTNEMiklMEFvcHRpbWl6ZXIlMjAlM0QlMjBBZGFtVyhtb2RlbC5wYXJhbWV0ZXJzKCklMkMlMjBsciUzRDNlLTUpJTBBJTBBZGV2aWNlJTIwJTNEJTIwdG9yY2guZGV2aWNlKCUyMmN1ZGElMjIpJTIwaWYlMjB0b3JjaC5jdWRhLmlzX2F2YWlsYWJsZSgpJTIwZWxzZSUyMHRvcmNoLmRldmljZSglMjJjcHUlMjIpJTBBbW9kZWwudG8oZGV2aWNlKSUwQSUwQW51bV9lcG9jaHMlMjAlM0QlMjAzJTBBbnVtX3RyYWluaW5nX3N0ZXBzJTIwJTNEJTIwbnVtX2Vwb2NocyUyMColMjBsZW4odHJhaW5fZGF0YWxvYWRlciklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTJDJTBBKSUwQSUwQXByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEFtb2RlbC50cmFpbigpJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBiYXRjaCUyMCUzRCUyMCU3QmslM0ElMjB2LnRvKGRldmljZSklMjBmb3IlMjBrJTJDJTIwdiUyMGluJTIwYmF0Y2guaXRlbXMoKSU3RCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMG91dHB1dHMlMjAlM0QlMjBtb2RlbCgqKmJhdGNoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MlMjAlM0QlMjBvdXRwdXRzLmxvc3MlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzLmJhY2t3YXJkKCklMEElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuc3RlcCgpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbHJfc2NoZWR1bGVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMG9wdGltaXplci56ZXJvX2dyYWQoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHByb2dyZXNzX2Jhci51cGRhdGUoMSk=",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AdamW, AutoModelForSequenceClassification, get_scheduler | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>) | |
| device = torch.device(<span class="hljs-string">"cuda"</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">"cpu"</span>) | |
| model.to(device) | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader: | |
| batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()} | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| loss.backward() | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),ue=new y({props:{code:"JTJCJTIwZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQSUyMCUyMGZyb20lMjB0cmFuc2Zvcm1lcnMlMjBpbXBvcnQlMjBBZGFtVyUyQyUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMkMlMjBnZXRfc2NoZWR1bGVyJTBBJTBBJTJCJTIwYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcigpJTBBJTBBJTIwJTIwbW9kZWwlMjAlM0QlMjBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50JTJDJTIwbnVtX2xhYmVscyUzRDIpJTBBJTIwJTIwb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0QzZS01KSUwQSUwQS0lMjBkZXZpY2UlMjAlM0QlMjB0b3JjaC5kZXZpY2UoJTIyY3VkYSUyMiklMjBpZiUyMHRvcmNoLmN1ZGEuaXNfYXZhaWxhYmxlKCklMjBlbHNlJTIwdG9yY2guZGV2aWNlKCUyMmNwdSUyMiklMEEtJTIwbW9kZWwudG8oZGV2aWNlKSUwQSUwQSUyQiUyMHRyYWluX2RhdGFsb2FkZXIlMkMlMjBldmFsX2RhdGFsb2FkZXIlMkMlMjBtb2RlbCUyQyUyMG9wdGltaXplciUyMCUzRCUyMGFjY2VsZXJhdG9yLnByZXBhcmUoJTBBJTJCJTIwJTIwJTIwJTIwJTIwdHJhaW5fZGF0YWxvYWRlciUyQyUyMGV2YWxfZGF0YWxvYWRlciUyQyUyMG1vZGVsJTJDJTIwb3B0aW1pemVyJTBBJTJCJTIwKSUwQSUwQSUyMCUyMG51bV9lcG9jaHMlMjAlM0QlMjAzJTBBJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTIwJTNEJTIwbnVtX2Vwb2NocyUyMColMjBsZW4odHJhaW5fZGF0YWxvYWRlciklMEElMjAlMjBscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTBBJTIwJTIwKSUwQSUwQSUyMCUyMHByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEElMjAlMjBtb2RlbC50cmFpbigpJTBBJTIwJTIwZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEEtJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwYmF0Y2glMjAlM0QlMjAlN0JrJTNBJTIwdi50byhkZXZpY2UpJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0QlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzJTIwJTNEJTIwb3V0cHV0cy5sb3NzJTBBLSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MuYmFja3dhcmQoKSUwQSUyQiUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-addition">+ from accelerate import Accelerator</span> | |
| from transformers import AdamW, AutoModelForSequenceClassification, get_scheduler | |
| <span class="hljs-addition">+ accelerator = Accelerator()</span> | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2) | |
| optimizer = AdamW(model.parameters(), lr=3e-5) | |
| <span class="hljs-deletion">- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")</span> | |
| <span class="hljs-deletion">- model.to(device)</span> | |
| <span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(</span> | |
| <span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer</span> | |
| <span class="hljs-addition">+ )</span> | |
| num_epochs = 3 | |
| num_training_steps = num_epochs * len(train_dataloader) | |
| lr_scheduler = get_scheduler( | |
| "linear", | |
| optimizer=optimizer, | |
| num_warmup_steps=0, | |
| num_training_steps=num_training_steps | |
| ) | |
| progress_bar = tqdm(range(num_training_steps)) | |
| model.train() | |
| for epoch in range(num_epochs): | |
| for batch in train_dataloader: | |
| <span class="hljs-deletion">- batch = {k: v.to(device) for k, v in batch.items()}</span> | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| <span class="hljs-deletion">- loss.backward()</span> | |
| <span class="hljs-addition">+ accelerator.backward(loss)</span> | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(1)`,wrap:!1}}),U=new Jt({props:{$$slots:{default:[Xt]},$$scope:{ctx:ve}}}),je=new y({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQWZyb20lMjB0cmFuc2Zvcm1lcnMlMjBpbXBvcnQlMjBBZGFtVyUyQyUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMkMlMjBnZXRfc2NoZWR1bGVyJTBBJTBBYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcigpJTBBJTBBbW9kZWwlMjAlM0QlMjBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50JTJDJTIwbnVtX2xhYmVscyUzRDIpJTBBb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0QzZS01KSUwQSUwQXRyYWluX2RsJTJDJTIwZXZhbF9kbCUyQyUyMG1vZGVsJTJDJTIwb3B0aW1pemVyJTIwJTNEJTIwYWNjZWxlcmF0b3IucHJlcGFyZSglMEElMjAlMjAlMjAlMjB0cmFpbl9kYXRhbG9hZGVyJTJDJTIwZXZhbF9kYXRhbG9hZGVyJTJDJTIwbW9kZWwlMkMlMjBvcHRpbWl6ZXIlMEEpJTBBJTBBbnVtX2Vwb2NocyUyMCUzRCUyMDMlMEFudW1fdHJhaW5pbmdfc3RlcHMlMjAlM0QlMjBudW1fZXBvY2hzJTIwKiUyMGxlbih0cmFpbl9kbCklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTJDJTBBKSUwQSUwQXByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEFtb2RlbC50cmFpbigpJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RsJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3V0cHV0cyUyMCUzRCUyMG1vZGVsKCoqYmF0Y2gpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbG9zcyUyMCUzRCUyMG91dHB1dHMubG9zcyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator | |
| <span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AdamW, AutoModelForSequenceClassification, get_scheduler | |
| accelerator = Accelerator() | |
| model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>) | |
| optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>) | |
| train_dl, eval_dl, model, optimizer = accelerator.prepare( | |
| train_dataloader, eval_dataloader, model, optimizer | |
| ) | |
| num_epochs = <span class="hljs-number">3</span> | |
| num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dl) | |
| lr_scheduler = get_scheduler( | |
| <span class="hljs-string">"linear"</span>, | |
| optimizer=optimizer, | |
| num_warmup_steps=<span class="hljs-number">0</span>, | |
| num_training_steps=num_training_steps, | |
| ) | |
| progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps)) | |
| model.train() | |
| <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs): | |
| <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dl: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),we=new y({props:{code:"YWNjZWxlcmF0ZSUyMGNvbmZpZw==",highlighted:"accelerate config",wrap:!1}}),Te=new y({props:{code:"YWNjZWxlcmF0ZSUyMGxhdW5jaCUyMHRyYWluLnB5",highlighted:'accelerate <span class="hljs-built_in">launch</span> train.py',wrap:!1}}),Ce=new y({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBub3RlYm9va19sYXVuY2hlciUwQSUwQW5vdGVib29rX2xhdW5jaGVyKHRyYWluaW5nX2Z1bmN0aW9uKQ==",highlighted:`<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher | |
| notebook_launcher(training_function)`,wrap:!1}}),Be=new Gt({props:{source:"https://github.com/huggingface/course/blob/main/chapters/rum/chapter3/4.mdx"}}),{c(){b=m("meta"),j=s(),J=m("p"),Ie=s(),i(T.$$.fragment),ke=s(),i(h.$$.fragment),Xe=s(),i(f.$$.fragment),We=s(),C=m("p"),C.innerHTML=zl,ge=s(),i(Z.$$.fragment),Re=s(),i(B.$$.fragment),_e=s(),v=m("p"),v.innerHTML=Nl,Ye=s(),I=m("ul"),I.innerHTML=El,Ae=s(),$=m("p"),$.innerHTML=Ql,Ve=s(),i(G.$$.fragment),ze=s(),k=m("p"),k.textContent=Fl,Ne=s(),i(X.$$.fragment),Ee=s(),W=m("p"),W.textContent=xl,Qe=s(),i(g.$$.fragment),Fe=s(),R=m("p"),R.textContent=Hl,xe=s(),i(_.$$.fragment),He=s(),i(Y.$$.fragment),Se=s(),A=m("p"),A.innerHTML=Sl,qe=s(),V=m("p"),V.textContent=ql,Ke=s(),i(z.$$.fragment),Le=s(),N=m("p"),N.textContent=Kl,Pe=s(),i(E.$$.fragment),De=s(),i(Q.$$.fragment),Oe=s(),F=m("p"),F.innerHTML=Ll,el=s(),x=m("p"),x.innerHTML=Pl,ll=s(),i(H.$$.fragment),tl=s(),S=m("p"),S.innerHTML=Dl,al=s(),i(q.$$.fragment),sl=s(),i(K.$$.fragment),nl=s(),i(L.$$.fragment),il=s(),P=m("p"),P.innerHTML=Ol,rl=s(),i(D.$$.fragment),pl=s(),i(O.$$.fragment),cl=s(),ee=m("p"),ee.innerHTML=et,Ml=s(),i(le.$$.fragment),ol=s(),te=m("p"),te.textContent=lt,ml=s(),i(ae.$$.fragment),ul=s(),se=m("p"),se.innerHTML=tt,dl=s(),i(ne.$$.fragment),yl=s(),i(ie.$$.fragment),bl=s(),re=m("p"),re.textContent=at,jl=s(),i(w.$$.fragment),Jl=s(),i(pe.$$.fragment),wl=s(),i(ce.$$.fragment),Ul=s(),Me=m("p"),Me.innerHTML=st,Tl=s(),i(oe.$$.fragment),hl=s(),me=m("p"),me.textContent=nt,fl=s(),i(ue.$$.fragment),Cl=s(),de=m("p"),de.innerHTML=it,Zl=s(),ye=m("p"),ye.innerHTML=rt,Bl=s(),i(U.$$.fragment),vl=s(),be=m("p"),be.textContent=pt,Il=s(),i(je.$$.fragment),$l=s(),Je=m("p"),Je.innerHTML=ct,Gl=s(),i(we.$$.fragment),kl=s(),Ue=m("p"),Ue.textContent=Mt,Xl=s(),i(Te.$$.fragment),Wl=s(),he=m("p"),he.textContent=ot,gl=s(),fe=m("p"),fe.innerHTML=mt,Rl=s(),i(Ce.$$.fragment),_l=s(),Ze=m("p"),Ze.innerHTML=ut,Yl=s(),i(Be.$$.fragment),Al=s(),$e=m("p"),this.h()},l(e){const l=Zt("svelte-u9bgzb",document.head);b=u(l,"META",{name:!0,content:!0}),l.forEach(t),j=n(e),J=u(e,"P",{}),bt(J).forEach(t),Ie=n(e),r(T.$$.fragment,e),ke=n(e),r(h.$$.fragment,e),Xe=n(e),r(f.$$.fragment,e),We=n(e),C=u(e,"P",{"data-svelte-h":!0}),d(C)!=="svelte-q6dk4a"&&(C.innerHTML=zl),ge=n(e),r(Z.$$.fragment,e),Re=n(e),r(B.$$.fragment,e),_e=n(e),v=u(e,"P",{"data-svelte-h":!0}),d(v)!=="svelte-uiid4o"&&(v.innerHTML=Nl),Ye=n(e),I=u(e,"UL",{"data-svelte-h":!0}),d(I)!=="svelte-c2aek9"&&(I.innerHTML=El),Ae=n(e),$=u(e,"P",{"data-svelte-h":!0}),d($)!=="svelte-1v3h9d0"&&($.innerHTML=Ql),Ve=n(e),r(G.$$.fragment,e),ze=n(e),k=u(e,"P",{"data-svelte-h":!0}),d(k)!=="svelte-4ka9th"&&(k.textContent=Fl),Ne=n(e),r(X.$$.fragment,e),Ee=n(e),W=u(e,"P",{"data-svelte-h":!0}),d(W)!=="svelte-awckyh"&&(W.textContent=xl),Qe=n(e),r(g.$$.fragment,e),Fe=n(e),R=u(e,"P",{"data-svelte-h":!0}),d(R)!=="svelte-r4tu2n"&&(R.textContent=Hl),xe=n(e),r(_.$$.fragment,e),He=n(e),r(Y.$$.fragment,e),Se=n(e),A=u(e,"P",{"data-svelte-h":!0}),d(A)!=="svelte-iz9nb8"&&(A.innerHTML=Sl),qe=n(e),V=u(e,"P",{"data-svelte-h":!0}),d(V)!=="svelte-q76el0"&&(V.textContent=ql),Ke=n(e),r(z.$$.fragment,e),Le=n(e),N=u(e,"P",{"data-svelte-h":!0}),d(N)!=="svelte-1wnox7m"&&(N.textContent=Kl),Pe=n(e),r(E.$$.fragment,e),De=n(e),r(Q.$$.fragment,e),Oe=n(e),F=u(e,"P",{"data-svelte-h":!0}),d(F)!=="svelte-am270m"&&(F.innerHTML=Ll),el=n(e),x=u(e,"P",{"data-svelte-h":!0}),d(x)!=="svelte-qdgcv9"&&(x.innerHTML=Pl),ll=n(e),r(H.$$.fragment,e),tl=n(e),S=u(e,"P",{"data-svelte-h":!0}),d(S)!=="svelte-1teovpf"&&(S.innerHTML=Dl),al=n(e),r(q.$$.fragment,e),sl=n(e),r(K.$$.fragment,e),nl=n(e),r(L.$$.fragment,e),il=n(e),P=u(e,"P",{"data-svelte-h":!0}),d(P)!=="svelte-18mygbx"&&(P.innerHTML=Ol),rl=n(e),r(D.$$.fragment,e),pl=n(e),r(O.$$.fragment,e),cl=n(e),ee=u(e,"P",{"data-svelte-h":!0}),d(ee)!=="svelte-qjyknt"&&(ee.innerHTML=et),Ml=n(e),r(le.$$.fragment,e),ol=n(e),te=u(e,"P",{"data-svelte-h":!0}),d(te)!=="svelte-1uejh3i"&&(te.textContent=lt),ml=n(e),r(ae.$$.fragment,e),ul=n(e),se=u(e,"P",{"data-svelte-h":!0}),d(se)!=="svelte-15vnicp"&&(se.innerHTML=tt),dl=n(e),r(ne.$$.fragment,e),yl=n(e),r(ie.$$.fragment,e),bl=n(e),re=u(e,"P",{"data-svelte-h":!0}),d(re)!=="svelte-1qxzzsw"&&(re.textContent=at),jl=n(e),r(w.$$.fragment,e),Jl=n(e),r(pe.$$.fragment,e),wl=n(e),r(ce.$$.fragment,e),Ul=n(e),Me=u(e,"P",{"data-svelte-h":!0}),d(Me)!=="svelte-1epvec1"&&(Me.innerHTML=st),Tl=n(e),r(oe.$$.fragment,e),hl=n(e),me=u(e,"P",{"data-svelte-h":!0}),d(me)!=="svelte-1a2jma4"&&(me.textContent=nt),fl=n(e),r(ue.$$.fragment,e),Cl=n(e),de=u(e,"P",{"data-svelte-h":!0}),d(de)!=="svelte-nbet76"&&(de.innerHTML=it),Zl=n(e),ye=u(e,"P",{"data-svelte-h":!0}),d(ye)!=="svelte-rpzdpp"&&(ye.innerHTML=rt),Bl=n(e),r(U.$$.fragment,e),vl=n(e),be=u(e,"P",{"data-svelte-h":!0}),d(be)!=="svelte-1ygvpse"&&(be.textContent=pt),Il=n(e),r(je.$$.fragment,e),$l=n(e),Je=u(e,"P",{"data-svelte-h":!0}),d(Je)!=="svelte-1ch6shh"&&(Je.innerHTML=ct),Gl=n(e),r(we.$$.fragment,e),kl=n(e),Ue=u(e,"P",{"data-svelte-h":!0}),d(Ue)!=="svelte-silivt"&&(Ue.textContent=Mt),Xl=n(e),r(Te.$$.fragment,e),Wl=n(e),he=u(e,"P",{"data-svelte-h":!0}),d(he)!=="svelte-1a33atb"&&(he.textContent=ot),gl=n(e),fe=u(e,"P",{"data-svelte-h":!0}),d(fe)!=="svelte-3m7ys1"&&(fe.innerHTML=mt),Rl=n(e),r(Ce.$$.fragment,e),_l=n(e),Ze=u(e,"P",{"data-svelte-h":!0}),d(Ze)!=="svelte-1jw41xa"&&(Ze.innerHTML=ut),Yl=n(e),r(Be.$$.fragment,e),Al=n(e),$e=u(e,"P",{}),bt($e).forEach(t),this.h()},h(){jt(b,"name","hf:doc:metadata"),jt(b,"content",gt)},m(e,l){Bt(document.head,b),a(e,j,l),a(e,J,l),a(e,Ie,l),p(T,e,l),a(e,ke,l),p(h,e,l),a(e,Xe,l),p(f,e,l),a(e,We,l),a(e,C,l),a(e,ge,l),p(Z,e,l),a(e,Re,l),p(B,e,l),a(e,_e,l),a(e,v,l),a(e,Ye,l),a(e,I,l),a(e,Ae,l),a(e,$,l),a(e,Ve,l),p(G,e,l),a(e,ze,l),a(e,k,l),a(e,Ne,l),p(X,e,l),a(e,Ee,l),a(e,W,l),a(e,Qe,l),p(g,e,l),a(e,Fe,l),a(e,R,l),a(e,xe,l),p(_,e,l),a(e,He,l),p(Y,e,l),a(e,Se,l),a(e,A,l),a(e,qe,l),a(e,V,l),a(e,Ke,l),p(z,e,l),a(e,Le,l),a(e,N,l),a(e,Pe,l),p(E,e,l),a(e,De,l),p(Q,e,l),a(e,Oe,l),a(e,F,l),a(e,el,l),a(e,x,l),a(e,ll,l),p(H,e,l),a(e,tl,l),a(e,S,l),a(e,al,l),p(q,e,l),a(e,sl,l),p(K,e,l),a(e,nl,l),p(L,e,l),a(e,il,l),a(e,P,l),a(e,rl,l),p(D,e,l),a(e,pl,l),p(O,e,l),a(e,cl,l),a(e,ee,l),a(e,Ml,l),p(le,e,l),a(e,ol,l),a(e,te,l),a(e,ml,l),p(ae,e,l),a(e,ul,l),a(e,se,l),a(e,dl,l),p(ne,e,l),a(e,yl,l),p(ie,e,l),a(e,bl,l),a(e,re,l),a(e,jl,l),p(w,e,l),a(e,Jl,l),p(pe,e,l),a(e,wl,l),p(ce,e,l),a(e,Ul,l),a(e,Me,l),a(e,Tl,l),p(oe,e,l),a(e,hl,l),a(e,me,l),a(e,fl,l),p(ue,e,l),a(e,Cl,l),a(e,de,l),a(e,Zl,l),a(e,ye,l),a(e,Bl,l),p(U,e,l),a(e,vl,l),a(e,be,l),a(e,Il,l),p(je,e,l),a(e,$l,l),a(e,Je,l),a(e,Gl,l),p(we,e,l),a(e,kl,l),a(e,Ue,l),a(e,Xl,l),p(Te,e,l),a(e,Wl,l),a(e,he,l),a(e,gl,l),a(e,fe,l),a(e,Rl,l),p(Ce,e,l),a(e,_l,l),a(e,Ze,l),a(e,Yl,l),p(Be,e,l),a(e,Al,l),a(e,$e,l),Vl=!0},p(e,[l]){const dt={};l&2&&(dt.$$scope={dirty:l,ctx:e}),w.$set(dt);const yt={};l&2&&(yt.$$scope={dirty:l,ctx:e}),U.$set(yt)},i(e){Vl||(c(T.$$.fragment,e),c(h.$$.fragment,e),c(f.$$.fragment,e),c(Z.$$.fragment,e),c(B.$$.fragment,e),c(G.$$.fragment,e),c(X.$$.fragment,e),c(g.$$.fragment,e),c(_.$$.fragment,e),c(Y.$$.fragment,e),c(z.$$.fragment,e),c(E.$$.fragment,e),c(Q.$$.fragment,e),c(H.$$.fragment,e),c(q.$$.fragment,e),c(K.$$.fragment,e),c(L.$$.fragment,e),c(D.$$.fragment,e),c(O.$$.fragment,e),c(le.$$.fragment,e),c(ae.$$.fragment,e),c(ne.$$.fragment,e),c(ie.$$.fragment,e),c(w.$$.fragment,e),c(pe.$$.fragment,e),c(ce.$$.fragment,e),c(oe.$$.fragment,e),c(ue.$$.fragment,e),c(U.$$.fragment,e),c(je.$$.fragment,e),c(we.$$.fragment,e),c(Te.$$.fragment,e),c(Ce.$$.fragment,e),c(Be.$$.fragment,e),Vl=!0)},o(e){M(T.$$.fragment,e),M(h.$$.fragment,e),M(f.$$.fragment,e),M(Z.$$.fragment,e),M(B.$$.fragment,e),M(G.$$.fragment,e),M(X.$$.fragment,e),M(g.$$.fragment,e),M(_.$$.fragment,e),M(Y.$$.fragment,e),M(z.$$.fragment,e),M(E.$$.fragment,e),M(Q.$$.fragment,e),M(H.$$.fragment,e),M(q.$$.fragment,e),M(K.$$.fragment,e),M(L.$$.fragment,e),M(D.$$.fragment,e),M(O.$$.fragment,e),M(le.$$.fragment,e),M(ae.$$.fragment,e),M(ne.$$.fragment,e),M(ie.$$.fragment,e),M(w.$$.fragment,e),M(pe.$$.fragment,e),M(ce.$$.fragment,e),M(oe.$$.fragment,e),M(ue.$$.fragment,e),M(U.$$.fragment,e),M(je.$$.fragment,e),M(we.$$.fragment,e),M(Te.$$.fragment,e),M(Ce.$$.fragment,e),M(Be.$$.fragment,e),Vl=!1},d(e){e&&(t(j),t(J),t(Ie),t(ke),t(Xe),t(We),t(C),t(ge),t(Re),t(_e),t(v),t(Ye),t(I),t(Ae),t($),t(Ve),t(ze),t(k),t(Ne),t(Ee),t(W),t(Qe),t(Fe),t(R),t(xe),t(He),t(Se),t(A),t(qe),t(V),t(Ke),t(Le),t(N),t(Pe),t(De),t(Oe),t(F),t(el),t(x),t(ll),t(tl),t(S),t(al),t(sl),t(nl),t(il),t(P),t(rl),t(pl),t(cl),t(ee),t(Ml),t(ol),t(te),t(ml),t(ul),t(se),t(dl),t(yl),t(bl),t(re),t(jl),t(Jl),t(wl),t(Ul),t(Me),t(Tl),t(hl),t(me),t(fl),t(Cl),t(de),t(Zl),t(ye),t(Bl),t(vl),t(be),t(Il),t($l),t(Je),t(Gl),t(kl),t(Ue),t(Xl),t(Wl),t(he),t(gl),t(fe),t(Rl),t(_l),t(Ze),t(Yl),t(Al),t($e)),t(b),o(T,e),o(h,e),o(f,e),o(Z,e),o(B,e),o(G,e),o(X,e),o(g,e),o(_,e),o(Y,e),o(z,e),o(E,e),o(Q,e),o(H,e),o(q,e),o(K,e),o(L,e),o(D,e),o(O,e),o(le,e),o(ae,e),o(ne,e),o(ie,e),o(w,e),o(pe,e),o(ce,e),o(oe,e),o(ue,e),o(U,e),o(je,e),o(we,e),o(Te,e),o(Ce,e),o(Be,e)}}}const gt='{"title":"O instruire completă","local":"o-instruire-completă","sections":[{"title":"Pregătirea pentru antrenament","local":"pregătirea-pentru-antrenament","sections":[],"depth":3},{"title":"Bucla de antrenament","local":"bucla-de-antrenament","sections":[],"depth":3},{"title":"Bucla de evaluare","local":"bucla-de-evaluare","sections":[],"depth":3},{"title":"Îmbunătățiți circuitul de antrenament cu 🤗 Accelerate","local":"îmbunătățiți-circuitul-de-antrenament-cu-accelerate","sections":[],"depth":3}],"depth":1}';function Rt(ve){return Tt(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class Qt extends ft{constructor(b){super(),Ct(this,b,Rt,Wt,Ut,{})}}export{Qt as component}; | |
Xet Storage Details
- Size:
- 44 kB
- Xet hash:
- d1510324b9182d3ab2966272baa975d8ccae7d074475158c6aed70e538f62359
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.