Update eval.py
Browse files
eval.py
CHANGED
|
@@ -290,7 +290,21 @@ def get_full_repo_name(model_id: str, organization: str = None, token: str = Non
|
|
| 290 |
|
| 291 |
|
| 292 |
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
|
| 296 |
|
|
|
|
| 290 |
|
| 291 |
|
| 292 |
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
| 293 |
+
# Initialize accelerator and tensorboard logging
|
| 294 |
+
accelerator = Accelerator(
|
| 295 |
+
mixed_precision=config.mixed_precision,
|
| 296 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| 297 |
+
log_with="tensorboard",
|
| 298 |
+
project_dir=os.path.join(config.output_dir, "logs"),
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 303 |
+
model, optimizer, train_dataloader, lr_scheduler
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if accelerator.is_main_process:
|
| 307 |
+
evaluate(config, 1, config.pipeline)
|
| 308 |
|
| 309 |
|
| 310 |
|