GaumlessGraham commited on
Commit
18ee31e
·
verified ·
1 Parent(s): b710d18

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +15 -1
eval.py CHANGED
@@ -290,7 +290,21 @@ def get_full_repo_name(model_id: str, organization: str = None, token: str = Non
290
 
291
 
292
  def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
293
- evaluate(config, 1, config.pipeline)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
 
296
 
 
290
 
291
 
292
  def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
293
+ # Initialize accelerator and tensorboard logging
294
+ accelerator = Accelerator(
295
+ mixed_precision=config.mixed_precision,
296
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
297
+ log_with="tensorboard",
298
+ project_dir=os.path.join(config.output_dir, "logs"),
299
+ )
300
+
301
+
302
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
303
+ model, optimizer, train_dataloader, lr_scheduler
304
+ )
305
+
306
+ if accelerator.is_main_process:
307
+ evaluate(config, 1, config.pipeline)
308
 
309
 
310