Spaces:
Paused
Paused
Update train.py
Browse files
train.py
CHANGED
|
@@ -116,6 +116,8 @@ def train_model(model, tokenizer, dataset, push):
|
|
| 116 |
num_warmup_steps=args.warmup_steps,
|
| 117 |
num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
|
| 118 |
)
|
|
|
|
|
|
|
| 119 |
|
| 120 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
| 121 |
trainer = trl.SFTTrainer(
|
|
@@ -127,11 +129,7 @@ def train_model(model, tokenizer, dataset, push):
|
|
| 127 |
max_seq_length=MAX_SEQ_LENGTH,
|
| 128 |
optimizers=(optimizer, scheduler)
|
| 129 |
)
|
| 130 |
-
|
| 131 |
-
model, optimizer = accelerator.prepare(model, optimizer)
|
| 132 |
-
trainer.model = model
|
| 133 |
-
trainer.optimizer = optimizer
|
| 134 |
-
trainer = accelerator.prepare(trainer)
|
| 135 |
trainer.train()
|
| 136 |
|
| 137 |
trained_model = trainer.model
|
|
|
|
| 116 |
num_warmup_steps=args.warmup_steps,
|
| 117 |
num_training_steps=len(dataset) * args.num_train_epochs // args.gradient_accumulation_steps
|
| 118 |
)
|
| 119 |
+
|
| 120 |
+
model, optimizer = accelerator.prepare(model, optimizer)
|
| 121 |
|
| 122 |
dataset = dataset.map(lambda examples: format_prompts(examples, tokenizer), batched=True)
|
| 123 |
trainer = trl.SFTTrainer(
|
|
|
|
| 129 |
max_seq_length=MAX_SEQ_LENGTH,
|
| 130 |
optimizers=(optimizer, scheduler)
|
| 131 |
)
|
| 132 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
trainer.train()
|
| 134 |
|
| 135 |
trained_model = trainer.model
|