Buckets:

rtrm's picture
download
raw
44.5 kB
import{s as Ua,o as Ta,n as ha}from"../chunks/scheduler.37c15a92.js";import{S as fa,i as Ca,g as M,s as t,r as o,A as Za,h as d,f as a,c as n,j as ba,u as r,x as u,k as ja,y as Ia,a as s,v as p,d as i,t as c,w as m,m as Ba,n as $a}from"../chunks/index.2bf4358c.js";import{T as Ja}from"../chunks/Tip.363c041f.js";import{Y as wa}from"../chunks/Youtube.1e50a667.js";import{C as y}from"../chunks/CodeBlock.4e987730.js";import{C as Ga}from"../chunks/CourseFloatingBanner.6add7356.js";import{H as ge}from"../chunks/Heading.8ada512a.js";import{E as ga}from"../chunks/getInferenceSnippets.5cee47d1.js";function ka(Be){let b,j="✏️ <strong>Inténtalo!</strong> Modifica el bucle de entrenamiento anterior para ajustar tu modelo en el conjunto de datos SST-2.";return{c(){b=M("p"),b.innerHTML=j},l(J){b=d(J,"P",{"data-svelte-h":!0}),u(b)!=="svelte-vsapoy"&&(b.innerHTML=j)},m(J,$e){s(J,b,$e)},p:ha,d(J){J&&a(b)}}}function Xa(Be){let b;return{c(){b=Ba('⚠️ Para obtener el beneficio de la aceleración ofrecida por los TPUs de la\n nube, recomendamos rellenar las muestras hasta una longitud fija con los\n argumentos `padding="max_length"` y `max_length` del tokenizador.')},l(j){b=$a(j,'⚠️ Para obtener el beneficio de la aceleración ofrecida por los TPUs de la\n nube, recomendamos rellenar las muestras hasta una longitud fija con los\n argumentos `padding="max_length"` y `max_length` del tokenizador.')},m(j,J){s(j,b,J)},d(j){j&&a(b)}}}function va(Be){let b,j,J,$e,T,ke,h,Xe,f,ve,C,zl="Ahora veremos como obtener los mismos resultados de la última sección sin hacer uso de la clase <code>Trainer</code>. De nuevo, asumimos que has hecho el procesamiento de datos en la sección 2. Aquí mostramos un resumen que cubre todo lo que necesitarás.",We,Z,Re,I,_e,B,Nl="Antes de escribir nuestro bucle de entrenamiento, necesitaremos definir algunos objetos. Los primeros son los <code>dataloaders</code> (literalmente, “cargadores de datos”) que usaremos para iterar sobre lotes. Pero antes de que podamos definir esos <code>dataloaders</code>, necesitamos aplicar un poquito de preprocesamiento a nuestro <code>tokenized_datasets</code>, para encargarnos de algunas cosas que el <code>Trainer</code> hizo por nosotros de manera automática. Específicamente, necesitamos:",Ye,$,El="<li>Remover las columnas correspondientes a valores que el model no espera (como las columnas <code>sentence1</code> y <code>sentence2</code>).</li> <li>Renombrar la columna <code>label</code> con <code>labels</code> (porque el modelo espera el argumento llamado <code>labels</code>).</li> <li>Configurar el formato de los conjuntos de datos para que retornen tensores PyTorch en lugar de listas.</li>",Ae,G,Fl="Nuestro <code>tokenized_datasets</code> tiene un método para cada uno de esos pasos:",Ve,g,ze,k,Ql="Ahora podemos verificar que el resultado solo tiene columnas que nuestro modelo aceptará:",Ne,X,Ee,v,Hl="Ahora que esto esta hecho, es fácil definir nuestros <code>dataloaders</code>:",Fe,W,Qe,R,xl="Para verificar rápidamente que no hubo errores en el procesamiento de datos, podemos inspeccionar un lote de la siguiente manera:",He,_,xe,Y,Se,A,Sl="Nótese que los tamaños serán un poco distintos en tu caso ya que configuramos <code>shuffle=True</code> para el dataloader de entrenamiento y estamos rellenando a la máxima longitud dentro del lote.",qe,V,ql="Ahora que hemos completado el preprocesamiento de datos (un objetivo gratificante y al mismo tiempo elusivo para cual cualquier practicante de ML), enfoquémonos en el modelo. Lo vamos a crear exactamente como lo hicimos en la sección anterior.",Le,z,Ke,N,Ll="Para asegurarnos de que todo va a salir sin problems durante el entrenamiento, vamos a pasar un lote a este modelo:",Pe,E,De,F,Oe,Q,Kl="Todos los modelos 🤗 Transformers van a retornar la pérdida cuando se pasan los <code>labels</code>, y también obtenemos los logits (dos por cada entrada en nuestro lote, asi que es un tensor de tamaño 8 x 2).",el,H,Pl='Estamos casi listos para escribir nuestro bucle de entrenamiento! Nos están faltando dos cosas: un optimizador y un programador de la tasa de aprendizaje. Ya que estamos tratando de replicar a mano lo que el <code>Trainer</code> estaba haciendo, usaremos los mismos valores por defecto. El optimizador usado por el <code>Trainer</code> es <code>AdamW</code>, que es el mismo que Adam, pero con un cambio para la regularización de decremento de los pesos (ver <a href="https://arxiv.org/abs/1711.05101" rel="nofollow">“Decoupled Weight Decay Regularization”</a> por Ilya Loshchilov y Frank Hutter):',ll,x,al,S,Dl="Finalmente, el programador por defecto de la tasa de aprendizaje es un decremento lineal desde al valor máximo (5e-5) hasta 0. Para definirlo apropiadamente, necesitamos saber el número de pasos de entrenamiento que vamos a tener, el cual viene dado por el número de épocas que deseamos correr multiplicado por el número de lotes de entrenamiento (que es el largo de nuestro dataloader de entrenamiento). El <code>Trainer</code> usa tres épocas por defecto, asi que usaremos eso:",sl,q,tl,L,nl,K,ol,P,Ol="Una última cosa: vamos a querer usar el GPU si tenemos acceso a uno (en un CPU, el entrenamiento puede tomar varias horas en lugar de unos pocos minutos). Para hacer esto, definimos un <code>device</code> sobre el que pondremos nuestro modelo y nuestros lotes:",rl,D,pl,O,il,ee,ea="¡Ya está todo listo para entrenar! Para tener una idea de cuándo va a terminar el entrenamiento, adicionamos una barra de progreso sobre el número de pasos de entrenamiento, usando la librería <code>tqdm</code>:",cl,le,ml,ae,la="Puedes ver que la parte central del bucle de entrenamiento luce bastante como el de la introducción. No se incluyó ningún tipo de reportes, asi que este bucle de entrenamiento no va a indicar como se esta desempeñando el modelo. Para eso necesitamos añadir un bucle de evaluación.",Ml,se,dl,te,aa="Como lo hicimos anteriormente, usaremos una métrica ofrecida por la librería 🤗 Evaluate. Ya hemos visto el método <code>metric.compute()</code>, pero de hecho las métricas se pueden acumular sobre los lotes a medida que avanzamos en el bucle de predicción con el método <code>add_batch()</code>. Una vez que hemos acumulado todos los lotes, podemos obtener el resultado final con <code>metric.compute()</code>. Aquí se muestra cómo se puede implementar en un bucle de evaluación:",ul,ne,yl,oe,bl,re,sa="De nuevo, tus resultados serán un tanto diferente debido a la inicialización aleatoria en la cabeza del modelo y el mezclado de los datos, pero deberían tener valores similares.",jl,w,Jl,pe,wl,ie,Ul,ce,ta='El bucle de entrenamiento que definimos anteriormente trabaja bien en una sola CPU o GPU. Pero usando la librería <a href="https://github.com/huggingface/accelerate" rel="nofollow">Accelerate 🤗</a>, con solo pocos ajustes podemos habilitar el entrenamiento distribuido en múltiples GPUs o CPUs. Comenzando con la creación de los dataloaders de entrenamiento y validación, aquí se muestra como luce nuestro bucle de entrenamiento:',Tl,me,hl,Me,na="Y aquí están los cambios:",fl,de,Cl,ue,oa="La primera línea a agregarse es la línea del <code>import</code>. La segunda línea crea un objeto <code>Accelerator</code> que revisa el ambiente e inicializa la configuración distribuida apropiada. La librería 🤗 Accelerate se encarga de asignarte el dispositivo, para que puedas remover las líneas que ponen el modelo en el dispositivo (o si prefieres, cámbialas para usar el <code>accelerator.device</code> en lugar de <code>device</code>).",Zl,ye,ra="Ahora la mayor parte del trabajo se hace en la línea que envía los <code>dataloaders</code>, el modelo y el optimizador al <code>accelerator.prepare()</code>. Este va a envolver esos objetos en el contenedor apropiado para asegurarse que tu entrenamiento distribuido funcione como se espera. Los cambios que quedan son remover la línea que coloca el lote en el <code>device</code> (de nuevo, si deseas dejarlo así bastaría con cambiarlo para que use el <code>accelerator.device</code>) y reemplazar <code>loss.backward()</code> con <code>accelerator.backward(loss)</code>.",Il,U,Bl,be,pa="Si deseas copiarlo y pegarlo para probar, así es como luce el bucle completo de entrenamiento con 🤗 Accelerate:",$l,je,Gl,Je,ia="Colocando esto en un script <code>train.py</code> permitirá que el mismo sea ejecutable en cualquier configuración distribuida. Para probarlo en tu configuración distribuida, ejecuta el siguiente comando:",gl,we,kl,Ue,ca="el cual hará algunas preguntas y guardará tus respuestas en un archivo de configuración usado por este comando:",Xl,Te,vl,he,ma="el cual iniciará en entrenamiento distribuido.",Wl,fe,Ma="Si deseas ejecutar esto en un Notebook (por ejemplo, para probarlo con TPUs en Colab), solo pega el código en una <code>training_function()</code> y ejecuta la última celda con:",Rl,Ce,_l,Ze,da='Puedes encontrar más ejemplos en el <a href="https://github.com/huggingface/accelerate/tree/main/examples" rel="nofollow">repositorio 🤗 Accelerate</a>.',Yl,Ie,Al,Ge,Vl;return T=new ge({props:{title:"Un entrenamiento completo",local:"un-entrenamiento-completo",headingTag:"h1"}}),h=new Ga({props:{chapter:3,classNames:"absolute z-10 right-0 top-0",notebooks:[{label:"Google Colab",value:"https://colab.research.google.com/github/huggingface/notebooks/blob/master/course/es/chapter3/section4.ipynb"},{label:"Aws Studio",value:"https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/master/course/es/chapter3/section4.ipynb"}]}}),f=new wa({props:{id:"Dh9CL8fyG80"}}),Z=new y({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwbG9hZF9kYXRhc2V0JTBBZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Ub2tlbml6ZXIlMkMlMjBEYXRhQ29sbGF0b3JXaXRoUGFkZGluZyUwQSUwQXJhd19kYXRhc2V0cyUyMCUzRCUyMGxvYWRfZGF0YXNldCglMjJnbHVlJTIyJTJDJTIwJTIybXJwYyUyMiklMEFjaGVja3BvaW50JTIwJTNEJTIwJTIyYmVydC1iYXNlLXVuY2FzZWQlMjIlMEF0b2tlbml6ZXIlMjAlM0QlMjBBdXRvVG9rZW5pemVyLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50KSUwQSUwQSUwQWRlZiUyMHRva2VuaXplX2Z1bmN0aW9uKGV4YW1wbGUpJTNBJTBBJTIwJTIwJTIwJTIwcmV0dXJuJTIwdG9rZW5pemVyKGV4YW1wbGUlNUIlMjJzZW50ZW5jZTElMjIlNUQlMkMlMjBleGFtcGxlJTVCJTIyc2VudGVuY2UyJTIyJTVEJTJDJTIwdHJ1bmNhdGlvbiUzRFRydWUpJTBBJTBBJTBBdG9rZW5pemVkX2RhdGFzZXRzJTIwJTNEJTIwcmF3X2RhdGFzZXRzLm1hcCh0b2tlbml6ZV9mdW5jdGlvbiUyQyUyMGJhdGNoZWQlM0RUcnVlKSUwQWRhdGFfY29sbGF0b3IlMjAlM0QlMjBEYXRhQ29sbGF0b3JXaXRoUGFkZGluZyh0b2tlbml6ZXIlM0R0b2tlbml6ZXIp",highlighted:`<span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoTokenizer, DataCollatorWithPadding
raw_datasets = load_dataset(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
checkpoint = <span class="hljs-string">&quot;bert-base-uncased&quot;</span>
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
<span class="hljs-keyword">def</span> <span class="hljs-title function_">tokenize_function</span>(<span class="hljs-params">example</span>):
<span class="hljs-keyword">return</span> tokenizer(example[<span class="hljs-string">&quot;sentence1&quot;</span>], example[<span class="hljs-string">&quot;sentence2&quot;</span>], truncation=<span class="hljs-literal">True</span>)
tokenized_datasets = raw_datasets.<span class="hljs-built_in">map</span>(tokenize_function, batched=<span class="hljs-literal">True</span>)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)`,wrap:!1}}),I=new ge({props:{title:"Prepárate para el entrenamiento",local:"prepárate-para-el-entrenamiento",headingTag:"h3"}}),g=new y({props:{code:"dG9rZW5pemVkX2RhdGFzZXRzJTIwJTNEJTIwdG9rZW5pemVkX2RhdGFzZXRzLnJlbW92ZV9jb2x1bW5zKCU1QiUyMnNlbnRlbmNlMSUyMiUyQyUyMCUyMnNlbnRlbmNlMiUyMiUyQyUyMCUyMmlkeCUyMiU1RCklMEF0b2tlbml6ZWRfZGF0YXNldHMlMjAlM0QlMjB0b2tlbml6ZWRfZGF0YXNldHMucmVuYW1lX2NvbHVtbiglMjJsYWJlbCUyMiUyQyUyMCUyMmxhYmVscyUyMiklMEF0b2tlbml6ZWRfZGF0YXNldHMuc2V0X2Zvcm1hdCglMjJ0b3JjaCUyMiklMEF0b2tlbml6ZWRfZGF0YXNldHMlNUIlMjJ0cmFpbiUyMiU1RC5jb2x1bW5fbmFtZXM=",highlighted:`tokenized_datasets = tokenized_datasets.remove_columns([<span class="hljs-string">&quot;sentence1&quot;</span>, <span class="hljs-string">&quot;sentence2&quot;</span>, <span class="hljs-string">&quot;idx&quot;</span>])
tokenized_datasets = tokenized_datasets.rename_column(<span class="hljs-string">&quot;label&quot;</span>, <span class="hljs-string">&quot;labels&quot;</span>)
tokenized_datasets.set_format(<span class="hljs-string">&quot;torch&quot;</span>)
tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>].column_names`,wrap:!1}}),X=new y({props:{code:"JTVCJTIyYXR0ZW50aW9uX21hc2slMjIlMkMlMjAlMjJpbnB1dF9pZHMlMjIlMkMlMjAlMjJsYWJlbHMlMjIlMkMlMjAlMjJ0b2tlbl90eXBlX2lkcyUyMiU1RA==",highlighted:'[<span class="hljs-string">&quot;attention_mask&quot;</span>, <span class="hljs-string">&quot;input_ids&quot;</span>, <span class="hljs-string">&quot;labels&quot;</span>, <span class="hljs-string">&quot;token_type_ids&quot;</span>]',wrap:!1}}),W=new y({props:{code:"ZnJvbSUyMHRvcmNoLnV0aWxzLmRhdGElMjBpbXBvcnQlMjBEYXRhTG9hZGVyJTBBJTBBdHJhaW5fZGF0YWxvYWRlciUyMCUzRCUyMERhdGFMb2FkZXIoJTBBJTIwJTIwJTIwJTIwdG9rZW5pemVkX2RhdGFzZXRzJTVCJTIydHJhaW4lMjIlNUQlMkMlMjBzaHVmZmxlJTNEVHJ1ZSUyQyUyMGJhdGNoX3NpemUlM0Q4JTJDJTIwY29sbGF0ZV9mbiUzRGRhdGFfY29sbGF0b3IlMEEpJTBBZXZhbF9kYXRhbG9hZGVyJTIwJTNEJTIwRGF0YUxvYWRlciglMEElMjAlMjAlMjAlMjB0b2tlbml6ZWRfZGF0YXNldHMlNUIlMjJ2YWxpZGF0aW9uJTIyJTVEJTJDJTIwYmF0Y2hfc2l6ZSUzRDglMkMlMjBjb2xsYXRlX2ZuJTNEZGF0YV9jb2xsYXRvciUwQSk=",highlighted:`<span class="hljs-keyword">from</span> torch.utils.data <span class="hljs-keyword">import</span> DataLoader
train_dataloader = DataLoader(
tokenized_datasets[<span class="hljs-string">&quot;train&quot;</span>], shuffle=<span class="hljs-literal">True</span>, batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator
)
eval_dataloader = DataLoader(
tokenized_datasets[<span class="hljs-string">&quot;validation&quot;</span>], batch_size=<span class="hljs-number">8</span>, collate_fn=data_collator
)`,wrap:!1}}),_=new y({props:{code:"Zm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjBicmVhayUwQSU3QmslM0ElMjB2LnNoYXBlJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0Q=",highlighted:`<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
<span class="hljs-keyword">break</span>
{k: v.shape <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}`,wrap:!1}}),Y=new y({props:{code:"JTdCJ2F0dGVudGlvbl9tYXNrJyUzQSUyMHRvcmNoLlNpemUoJTVCOCUyQyUyMDY1JTVEKSUyQyUwQSUyMCdpbnB1dF9pZHMnJTNBJTIwdG9yY2guU2l6ZSglNUI4JTJDJTIwNjUlNUQpJTJDJTBBJTIwJ2xhYmVscyclM0ElMjB0b3JjaC5TaXplKCU1QjglNUQpJTJDJTBBJTIwJ3Rva2VuX3R5cGVfaWRzJyUzQSUyMHRvcmNoLlNpemUoJTVCOCUyQyUyMDY1JTVEKSU3RA==",highlighted:`{<span class="hljs-string">&#x27;attention_mask&#x27;</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]),
<span class="hljs-string">&#x27;input_ids&#x27;</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>]),
<span class="hljs-string">&#x27;labels&#x27;</span>: torch.Size([<span class="hljs-number">8</span>]),
<span class="hljs-string">&#x27;token_type_ids&#x27;</span>: torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">65</span>])}`,wrap:!1}}),z=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMEElMEFtb2RlbCUyMCUzRCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24uZnJvbV9wcmV0cmFpbmVkKGNoZWNrcG9pbnQlMkMlMjBudW1fbGFiZWxzJTNEMik=",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)`,wrap:!1}}),E=new y({props:{code:"b3V0cHV0cyUyMCUzRCUyMG1vZGVsKCoqYmF0Y2gpJTBBcHJpbnQob3V0cHV0cy5sb3NzJTJDJTIwb3V0cHV0cy5sb2dpdHMuc2hhcGUp",highlighted:`outputs = model(**batch)
<span class="hljs-built_in">print</span>(outputs.loss, outputs.logits.shape)`,wrap:!1}}),F=new y({props:{code:"dGVuc29yKDAuNTQ0MSUyQyUyMGdyYWRfZm4lM0QlM0NObGxMb3NzQmFja3dhcmQlM0UpJTIwdG9yY2guU2l6ZSglNUI4JTJDJTIwMiU1RCk=",highlighted:'tensor(<span class="hljs-number">0.5441</span>, grad_fn=&lt;NllLossBackward&gt;) torch.Size([<span class="hljs-number">8</span>, <span class="hljs-number">2</span>])',wrap:!1}}),x=new y({props:{code:"ZnJvbSUyMHRvcmNoLm9wdGltJTIwaW1wb3J0JTIwQWRhbVclMEElMEFvcHRpbWl6ZXIlMjAlM0QlMjBBZGFtVyhtb2RlbC5wYXJhbWV0ZXJzKCklMkMlMjBsciUzRDVlLTUp",highlighted:`<span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">5e-5</span>)`,wrap:!1}}),q=new y({props:{code:"ZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMGdldF9zY2hlZHVsZXIlMEElMEFudW1fZXBvY2hzJTIwJTNEJTIwMyUwQW51bV90cmFpbmluZ19zdGVwcyUyMCUzRCUyMG51bV9lcG9jaHMlMjAqJTIwbGVuKHRyYWluX2RhdGFsb2FkZXIpJTBBbHJfc2NoZWR1bGVyJTIwJTNEJTIwZ2V0X3NjaGVkdWxlciglMEElMjAlMjAlMjAlMjAlMjJsaW5lYXIlMjIlMkMlMEElMjAlMjAlMjAlMjBvcHRpbWl6ZXIlM0RvcHRpbWl6ZXIlMkMlMEElMjAlMjAlMjAlMjBudW1fd2FybXVwX3N0ZXBzJTNEMCUyQyUwQSUyMCUyMCUyMCUyMG51bV90cmFpbmluZ19zdGVwcyUzRG51bV90cmFpbmluZ19zdGVwcyUyQyUwQSklMEFwcmludChudW1fdHJhaW5pbmdfc3RlcHMp",highlighted:`<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> get_scheduler
num_epochs = <span class="hljs-number">3</span>
num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader)
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)
<span class="hljs-built_in">print</span>(num_training_steps)`,wrap:!1}}),L=new y({props:{code:"MTM3Nw==",highlighted:'<span class="hljs-number">1377</span>',wrap:!1}}),K=new ge({props:{title:"El bucle de entrenamiento",local:"el-bucle-de-entrenamiento",headingTag:"h3"}}),D=new y({props:{code:"aW1wb3J0JTIwdG9yY2glMEElMEFkZXZpY2UlMjAlM0QlMjB0b3JjaC5kZXZpY2UoJTIyY3VkYSUyMiklMjBpZiUyMHRvcmNoLmN1ZGEuaXNfYXZhaWxhYmxlKCklMjBlbHNlJTIwdG9yY2guZGV2aWNlKCUyMmNwdSUyMiklMEFtb2RlbC50byhkZXZpY2UpJTBBZGV2aWNl",highlighted:`<span class="hljs-keyword">import</span> torch
device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
model.to(device)
device`,wrap:!1}}),O=new y({props:{code:"ZGV2aWNlKHR5cGUlM0QnY3VkYScp",highlighted:'device(<span class="hljs-built_in">type</span>=<span class="hljs-string">&#x27;cuda&#x27;</span>)',wrap:!1}}),le=new y({props:{code:"ZnJvbSUyMHRxZG0uYXV0byUyMGltcG9ydCUyMHRxZG0lMEElMEFwcm9ncmVzc19iYXIlMjAlM0QlMjB0cWRtKHJhbmdlKG51bV90cmFpbmluZ19zdGVwcykpJTBBJTBBbW9kZWwudHJhaW4oKSUwQWZvciUyMGVwb2NoJTIwaW4lMjByYW5nZShudW1fZXBvY2hzKSUzQSUwQSUyMCUyMCUyMCUyMGZvciUyMGJhdGNoJTIwaW4lMjB0cmFpbl9kYXRhbG9hZGVyJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwYmF0Y2glMjAlM0QlMjAlN0JrJTNBJTIwdi50byhkZXZpY2UpJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0QlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzJTIwJTNEJTIwb3V0cHV0cy5sb3NzJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbG9zcy5iYWNrd2FyZCgpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
model.train()
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs):
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),se=new ge({props:{title:"El bucle de evaluación",local:"el-bucle-de-evaluación",headingTag:"h3"}}),ne=new y({props:{code:"aW1wb3J0JTIwZXZhbHVhdGUlMEElMEFtZXRyaWMlMjAlM0QlMjBldmFsdWF0ZS5sb2FkKCUyMmdsdWUlMjIlMkMlMjAlMjJtcnBjJTIyKSUwQW1vZGVsLmV2YWwoKSUwQWZvciUyMGJhdGNoJTIwaW4lMjBldmFsX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjBiYXRjaCUyMCUzRCUyMCU3QmslM0ElMjB2LnRvKGRldmljZSklMjBmb3IlMjBrJTJDJTIwdiUyMGluJTIwYmF0Y2guaXRlbXMoKSU3RCUwQSUyMCUyMCUyMCUyMHdpdGglMjB0b3JjaC5ub19ncmFkKCklM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMEElMjAlMjAlMjAlMjBsb2dpdHMlMjAlM0QlMjBvdXRwdXRzLmxvZ2l0cyUwQSUyMCUyMCUyMCUyMHByZWRpY3Rpb25zJTIwJTNEJTIwdG9yY2guYXJnbWF4KGxvZ2l0cyUyQyUyMGRpbSUzRC0xKSUwQSUyMCUyMCUyMCUyMG1ldHJpYy5hZGRfYmF0Y2gocHJlZGljdGlvbnMlM0RwcmVkaWN0aW9ucyUyQyUyMHJlZmVyZW5jZXMlM0RiYXRjaCU1QiUyMmxhYmVscyUyMiU1RCklMEElMEFtZXRyaWMuY29tcHV0ZSgp",highlighted:`<span class="hljs-keyword">import</span> evaluate
metric = evaluate.load(<span class="hljs-string">&quot;glue&quot;</span>, <span class="hljs-string">&quot;mrpc&quot;</span>)
model.<span class="hljs-built_in">eval</span>()
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> eval_dataloader:
batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}
<span class="hljs-keyword">with</span> torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-<span class="hljs-number">1</span>)
metric.add_batch(predictions=predictions, references=batch[<span class="hljs-string">&quot;labels&quot;</span>])
metric.compute()`,wrap:!1}}),oe=new y({props:{code:"JTdCJ2FjY3VyYWN5JyUzQSUyMDAuODQzMTM3MjU0OTAxOTYwOCUyQyUyMCdmMSclM0ElMjAwLjg5MDc4NDk4MjkzNTE1MzUlN0Q=",highlighted:'{<span class="hljs-string">&#x27;accuracy&#x27;</span>: <span class="hljs-number">0.8431372549019608</span>, <span class="hljs-string">&#x27;f1&#x27;</span>: <span class="hljs-number">0.8907849829351535</span>}',wrap:!1}}),w=new Ja({props:{$$slots:{default:[ka]},$$scope:{ctx:Be}}}),pe=new ge({props:{title:"Repotencia tu bucle de entrenamiento con Accelerate 🤗",local:"repotencia-tu-bucle-de-entrenamiento-con-accelerate-",headingTag:"h3"}}),ie=new wa({props:{id:"s7dy8QRgjJ0"}}),me=new y({props:{code:"ZnJvbSUyMHRvcmNoLm9wdGltJTIwaW1wb3J0JTIwQWRhbVclMEFmcm9tJTIwdHJhbnNmb3JtZXJzJTIwaW1wb3J0JTIwQXV0b01vZGVsRm9yU2VxdWVuY2VDbGFzc2lmaWNhdGlvbiUyQyUyMGdldF9zY2hlZHVsZXIlMEElMEFtb2RlbCUyMCUzRCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24uZnJvbV9wcmV0cmFpbmVkKGNoZWNrcG9pbnQlMkMlMjBudW1fbGFiZWxzJTNEMiklMEFvcHRpbWl6ZXIlMjAlM0QlMjBBZGFtVyhtb2RlbC5wYXJhbWV0ZXJzKCklMkMlMjBsciUzRDNlLTUpJTBBJTBBZGV2aWNlJTIwJTNEJTIwdG9yY2guZGV2aWNlKCUyMmN1ZGElMjIpJTIwaWYlMjB0b3JjaC5jdWRhLmlzX2F2YWlsYWJsZSgpJTIwZWxzZSUyMHRvcmNoLmRldmljZSglMjJjcHUlMjIpJTBBbW9kZWwudG8oZGV2aWNlKSUwQSUwQW51bV9lcG9jaHMlMjAlM0QlMjAzJTBBbnVtX3RyYWluaW5nX3N0ZXBzJTIwJTNEJTIwbnVtX2Vwb2NocyUyMColMjBsZW4odHJhaW5fZGF0YWxvYWRlciklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTJDJTBBKSUwQSUwQXByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEFtb2RlbC50cmFpbigpJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBiYXRjaCUyMCUzRCUyMCU3QmslM0ElMjB2LnRvKGRldmljZSklMjBmb3IlMjBrJTJDJTIwdiUyMGluJTIwYmF0Y2guaXRlbXMoKSU3RCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMG91dHB1dHMlMjAlM0QlMjBtb2RlbCgqKmJhdGNoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MlMjAlM0QlMjBvdXRwdXRzLmxvc3MlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzLmJhY2t3YXJkKCklMEElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuc3RlcCgpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbHJfc2NoZWR1bGVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMG9wdGltaXplci56ZXJvX2dyYWQoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMHByb2dyZXNzX2Jhci51cGRhdGUoMSk=",highlighted:`<span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>)
device = torch.device(<span class="hljs-string">&quot;cuda&quot;</span>) <span class="hljs-keyword">if</span> torch.cuda.is_available() <span class="hljs-keyword">else</span> torch.device(<span class="hljs-string">&quot;cpu&quot;</span>)
model.to(device)
num_epochs = <span class="hljs-number">3</span>
num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dataloader)
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
model.train()
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs):
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dataloader:
batch = {k: v.to(device) <span class="hljs-keyword">for</span> k, v <span class="hljs-keyword">in</span> batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),de=new y({props:{code:"JTJCJTIwZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQSUyMCUyMGZyb20lMjB0b3JjaC5vcHRpbSUyMGltcG9ydCUyMEFkYW1XJTBBJTIwJTIwZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMkMlMjBnZXRfc2NoZWR1bGVyJTBBJTBBJTJCJTIwYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcigpJTBBJTBBJTIwJTIwbW9kZWwlMjAlM0QlMjBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50JTJDJTIwbnVtX2xhYmVscyUzRDIpJTBBJTIwJTIwb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0QzZS01KSUwQSUwQS0lMjBkZXZpY2UlMjAlM0QlMjB0b3JjaC5kZXZpY2UoJTIyY3VkYSUyMiklMjBpZiUyMHRvcmNoLmN1ZGEuaXNfYXZhaWxhYmxlKCklMjBlbHNlJTIwdG9yY2guZGV2aWNlKCUyMmNwdSUyMiklMEEtJTIwbW9kZWwudG8oZGV2aWNlKSUwQSUwQSUyQiUyMHRyYWluX2RhdGFsb2FkZXIlMkMlMjBldmFsX2RhdGFsb2FkZXIlMkMlMjBtb2RlbCUyQyUyMG9wdGltaXplciUyMCUzRCUyMGFjY2VsZXJhdG9yLnByZXBhcmUoJTBBJTJCJTIwJTIwJTIwJTIwJTIwdHJhaW5fZGF0YWxvYWRlciUyQyUyMGV2YWxfZGF0YWxvYWRlciUyQyUyMG1vZGVsJTJDJTIwb3B0aW1pemVyJTBBJTJCJTIwKSUwQSUwQSUyMCUyMG51bV9lcG9jaHMlMjAlM0QlMjAzJTBBJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTIwJTNEJTIwbnVtX2Vwb2NocyUyMColMjBsZW4odHJhaW5fZGF0YWxvYWRlciklMEElMjAlMjBscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTBBJTIwJTIwKSUwQSUwQSUyMCUyMHByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEElMjAlMjBtb2RlbC50cmFpbigpJTBBJTIwJTIwZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RhdGFsb2FkZXIlM0ElMEEtJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwYmF0Y2glMjAlM0QlMjAlN0JrJTNBJTIwdi50byhkZXZpY2UpJTIwZm9yJTIwayUyQyUyMHYlMjBpbiUyMGJhdGNoLml0ZW1zKCklN0QlMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvdXRwdXRzJTIwJTNEJTIwbW9kZWwoKipiYXRjaCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBsb3NzJTIwJTNEJTIwb3V0cHV0cy5sb3NzJTBBLSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxvc3MuYmFja3dhcmQoKSUwQSUyQiUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-addition">+ from accelerate import Accelerator</span>
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification, get_scheduler
<span class="hljs-addition">+ accelerator = Accelerator()</span>
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(), lr=3e-5)
<span class="hljs-deletion">- device = torch.device(&quot;cuda&quot;) if torch.cuda.is_available() else torch.device(&quot;cpu&quot;)</span>
<span class="hljs-deletion">- model.to(device)</span>
<span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(</span>
<span class="hljs-addition">+ train_dataloader, eval_dataloader, model, optimizer</span>
<span class="hljs-addition">+ )</span>
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
&quot;linear&quot;,
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
for batch in train_dataloader:
<span class="hljs-deletion">- batch = {k: v.to(device) for k, v in batch.items()}</span>
outputs = model(**batch)
loss = outputs.loss
<span class="hljs-deletion">- loss.backward()</span>
<span class="hljs-addition">+ accelerator.backward(loss)</span>
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)`,wrap:!1}}),U=new Ja({props:{$$slots:{default:[Xa]},$$scope:{ctx:Be}}}),je=new y({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBBY2NlbGVyYXRvciUwQWZyb20lMjB0b3JjaC5vcHRpbSUyMGltcG9ydCUyMEFkYW1XJTBBZnJvbSUyMHRyYW5zZm9ybWVycyUyMGltcG9ydCUyMEF1dG9Nb2RlbEZvclNlcXVlbmNlQ2xhc3NpZmljYXRpb24lMkMlMjBnZXRfc2NoZWR1bGVyJTBBJTBBYWNjZWxlcmF0b3IlMjAlM0QlMjBBY2NlbGVyYXRvcigpJTBBJTBBbW9kZWwlMjAlM0QlMjBBdXRvTW9kZWxGb3JTZXF1ZW5jZUNsYXNzaWZpY2F0aW9uLmZyb21fcHJldHJhaW5lZChjaGVja3BvaW50JTJDJTIwbnVtX2xhYmVscyUzRDIpJTBBb3B0aW1pemVyJTIwJTNEJTIwQWRhbVcobW9kZWwucGFyYW1ldGVycygpJTJDJTIwbHIlM0QzZS01KSUwQSUwQXRyYWluX2RsJTJDJTIwZXZhbF9kbCUyQyUyMG1vZGVsJTJDJTIwb3B0aW1pemVyJTIwJTNEJTIwYWNjZWxlcmF0b3IucHJlcGFyZSglMEElMjAlMjAlMjAlMjB0cmFpbl9kYXRhbG9hZGVyJTJDJTIwZXZhbF9kYXRhbG9hZGVyJTJDJTIwbW9kZWwlMkMlMjBvcHRpbWl6ZXIlMEEpJTBBJTBBbnVtX2Vwb2NocyUyMCUzRCUyMDMlMEFudW1fdHJhaW5pbmdfc3RlcHMlMjAlM0QlMjBudW1fZXBvY2hzJTIwKiUyMGxlbih0cmFpbl9kbCklMEFscl9zY2hlZHVsZXIlMjAlM0QlMjBnZXRfc2NoZWR1bGVyKCUwQSUyMCUyMCUyMCUyMCUyMmxpbmVhciUyMiUyQyUwQSUyMCUyMCUyMCUyMG9wdGltaXplciUzRG9wdGltaXplciUyQyUwQSUyMCUyMCUyMCUyMG51bV93YXJtdXBfc3RlcHMlM0QwJTJDJTBBJTIwJTIwJTIwJTIwbnVtX3RyYWluaW5nX3N0ZXBzJTNEbnVtX3RyYWluaW5nX3N0ZXBzJTJDJTBBKSUwQSUwQXByb2dyZXNzX2JhciUyMCUzRCUyMHRxZG0ocmFuZ2UobnVtX3RyYWluaW5nX3N0ZXBzKSklMEElMEFtb2RlbC50cmFpbigpJTBBZm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKG51bV9lcG9jaHMpJTNBJTBBJTIwJTIwJTIwJTIwZm9yJTIwYmF0Y2glMjBpbiUyMHRyYWluX2RsJTNBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3V0cHV0cyUyMCUzRCUyMG1vZGVsKCoqYmF0Y2gpJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwbG9zcyUyMCUzRCUyMG91dHB1dHMubG9zcyUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGFjY2VsZXJhdG9yLmJhY2t3YXJkKGxvc3MpJTBBJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwb3B0aW1pemVyLnN0ZXAoKSUwQSUyMCUyMCUyMCUyMCUyMCUyMCUyMCUyMGxyX3NjaGVkdWxlci5zdGVwKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBvcHRpbWl6ZXIuemVyb19ncmFkKCklMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjBwcm9ncmVzc19iYXIudXBkYXRlKDEp",highlighted:`<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-keyword">from</span> torch.optim <span class="hljs-keyword">import</span> AdamW
<span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModelForSequenceClassification, get_scheduler
accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=<span class="hljs-number">2</span>)
optimizer = AdamW(model.parameters(), lr=<span class="hljs-number">3e-5</span>)
train_dl, eval_dl, model, optimizer = accelerator.prepare(
train_dataloader, eval_dataloader, model, optimizer
)
num_epochs = <span class="hljs-number">3</span>
num_training_steps = num_epochs * <span class="hljs-built_in">len</span>(train_dl)
lr_scheduler = get_scheduler(
<span class="hljs-string">&quot;linear&quot;</span>,
optimizer=optimizer,
num_warmup_steps=<span class="hljs-number">0</span>,
num_training_steps=num_training_steps,
)
progress_bar = tqdm(<span class="hljs-built_in">range</span>(num_training_steps))
model.train()
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_epochs):
<span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_dl:
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(<span class="hljs-number">1</span>)`,wrap:!1}}),we=new y({props:{code:"YWNjZWxlcmF0ZSUyMGNvbmZpZw==",highlighted:"accelerate config",wrap:!1}}),Te=new y({props:{code:"YWNjZWxlcmF0ZSUyMGxhdW5jaCUyMHRyYWluLnB5",highlighted:'accelerate <span class="hljs-built_in">launch</span> train.py',wrap:!1}}),Ce=new y({props:{code:"ZnJvbSUyMGFjY2VsZXJhdGUlMjBpbXBvcnQlMjBub3RlYm9va19sYXVuY2hlciUwQSUwQW5vdGVib29rX2xhdW5jaGVyKHRyYWluaW5nX2Z1bmN0aW9uKQ==",highlighted:`<span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher
notebook_launcher(training_function)`,wrap:!1}}),Ie=new ga({props:{source:"https://github.com/huggingface/course/blob/main/chapters/es/chapter3/4.mdx"}}),{c(){b=M("meta"),j=t(),J=M("p"),$e=t(),o(T.$$.fragment),ke=t(),o(h.$$.fragment),Xe=t(),o(f.$$.fragment),ve=t(),C=M("p"),C.innerHTML=zl,We=t(),o(Z.$$.fragment),Re=t(),o(I.$$.fragment),_e=t(),B=M("p"),B.innerHTML=Nl,Ye=t(),$=M("ul"),$.innerHTML=El,Ae=t(),G=M("p"),G.innerHTML=Fl,Ve=t(),o(g.$$.fragment),ze=t(),k=M("p"),k.textContent=Ql,Ne=t(),o(X.$$.fragment),Ee=t(),v=M("p"),v.innerHTML=Hl,Fe=t(),o(W.$$.fragment),Qe=t(),R=M("p"),R.textContent=xl,He=t(),o(_.$$.fragment),xe=t(),o(Y.$$.fragment),Se=t(),A=M("p"),A.innerHTML=Sl,qe=t(),V=M("p"),V.textContent=ql,Le=t(),o(z.$$.fragment),Ke=t(),N=M("p"),N.textContent=Ll,Pe=t(),o(E.$$.fragment),De=t(),o(F.$$.fragment),Oe=t(),Q=M("p"),Q.innerHTML=Kl,el=t(),H=M("p"),H.innerHTML=Pl,ll=t(),o(x.$$.fragment),al=t(),S=M("p"),S.innerHTML=Dl,sl=t(),o(q.$$.fragment),tl=t(),o(L.$$.fragment),nl=t(),o(K.$$.fragment),ol=t(),P=M("p"),P.innerHTML=Ol,rl=t(),o(D.$$.fragment),pl=t(),o(O.$$.fragment),il=t(),ee=M("p"),ee.innerHTML=ea,cl=t(),o(le.$$.fragment),ml=t(),ae=M("p"),ae.textContent=la,Ml=t(),o(se.$$.fragment),dl=t(),te=M("p"),te.innerHTML=aa,ul=t(),o(ne.$$.fragment),yl=t(),o(oe.$$.fragment),bl=t(),re=M("p"),re.textContent=sa,jl=t(),o(w.$$.fragment),Jl=t(),o(pe.$$.fragment),wl=t(),o(ie.$$.fragment),Ul=t(),ce=M("p"),ce.innerHTML=ta,Tl=t(),o(me.$$.fragment),hl=t(),Me=M("p"),Me.textContent=na,fl=t(),o(de.$$.fragment),Cl=t(),ue=M("p"),ue.innerHTML=oa,Zl=t(),ye=M("p"),ye.innerHTML=ra,Il=t(),o(U.$$.fragment),Bl=t(),be=M("p"),be.textContent=pa,$l=t(),o(je.$$.fragment),Gl=t(),Je=M("p"),Je.innerHTML=ia,gl=t(),o(we.$$.fragment),kl=t(),Ue=M("p"),Ue.textContent=ca,Xl=t(),o(Te.$$.fragment),vl=t(),he=M("p"),he.textContent=ma,Wl=t(),fe=M("p"),fe.innerHTML=Ma,Rl=t(),o(Ce.$$.fragment),_l=t(),Ze=M("p"),Ze.innerHTML=da,Yl=t(),o(Ie.$$.fragment),Al=t(),Ge=M("p"),this.h()},l(e){const l=Za("svelte-u9bgzb",document.head);b=d(l,"META",{name:!0,content:!0}),l.forEach(a),j=n(e),J=d(e,"P",{}),ba(J).forEach(a),$e=n(e),r(T.$$.fragment,e),ke=n(e),r(h.$$.fragment,e),Xe=n(e),r(f.$$.fragment,e),ve=n(e),C=d(e,"P",{"data-svelte-h":!0}),u(C)!=="svelte-1ifzt2p"&&(C.innerHTML=zl),We=n(e),r(Z.$$.fragment,e),Re=n(e),r(I.$$.fragment,e),_e=n(e),B=d(e,"P",{"data-svelte-h":!0}),u(B)!=="svelte-22cbaw"&&(B.innerHTML=Nl),Ye=n(e),$=d(e,"UL",{"data-svelte-h":!0}),u($)!=="svelte-16hg34w"&&($.innerHTML=El),Ae=n(e),G=d(e,"P",{"data-svelte-h":!0}),u(G)!=="svelte-1cy77bo"&&(G.innerHTML=Fl),Ve=n(e),r(g.$$.fragment,e),ze=n(e),k=d(e,"P",{"data-svelte-h":!0}),u(k)!=="svelte-8bk5f4"&&(k.textContent=Ql),Ne=n(e),r(X.$$.fragment,e),Ee=n(e),v=d(e,"P",{"data-svelte-h":!0}),u(v)!=="svelte-18n18t9"&&(v.innerHTML=Hl),Fe=n(e),r(W.$$.fragment,e),Qe=n(e),R=d(e,"P",{"data-svelte-h":!0}),u(R)!=="svelte-nt5g3z"&&(R.textContent=xl),He=n(e),r(_.$$.fragment,e),xe=n(e),r(Y.$$.fragment,e),Se=n(e),A=d(e,"P",{"data-svelte-h":!0}),u(A)!=="svelte-1yvm52h"&&(A.innerHTML=Sl),qe=n(e),V=d(e,"P",{"data-svelte-h":!0}),u(V)!=="svelte-1954u46"&&(V.textContent=ql),Le=n(e),r(z.$$.fragment,e),Ke=n(e),N=d(e,"P",{"data-svelte-h":!0}),u(N)!=="svelte-o48clr"&&(N.textContent=Ll),Pe=n(e),r(E.$$.fragment,e),De=n(e),r(F.$$.fragment,e),Oe=n(e),Q=d(e,"P",{"data-svelte-h":!0}),u(Q)!=="svelte-1ozt76r"&&(Q.innerHTML=Kl),el=n(e),H=d(e,"P",{"data-svelte-h":!0}),u(H)!=="svelte-112fcrr"&&(H.innerHTML=Pl),ll=n(e),r(x.$$.fragment,e),al=n(e),S=d(e,"P",{"data-svelte-h":!0}),u(S)!=="svelte-1qdu95l"&&(S.innerHTML=Dl),sl=n(e),r(q.$$.fragment,e),tl=n(e),r(L.$$.fragment,e),nl=n(e),r(K.$$.fragment,e),ol=n(e),P=d(e,"P",{"data-svelte-h":!0}),u(P)!=="svelte-zx0gfa"&&(P.innerHTML=Ol),rl=n(e),r(D.$$.fragment,e),pl=n(e),r(O.$$.fragment,e),il=n(e),ee=d(e,"P",{"data-svelte-h":!0}),u(ee)!=="svelte-hoxau0"&&(ee.innerHTML=ea),cl=n(e),r(le.$$.fragment,e),ml=n(e),ae=d(e,"P",{"data-svelte-h":!0}),u(ae)!=="svelte-fku1tj"&&(ae.textContent=la),Ml=n(e),r(se.$$.fragment,e),dl=n(e),te=d(e,"P",{"data-svelte-h":!0}),u(te)!=="svelte-hvg8j7"&&(te.innerHTML=aa),ul=n(e),r(ne.$$.fragment,e),yl=n(e),r(oe.$$.fragment,e),bl=n(e),re=d(e,"P",{"data-svelte-h":!0}),u(re)!=="svelte-10ja0p6"&&(re.textContent=sa),jl=n(e),r(w.$$.fragment,e),Jl=n(e),r(pe.$$.fragment,e),wl=n(e),r(ie.$$.fragment,e),Ul=n(e),ce=d(e,"P",{"data-svelte-h":!0}),u(ce)!=="svelte-p7kfwl"&&(ce.innerHTML=ta),Tl=n(e),r(me.$$.fragment,e),hl=n(e),Me=d(e,"P",{"data-svelte-h":!0}),u(Me)!=="svelte-1sujeb6"&&(Me.textContent=na),fl=n(e),r(de.$$.fragment,e),Cl=n(e),ue=d(e,"P",{"data-svelte-h":!0}),u(ue)!=="svelte-pen57s"&&(ue.innerHTML=oa),Zl=n(e),ye=d(e,"P",{"data-svelte-h":!0}),u(ye)!=="svelte-1jgz58l"&&(ye.innerHTML=ra),Il=n(e),r(U.$$.fragment,e),Bl=n(e),be=d(e,"P",{"data-svelte-h":!0}),u(be)!=="svelte-1bt6f5g"&&(be.textContent=pa),$l=n(e),r(je.$$.fragment,e),Gl=n(e),Je=d(e,"P",{"data-svelte-h":!0}),u(Je)!=="svelte-19tcwhc"&&(Je.innerHTML=ia),gl=n(e),r(we.$$.fragment,e),kl=n(e),Ue=d(e,"P",{"data-svelte-h":!0}),u(Ue)!=="svelte-14mqyau"&&(Ue.textContent=ca),Xl=n(e),r(Te.$$.fragment,e),vl=n(e),he=d(e,"P",{"data-svelte-h":!0}),u(he)!=="svelte-hs7y8g"&&(he.textContent=ma),Wl=n(e),fe=d(e,"P",{"data-svelte-h":!0}),u(fe)!=="svelte-1lpwgeq"&&(fe.innerHTML=Ma),Rl=n(e),r(Ce.$$.fragment,e),_l=n(e),Ze=d(e,"P",{"data-svelte-h":!0}),u(Ze)!=="svelte-olr8oz"&&(Ze.innerHTML=da),Yl=n(e),r(Ie.$$.fragment,e),Al=n(e),Ge=d(e,"P",{}),ba(Ge).forEach(a),this.h()},h(){ja(b,"name","hf:doc:metadata"),ja(b,"content",Wa)},m(e,l){Ia(document.head,b),s(e,j,l),s(e,J,l),s(e,$e,l),p(T,e,l),s(e,ke,l),p(h,e,l),s(e,Xe,l),p(f,e,l),s(e,ve,l),s(e,C,l),s(e,We,l),p(Z,e,l),s(e,Re,l),p(I,e,l),s(e,_e,l),s(e,B,l),s(e,Ye,l),s(e,$,l),s(e,Ae,l),s(e,G,l),s(e,Ve,l),p(g,e,l),s(e,ze,l),s(e,k,l),s(e,Ne,l),p(X,e,l),s(e,Ee,l),s(e,v,l),s(e,Fe,l),p(W,e,l),s(e,Qe,l),s(e,R,l),s(e,He,l),p(_,e,l),s(e,xe,l),p(Y,e,l),s(e,Se,l),s(e,A,l),s(e,qe,l),s(e,V,l),s(e,Le,l),p(z,e,l),s(e,Ke,l),s(e,N,l),s(e,Pe,l),p(E,e,l),s(e,De,l),p(F,e,l),s(e,Oe,l),s(e,Q,l),s(e,el,l),s(e,H,l),s(e,ll,l),p(x,e,l),s(e,al,l),s(e,S,l),s(e,sl,l),p(q,e,l),s(e,tl,l),p(L,e,l),s(e,nl,l),p(K,e,l),s(e,ol,l),s(e,P,l),s(e,rl,l),p(D,e,l),s(e,pl,l),p(O,e,l),s(e,il,l),s(e,ee,l),s(e,cl,l),p(le,e,l),s(e,ml,l),s(e,ae,l),s(e,Ml,l),p(se,e,l),s(e,dl,l),s(e,te,l),s(e,ul,l),p(ne,e,l),s(e,yl,l),p(oe,e,l),s(e,bl,l),s(e,re,l),s(e,jl,l),p(w,e,l),s(e,Jl,l),p(pe,e,l),s(e,wl,l),p(ie,e,l),s(e,Ul,l),s(e,ce,l),s(e,Tl,l),p(me,e,l),s(e,hl,l),s(e,Me,l),s(e,fl,l),p(de,e,l),s(e,Cl,l),s(e,ue,l),s(e,Zl,l),s(e,ye,l),s(e,Il,l),p(U,e,l),s(e,Bl,l),s(e,be,l),s(e,$l,l),p(je,e,l),s(e,Gl,l),s(e,Je,l),s(e,gl,l),p(we,e,l),s(e,kl,l),s(e,Ue,l),s(e,Xl,l),p(Te,e,l),s(e,vl,l),s(e,he,l),s(e,Wl,l),s(e,fe,l),s(e,Rl,l),p(Ce,e,l),s(e,_l,l),s(e,Ze,l),s(e,Yl,l),p(Ie,e,l),s(e,Al,l),s(e,Ge,l),Vl=!0},p(e,[l]){const ua={};l&2&&(ua.$$scope={dirty:l,ctx:e}),w.$set(ua);const ya={};l&2&&(ya.$$scope={dirty:l,ctx:e}),U.$set(ya)},i(e){Vl||(i(T.$$.fragment,e),i(h.$$.fragment,e),i(f.$$.fragment,e),i(Z.$$.fragment,e),i(I.$$.fragment,e),i(g.$$.fragment,e),i(X.$$.fragment,e),i(W.$$.fragment,e),i(_.$$.fragment,e),i(Y.$$.fragment,e),i(z.$$.fragment,e),i(E.$$.fragment,e),i(F.$$.fragment,e),i(x.$$.fragment,e),i(q.$$.fragment,e),i(L.$$.fragment,e),i(K.$$.fragment,e),i(D.$$.fragment,e),i(O.$$.fragment,e),i(le.$$.fragment,e),i(se.$$.fragment,e),i(ne.$$.fragment,e),i(oe.$$.fragment,e),i(w.$$.fragment,e),i(pe.$$.fragment,e),i(ie.$$.fragment,e),i(me.$$.fragment,e),i(de.$$.fragment,e),i(U.$$.fragment,e),i(je.$$.fragment,e),i(we.$$.fragment,e),i(Te.$$.fragment,e),i(Ce.$$.fragment,e),i(Ie.$$.fragment,e),Vl=!0)},o(e){c(T.$$.fragment,e),c(h.$$.fragment,e),c(f.$$.fragment,e),c(Z.$$.fragment,e),c(I.$$.fragment,e),c(g.$$.fragment,e),c(X.$$.fragment,e),c(W.$$.fragment,e),c(_.$$.fragment,e),c(Y.$$.fragment,e),c(z.$$.fragment,e),c(E.$$.fragment,e),c(F.$$.fragment,e),c(x.$$.fragment,e),c(q.$$.fragment,e),c(L.$$.fragment,e),c(K.$$.fragment,e),c(D.$$.fragment,e),c(O.$$.fragment,e),c(le.$$.fragment,e),c(se.$$.fragment,e),c(ne.$$.fragment,e),c(oe.$$.fragment,e),c(w.$$.fragment,e),c(pe.$$.fragment,e),c(ie.$$.fragment,e),c(me.$$.fragment,e),c(de.$$.fragment,e),c(U.$$.fragment,e),c(je.$$.fragment,e),c(we.$$.fragment,e),c(Te.$$.fragment,e),c(Ce.$$.fragment,e),c(Ie.$$.fragment,e),Vl=!1},d(e){e&&(a(j),a(J),a($e),a(ke),a(Xe),a(ve),a(C),a(We),a(Re),a(_e),a(B),a(Ye),a($),a(Ae),a(G),a(Ve),a(ze),a(k),a(Ne),a(Ee),a(v),a(Fe),a(Qe),a(R),a(He),a(xe),a(Se),a(A),a(qe),a(V),a(Le),a(Ke),a(N),a(Pe),a(De),a(Oe),a(Q),a(el),a(H),a(ll),a(al),a(S),a(sl),a(tl),a(nl),a(ol),a(P),a(rl),a(pl),a(il),a(ee),a(cl),a(ml),a(ae),a(Ml),a(dl),a(te),a(ul),a(yl),a(bl),a(re),a(jl),a(Jl),a(wl),a(Ul),a(ce),a(Tl),a(hl),a(Me),a(fl),a(Cl),a(ue),a(Zl),a(ye),a(Il),a(Bl),a(be),a($l),a(Gl),a(Je),a(gl),a(kl),a(Ue),a(Xl),a(vl),a(he),a(Wl),a(fe),a(Rl),a(_l),a(Ze),a(Yl),a(Al),a(Ge)),a(b),m(T,e),m(h,e),m(f,e),m(Z,e),m(I,e),m(g,e),m(X,e),m(W,e),m(_,e),m(Y,e),m(z,e),m(E,e),m(F,e),m(x,e),m(q,e),m(L,e),m(K,e),m(D,e),m(O,e),m(le,e),m(se,e),m(ne,e),m(oe,e),m(w,e),m(pe,e),m(ie,e),m(me,e),m(de,e),m(U,e),m(je,e),m(we,e),m(Te,e),m(Ce,e),m(Ie,e)}}}const Wa='{"title":"Un entrenamiento completo","local":"un-entrenamiento-completo","sections":[{"title":"Prepárate para el entrenamiento","local":"prepárate-para-el-entrenamiento","sections":[],"depth":3},{"title":"El bucle de entrenamiento","local":"el-bucle-de-entrenamiento","sections":[],"depth":3},{"title":"El bucle de evaluación","local":"el-bucle-de-evaluación","sections":[],"depth":3},{"title":"Repotencia tu bucle de entrenamiento con Accelerate 🤗","local":"repotencia-tu-bucle-de-entrenamiento-con-accelerate-","sections":[],"depth":3}],"depth":1}';function Ra(Be){return Ta(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class Qa extends fa{constructor(b){super(),Ca(this,b,Ra,va,Ua,{})}}export{Qa as component};

Xet Storage Details

Size:
44.5 kB
·
Xet hash:
fcf6db55756d0cce2d077668529a91b87015666a5a7aaac7f53196eb4ca980fe

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.