Trainer [[trainer]]
[Trainer]λ Transformers λΌμ΄λΈλ¬λ¦¬μ ꡬνλ PyTorch λͺ¨λΈμ λ°λ³΅νμ¬ νλ ¨ λ° νκ° κ³Όμ μ
λλ€. νλ ¨μ νμν μμ(λͺ¨λΈ, ν ν¬λμ΄μ , λ°μ΄ν°μ
, νκ° ν¨μ, νλ ¨ νμ΄νΌνλΌλ―Έν° λ±)λ§ μ 곡νλ©΄ [Trainer]κ° νμν λλ¨Έμ§ μμ
μ μ²λ¦¬ν©λλ€. μ΄λ₯Ό ν΅ν΄ μ§μ νλ ¨ 루νλ₯Ό μμ±νμ§ μκ³ λ λΉ λ₯΄κ² νλ ¨μ μμν μ μμ΅λλ€. λν [Trainer]λ κ°λ ₯ν λ§μΆ€ μ€μ κ³Ό λ€μν νλ ¨ μ΅μ
μ μ 곡νμ¬ μ¬μ©μ λ§μΆ€ νλ ¨μ΄ κ°λ₯ν©λλ€.
Transformersλ [Trainer] ν΄λμ€ μΈμλ λ²μμ΄λ μμ½κ³Ό κ°μ μνμ€-ν¬-μνμ€ μμ
μ μν [Seq2SeqTrainer] ν΄λμ€λ μ 곡ν©λλ€. λν TRL λΌμ΄λΈλ¬λ¦¬μλ [Trainer] ν΄λμ€λ₯Ό κ°μΈκ³ Llama-2 λ° Mistralκ³Ό κ°μ μΈμ΄ λͺ¨λΈμ μλ νκ· κΈ°λ²μΌλ‘ νλ ¨νλ λ° μ΅μ νλ [~trl.SFTTrainer] ν΄λμ€ μ
λλ€. [~trl.SFTTrainer]λ μνμ€ ν¨νΉ, LoRA, μμν λ° DeepSpeedμ κ°μ κΈ°λ₯μ μ§μνμ¬ ν¬κΈ° μκ΄μμ΄ λͺ¨λΈ ν¨μ¨μ μΌλ‘ νμ₯ν μ μμ΅λλ€.
μ΄λ€ λ€λ₯Έ [Trainer] μ ν ν΄λμ€μ λν΄ λ μκ³ μΆλ€λ©΄ API μ°Έμ‘°λ₯Ό νμΈνμ¬ μΈμ μ΄λ€ ν΄λμ€κ° μ ν©ν μ§ μΌλ§λ μ§ νμΈνμΈμ. μΌλ°μ μΌλ‘ [Trainer]λ κ°μ₯ λ€μ¬λ€λ₯ν μ΅μ
μΌλ‘, λ€μν μμ
μ μ ν©ν©λλ€. [Seq2SeqTrainer]λ μνμ€-ν¬-μνμ€ μμ
μ μν΄ μ€κ³λμκ³ , [~trl.SFTTrainer]λ μΈμ΄ λͺ¨λΈ νλ ¨μ μν΄ μ€κ³λμμ΅λλ€.
μμνκΈ° μ μ, λΆμ° νκ²½μμ PyTorch νλ ¨κ³Ό μ€νμ ν μ μκ² Accelerate λΌμ΄λΈλ¬λ¦¬κ° μ€μΉλμλμ§ νμΈνμΈμ.
pip install accelerate
# μ
κ·Έλ μ΄λ
pip install accelerate --upgrade
μ΄ κ°μ΄λλ [Trainer] ν΄λμ€μ λν κ°μλ₯Ό μ 곡ν©λλ€.
κΈ°λ³Έ μ¬μ©λ² [[basic-usage]]
[Trainer]λ κΈ°λ³Έμ μΈ νλ ¨ 루νμ νμν λͺ¨λ μ½λλ₯Ό ν¬ν¨νκ³ μμ΅λλ€.
- μμ€μ κ³μ°νλ νλ ¨ λ¨κ³λ₯Ό μνν©λλ€.
- [
~accelerate.Accelerator.backward] λ©μλλ‘ κ·Έλ μ΄λμΈνΈλ₯Ό κ³μ°ν©λλ€. - κ·Έλ μ΄λμΈνΈλ₯Ό κΈ°λ°μΌλ‘ κ°μ€μΉλ₯Ό μ λ°μ΄νΈν©λλ€.
- μ ν΄μ§ μν μμ λλ¬ν λκΉμ§ μ΄ κ³Όμ μ λ°λ³΅ν©λλ€.
[Trainer] ν΄λμ€λ PyTorchμ νλ ¨ κ³Όμ μ μ΅μνμ§ μκ±°λ λ§ μμν κ²½μ°μλ νλ ¨μ΄ κ°λ₯νλλ‘ νμν λͺ¨λ μ½λλ₯Ό μΆμννμμ΅λλ€. λν λ§€λ² νλ ¨ 루νλ₯Ό μμ μμ±νμ§ μμλ λλ©°, νλ ¨μ νμν λͺ¨λΈκ³Ό λ°μ΄ν°μ
κ°μ νμ κ΅¬μ± μμλ§ μ 곡νλ©΄, [Trainer] ν΄λμ€κ° λλ¨Έμ§λ₯Ό μ²λ¦¬ν©λλ€.
νλ ¨ μ΅μ
μ΄λ νμ΄νΌνλΌλ―Έν°λ₯Ό μ§μ νλ €λ©΄, [TrainingArguments] ν΄λμ€μμ νμΈ ν μ μμ΅λλ€. μλ₯Ό λ€μ΄, λͺ¨λΈμ μ μ₯ν λλ ν 리λ₯Ό output_dirμ μ μνκ³ , νλ ¨ νμ Hubλ‘ λͺ¨λΈμ νΈμνλ €λ©΄ push_to_hub=Trueλ‘ μ€μ ν©λλ€.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="your-model",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
)
training_argsλ₯Ό [Trainer]μ λͺ¨λΈ, λ°μ΄ν°μ
, λ°μ΄ν°μ
μ μ²λ¦¬ λꡬ(λ°μ΄ν° μ νμ λ°λΌ ν ν¬λμ΄μ , νΉμ§ μΆμΆκΈ° λλ μ΄λ―Έμ§ νλ‘μΈμμΌ μ μμ), λ°μ΄ν° μμ§κΈ° λ° νλ ¨ μ€ νμΈν μ§νλ₯Ό κ³μ°ν ν¨μλ₯Ό ν¨κ» μ λ¬νμΈμ.
λ§μ§λ§μΌλ‘, [~Trainer.train]λ₯Ό νΈμΆνμ¬ νλ ¨μ μμνμΈμ!
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
체ν¬ν¬μΈνΈ [[checkpoints]]
[Trainer] ν΄λμ€λ [TrainingArguments]μ output_dir λ§€κ°λ³μμ μ§μ λ λλ ν 리μ λͺ¨λΈ 체ν¬ν¬μΈνΈλ₯Ό μ μ₯ν©λλ€. 체ν¬ν¬μΈνΈλ checkpoint-000 νμ ν΄λμ μ μ₯λλ©°, μ¬κΈ°μ λμ μ«μλ νλ ¨ λ¨κ³μ ν΄λΉν©λλ€. 체ν¬ν¬μΈνΈλ₯Ό μ μ₯νλ©΄ λμ€μ νλ ¨μ μ¬κ°ν λ μ μ©ν©λλ€.
# μ΅μ 체ν¬ν¬μΈνΈμμ μ¬κ°
trainer.train(resume_from_checkpoint=True)
# μΆλ ₯ λλ ν 리μ μ μ₯λ νΉμ 체ν¬ν¬μΈνΈμμ μ¬κ°
trainer.train(resume_from_checkpoint="your-model/checkpoint-1000")
체ν¬ν¬μΈνΈλ₯Ό Hubμ νΈμνλ €λ©΄ [TrainingArguments]μμ push_to_hub=Trueλ‘ μ€μ νμ¬ μ»€λ°νκ³ νΈμν μ μμ΅λλ€. 체ν¬ν¬μΈνΈ μ μ₯ λ°©λ²μ κ²°μ νλ λ€λ₯Έ μ΅μ
μ hub_strategy λ§€κ°λ³μμμ μ€μ ν©λλ€:
hub_strategy="checkpoint"λ μ΅μ 체ν¬ν¬μΈνΈλ₯Ό "last-checkpoint"λΌλ νμ ν΄λμ νΈμνμ¬ νλ ¨μ μ¬κ°ν μ μμ΅λλ€.hub_strategy="all_checkpoints"λ λͺ¨λ 체ν¬ν¬μΈνΈλ₯Όoutput_dirμ μ μλ λλ ν 리μ νΈμν©λλ€(λͺ¨λΈ 리ν¬μ§ν 리μμ ν΄λλΉ νλμ 체ν¬ν¬μΈνΈλ₯Ό λ³Ό μ μμ΅λλ€).
체ν¬ν¬μΈνΈμμ νλ ¨μ μ¬κ°ν λ, [Trainer]λ 체ν¬ν¬μΈνΈκ° μ μ₯λ λμ λμΌν Python, NumPy λ° PyTorch RNG μνλ₯Ό μ μ§νλ €κ³ ν©λλ€. νμ§λ§ PyTorchλ κΈ°λ³Έ μ€μ μΌλ‘ 'μΌκ΄λ κ²°κ³Όλ₯Ό 보μ₯νμ§ μμ'μΌλ‘ λ§μ΄ λμ΄μκΈ° λλ¬Έμ, RNG μνκ° λμΌν κ²μ΄λΌκ³ 보μ₯ν μ μμ΅λλ€. λ°λΌμ, μΌκ΄λ κ²°κ³Όκ° λ³΄μ₯λλλ‘ νμ±ν νλ €λ©΄, λλ€μ± μ μ΄ κ°μ΄λλ₯Ό μ°Έκ³ νμ¬ νλ ¨μ μμ ν μΌκ΄λ κ²°κ³Όλ₯Ό 보μ₯ λ°λλ‘ λ§λ€κΈ° μν΄ νμ±νν μ μλ νλͺ©μ νμΈνμΈμ. λ€λ§, νΉμ μ€μ μ κ²°μ μ μΌλ‘ λ§λ€λ©΄ νλ ¨μ΄ λλ €μ§ μ μμ΅λλ€.
Trainer λ§μΆ€ μ€μ [[customize-the-trainer]]
[Trainer] ν΄λμ€λ μ κ·Όμ±κ³Ό μ©μ΄μ±μ μΌλμ λκ³ μ€κ³λμμ§λ§, λ λ€μν κΈ°λ₯μ μνλ μ¬μ©μλ€μ μν΄ λ€μν λ§μΆ€ μ€μ μ΅μ
μ μ 곡ν©λλ€. [Trainer]μ λ§μ λ©μλλ μλΈν΄λμ€ν λ° μ€λ²λΌμ΄λνμ¬ μνλ κΈ°λ₯μ μ 곡ν μ μμΌλ©°, μ΄λ₯Ό ν΅ν΄ μ 체 νλ ¨ 루νλ₯Ό λ€μ μμ±ν νμ μμ΄ μνλ κΈ°λ₯μ μΆκ°ν μ μμ΅λλ€. μ΄λ¬ν λ©μλμλ λ€μμ΄ ν¬ν¨λ©λλ€:
- [
~Trainer.get_train_dataloader]λ νλ ¨ λ°μ΄ν°λ‘λλ₯Ό μμ±ν©λλ€. - [
~Trainer.get_eval_dataloader]λ νκ° λ°μ΄ν°λ‘λλ₯Ό μμ±ν©λλ€. - [
~Trainer.get_test_dataloader]λ ν μ€νΈ λ°μ΄ν°λ‘λλ₯Ό μμ±ν©λλ€. - [
~Trainer.log]λ νλ ¨μ λͺ¨λν°λ§νλ λ€μν κ°μ²΄μ λν μ 보λ₯Ό λ‘κ·Έλ‘ λ¨κΉλλ€. - [
~Trainer.create_optimizer_and_scheduler]λ__init__μμ μ λ¬λμ§ μμ κ²½μ° μ΅ν°λ§μ΄μ μ νμ΅λ₯ μ€μΌμ€λ¬λ₯Ό μμ±ν©λλ€. μ΄λ€μ κ°κ° [~Trainer.create_optimizer] λ° [~Trainer.create_scheduler]λ‘ λ³λλ‘ λ§μΆ€ μ€μ ν μ μμ΅λλ€. - [
~Trainer.compute_loss]λ νλ ¨ μ λ ₯ λ°°μΉμ λν μμ€μ κ³μ°ν©λλ€. - [
~Trainer.training_step]λ νλ ¨ λ¨κ³λ₯Ό μνν©λλ€. - [
~Trainer.prediction_step]λ μμΈ‘ λ° ν μ€νΈ λ¨κ³λ₯Ό μνν©λλ€. - [
~Trainer.evaluate]λ λͺ¨λΈμ νκ°νκ³ νκ° μ§νμ λ°νν©λλ€. - [
~Trainer.predict]λ ν μ€νΈ μΈνΈμ λν μμΈ‘(λ μ΄λΈμ΄ μλ κ²½μ° μ§ν ν¬ν¨)μ μνν©λλ€.
μλ₯Ό λ€μ΄, [~Trainer.compute_loss] λ©μλλ₯Ό λ§μΆ€ μ€μ νμ¬ κ°μ€ μμ€μ μ¬μ©νλ €λ κ²½μ°:
from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
def compute_loss(self,
model, inputs, return_outputs=False):
labels = inputs.pop("labels")
# μλ°©ν₯ μ ν
outputs = model(**inputs)
logits = outputs.get("logits")
# μλ‘ λ€λ₯Έ κ°μ€μΉλ‘ 3κ°μ λ μ΄λΈμ λν μ¬μ©μ μ μ μμ€μ κ³μ°
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
μ½λ°± [[callbacks]]
[Trainer]λ₯Ό λ§μΆ€ μ€μ νλ λ λ€λ₯Έ λ°©λ²μ μ½λ°±μ μ¬μ©νλ κ²μ
λλ€. μ½λ°±μ νλ ¨ 루νμμ λ³νλ₯Ό μ£Όμ§ μμ΅λλ€. νλ ¨ 루νμ μνλ₯Ό κ²μ¬ν ν μνμ λ°λΌ μΌλΆ μμ
(μ‘°κΈ° μ’
λ£, κ²°κ³Ό λ‘κ·Έ λ±)μ μ€νν©λλ€. μ¦, μ½λ°±μ μ¬μ©μ μ μ μμ€ ν¨μμ κ°μ κ²μ ꡬννλ λ° μ¬μ©ν μ μμΌλ©°, μ΄λ₯Ό μν΄μλ [~Trainer.compute_loss] λ©μλλ₯Ό μλΈν΄λμ€ννκ³ μ€λ²λΌμ΄λν΄μΌ ν©λλ€.
μλ₯Ό λ€μ΄, νλ ¨ 루νμ 10λ¨κ³ ν μ‘°κΈ° μ’ λ£ μ½λ°±μ μΆκ°νλ €λ©΄ λ€μκ³Ό κ°μ΄ ν©λλ€.
from transformers import TrainerCallback
class EarlyStoppingCallback(TrainerCallback):
def __init__(self, num_steps=10):
self.num_steps = num_steps
def on_step_end(self, args, state, control, **kwargs):
if state.global_step >= self.num_steps:
return {"should_training_stop": True}
else:
return {}
κ·Έλ° λ€μ, μ΄λ₯Ό [Trainer]μ callback λ§€κ°λ³μμ μ λ¬ν©λλ€.
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()],
)
λ‘κΉ [[logging]]
λ‘κΉ APIμ λν μμΈν λ΄μ©μ λ‘κΉ API λ νΌλ°μ€λ₯Ό νμΈνμΈμ.
[Trainer]λ κΈ°λ³Έμ μΌλ‘ logging.INFOλ‘ μ€μ λμ΄ μμ΄ μ€λ₯, κ²½κ³ λ° κΈ°ν κΈ°λ³Έ μ 보λ₯Ό λ³΄κ³ ν©λλ€. λΆμ° νκ²½μμλ [Trainer] 볡μ λ³Έμ΄ logging.WARNINGμΌλ‘ μ€μ λμ΄ μ€λ₯μ κ²½κ³ λ§ λ³΄κ³ ν©λλ€. [TrainingArguments]μ log_level λ° log_level_replica λ§€κ°λ³μλ‘ λ‘κ·Έ λ 벨μ λ³κ²½ν μ μμ΅λλ€.
κ° λ
Έλμ λ‘κ·Έ λ 벨 μ€μ μ ꡬμ±νλ €λ©΄ log_on_each_node λ§€κ°λ³μλ₯Ό μ¬μ©νμ¬ κ° λ
Έλμμ λ‘κ·Έ λ 벨μ μ¬μ©ν μ§ μλλ©΄ μ£Ό λ
Έλμμλ§ μ¬μ©ν μ§ κ²°μ νμΈμ.
[Trainer]λ [Trainer.__init__] λ©μλμμ κ° λ
Έλμ λν΄ λ‘κ·Έ λ 벨μ λ³λλ‘ μ€μ νλ―λ‘, λ€λ₯Έ Transformers κΈ°λ₯μ μ¬μ©ν κ²½μ° [Trainer] κ°μ²΄λ₯Ό μμ±νκΈ° μ μ μ΄λ₯Ό 미리 μ€μ νλ κ²μ΄ μ’μ΅λλ€.
μλ₯Ό λ€μ΄, λ©μΈ μ½λμ λͺ¨λμ κ° λ Έλμ λ°λΌ λμΌν λ‘κ·Έ λ 벨μ μ¬μ©νλλ‘ μ€μ νλ €λ©΄ λ€μκ³Ό κ°μ΄ ν©λλ€.
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)
κ° λ
Έλμμ κΈ°λ‘λ λ΄μ©μ ꡬμ±νκΈ° μν΄ log_levelκ³Ό log_level_replicaλ₯Ό λ€μν μ‘°ν©μΌλ‘ μ¬μ©ν΄λ³΄μΈμ.
my_app.py ... --log_level warning --log_level_replica error
λ©ν° λ
Έλ νκ²½μμλ log_on_each_node 0 λ§€κ°λ³μλ₯Ό μΆκ°ν©λλ€.
my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0
# μ€λ₯λ§ λ³΄κ³ νλλ‘ μ€μ
my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0
NEFTune [[neftune]]
NEFTuneμ νλ ¨ μ€ μλ² λ© λ²‘ν°μ λ
Έμ΄μ¦λ₯Ό μΆκ°νμ¬ μ±λ₯μ ν₯μμν¬ μ μλ κΈ°μ μ
λλ€. [Trainer]μμ μ΄λ₯Ό νμ±ννλ €λ©΄ [TrainingArguments]μ neftune_noise_alpha λ§€κ°λ³μλ₯Ό μ€μ νμ¬ λ
Έμ΄μ¦μ μμ μ‘°μ ν©λλ€.
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(..., neftune_noise_alpha=0.1)
trainer = Trainer(..., args=training_args)
NEFTuneμ μμμΉ λͺ»ν λμμ νΌν λͺ©μ μΌλ‘ μ²μ μλ² λ© λ μ΄μ΄λ‘ 볡μνκΈ° μν΄ νλ ¨ ν λΉνμ±ν λ©λλ€.
GaLore [[galore]]
Gradient Low-Rank Projection (GaLore)μ μ 체 λ§€κ°λ³μλ₯Ό νμ΅νλ©΄μλ LoRAμ κ°μ μΌλ°μ μΈ μ κ³μ μ μ λ°©λ²λ³΄λ€ λ λ©λͺ¨λ¦¬ ν¨μ¨μ μΈ μ κ³μ νμ΅ μ λ΅μ λλ€.
λ¨Όμ GaLore 곡μ 리ν¬μ§ν 리λ₯Ό μ€μΉν©λλ€:
pip install galore-torch
κ·Έλ° λ€μ optimμ ["galore_adamw", "galore_adafactor", "galore_adamw_8bit"] μ€ νλμ ν¨κ» optim_target_modulesλ₯Ό μΆκ°ν©λλ€. μ΄λ μ μ©νλ €λ λμ λͺ¨λ μ΄λ¦μ ν΄λΉνλ λ¬Έμμ΄, μ κ· ννμ λλ μ 체 κ²½λ‘μ λͺ©λ‘μΌ μ μμ΅λλ€. μλλ end-to-end μμ μ€ν¬λ¦½νΈμ
λλ€(νμν κ²½μ° pip install trl datasetsλ₯Ό μ€ν):
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
GaLoreκ° μ§μνλ μΆκ° λ§€κ°λ³μλ₯Ό μ λ¬νλ €λ©΄ optim_argsλ₯Ό μ€μ ν©λλ€. μλ₯Ό λ€μ΄:
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"],
optim_args="rank=64, update_proj_gap=100, scale=0.10",
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
ν΄λΉ λ°©λ²μ λν μμΈν λ΄μ©μ μλ³Έ 리ν¬μ§ν 리 λλ λ Όλ¬Έμ μ°Έκ³ νμΈμ.
νμ¬ GaLore λ μ΄μ΄λ‘ κ°μ£Όλλ Linear λ μ΄μ΄λ§ νλ ¨ ν μ μμΌλ©°, μ κ³μ λΆν΄λ₯Ό μ¬μ©νμ¬ νλ ¨λκ³ λλ¨Έμ§ λ μ΄μ΄λ κΈ°μ‘΄ λ°©μμΌλ‘ μ΅μ νλ©λλ€.
νλ ¨ μμ μ μ μκ°μ΄ μ½κ° 걸릴 μ μμ΅λλ€(NVIDIA A100μμ 2B λͺ¨λΈμ κ²½μ° μ½ 3λΆ), νμ§λ§ μ΄ν νλ ¨μ μννκ² μ§νλ©λλ€.
λ€μκ³Ό κ°μ΄ μ΅ν°λ§μ΄μ μ΄λ¦μ layerwiseλ₯Ό μΆκ°νμ¬ λ μ΄μ΄λ³ μ΅μ νλ₯Ό μνν μλ μμ΅λλ€:
import torch
import datasets
import trl
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw_layerwise",
optim_target_modules=["attn", "mlp"]
)
model_id = "google/gemma-2b"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)
trainer.train()
λ μ΄μ΄λ³ μ΅μ νλ λ€μ μ€νμ μ΄λ©° DDP(λΆμ° λ°μ΄ν° λ³λ ¬)λ₯Ό μ§μνμ§ μμΌλ―λ‘, λ¨μΌ GPUμμλ§ νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€νν μ μμ΅λλ€. μμΈν λ΄μ©μ μ΄ λ¬Έμλ₯Όμ μ°Έμ‘°νμΈμ. gradient clipping, DeepSpeed λ± λ€λ₯Έ κΈ°λ₯μ κΈ°λ³Έμ μΌλ‘ μ§μλμ§ μμ μ μμ΅λλ€. μ΄λ¬ν λ¬Έμ κ° λ°μνλ©΄ GitHubμ μ΄μλ₯Ό μ¬λ €μ£ΌμΈμ.
LOMO μ΅ν°λ§μ΄μ [[lomo-optimizer]]
LOMO μ΅ν°λ§μ΄μ λ μ νλ μμμΌλ‘ λν μΈμ΄ λͺ¨λΈμ μ 체 λ§€κ°λ³μ λ―ΈμΈ μ‘°μ κ³Ό μ μν νμ΅λ₯ μ ν΅ν μ λ©λͺ¨λ¦¬ μ΅μ ν(AdaLomo)μμ λμ
λμμ΅λλ€.
μ΄λ€μ λͺ¨λ ν¨μ¨μ μΈ μ 체 λ§€κ°λ³μ λ―ΈμΈ μ‘°μ λ°©λ²μΌλ‘ ꡬμ±λμ΄ μμ΅λλ€. μ΄λ¬ν μ΅ν°λ§μ΄μ λ€μ λ©λͺ¨λ¦¬ μ¬μ©λμ μ€μ΄κΈ° μν΄ κ·Έλ μ΄λμΈνΈ κ³μ°κ³Ό λ§€κ°λ³μ μ
λ°μ΄νΈλ₯Ό νλμ λ¨κ³λ‘ μ΅ν©ν©λλ€. LOMOμμ μ§μλλ μ΅ν°λ§μ΄μ λ "lomo"μ "adalomo"μ
λλ€. λ¨Όμ pypiμμ pip install lomo-optimλ₯Ό ν΅ν΄ lomoλ₯Ό μ€μΉνκ±°λ, GitHub μμ€μμ pip install git+https://github.com/OpenLMLab/LOMO.gitλ‘ μ€μΉνμΈμ.
μ μμ λ°λ₯΄λ©΄, grad_norm μμ΄ AdaLomoλ₯Ό μ¬μ©νλ κ²μ΄ λ λμ μ±λ₯κ³Ό λμ μ²λ¦¬λμ μ 곡νλ€κ³ ν©λλ€.
λ€μμ IMDB λ°μ΄ν°μ μμ google/gemma-2bλ₯Ό μ΅λ μ λ°λλ‘ λ―ΈμΈ μ‘°μ νλ κ°λ¨ν μ€ν¬λ¦½νΈμ λλ€:
import torch
import datasets
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import trl
train_dataset = datasets.load_dataset('imdb', split='train')
args = TrainingArguments(
output_dir="./test-lomo",
max_steps=1000,
per_device_train_batch_size=4,
optim="adalomo",
gradient_checkpointing=True,
logging_strategy="steps",
logging_steps=1,
learning_rate=2e-6,
save_strategy="no",
run_name="lomo-imdb",
)
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(0)
trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=1024,
)
trainer.train()
Accelerateμ Trainer [[accelerate-and-trainer]]
[Trainer] ν΄λμ€λ Accelerateλ‘ κ΅¬λλλ©°, μ΄λ FullyShardedDataParallel (FSDP) λ° DeepSpeedμ κ°μ ν΅ν©μ μ§μνλ λΆμ° νκ²½μμ PyTorch λͺ¨λΈμ μ½κ² νλ ¨ν μ μλ λΌμ΄λΈλ¬λ¦¬μ
λλ€.
FSDP μ€λ© μ λ΅, CPU μ€νλ‘λ λ° [Trainer]μ ν¨κ» μ¬μ©ν μ μλ λ λ§μ κΈ°λ₯μ μμλ³΄λ €λ©΄ Fully Sharded Data Parallel κ°μ΄λλ₯Ό νμΈνμΈμ.
[Trainer]μ Accelerateλ₯Ό μ¬μ©νλ €λ©΄ accelerate.config λͺ
λ Ήμ μ€ννμ¬ νλ ¨ νκ²½μ μ€μ νμΈμ. μ΄ λͺ
λ Ήμ νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€νν λ μ¬μ©ν config_file.yamlμ μμ±ν©λλ€. μλ₯Ό λ€μ΄, λ€μ μμλ μ€μ ν μ μλ μΌλΆ κ΅¬μ± μμ
λλ€.
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0 # λ
Έλμ λ°λΌ μμλ₯Ό λ³κ²½νμΈμ
main_process_ip: 192.168.20.1
main_process_port: 9898
main_training_function: main
mixed_precision: fp16
num_machines: 2
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: /home/user/configs/ds_zero3_config.json
zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 0.7
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
accelerate_launch λͺ
λ Ήμ Accelerateμ [Trainer]λ₯Ό μ¬μ©νμ¬ λΆμ° μμ€ν
μμ νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€ννλ κΆμ₯ λ°©λ²μ΄λ©°, config_file.yamlμ μ§μ λ λ§€κ°λ³μλ₯Ό μ¬μ©ν©λλ€. μ΄ νμΌμ Accelerate μΊμ ν΄λμ μ μ₯λλ©° accelerate_launchλ₯Ό μ€νν λ μλμΌλ‘ λ‘λλ©λλ€.
μλ₯Ό λ€μ΄, FSDP ꡬμ±μ μ¬μ©νμ¬ run_glue.py νλ ¨ μ€ν¬λ¦½νΈλ₯Ό μ€ννλ €λ©΄ λ€μκ³Ό κ°μ΄ ν©λλ€:
accelerate launch \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
config_file.yaml νμΌμ λ§€κ°λ³μλ₯Ό μ§μ μ§μ ν μλ μμ΅λλ€:
accelerate launch --num_processes=2 \
--use_fsdp \
--mixed_precision=bf16 \
--fsdp_auto_wrap_policy=TRANSFORMER_BASED_WRAP \
--fsdp_transformer_layer_cls_to_wrap="BertLayer" \
--fsdp_sharding_strategy=1 \
--fsdp_state_dict_type=FULL_STATE_DICT \
./examples/pytorch/text-classification/run_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name $TASK_NAME \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 16 \
--learning_rate 5e-5 \
--num_train_epochs 3 \
--output_dir /tmp/$TASK_NAME/ \
--overwrite_output_dir
accelerate_launchμ μ¬μ©μ μ μ ꡬμ±μ λν΄ λ μμλ³΄λ €λ©΄ Accelerate μ€ν¬λ¦½νΈ μ€ν νν 리μΌμ νμΈνμΈμ.