Spaces:
Sleeping
Sleeping
split model setup and task execution
Browse files
main.py
CHANGED
|
@@ -14,7 +14,8 @@ from rewards import get_reward_losses
|
|
| 14 |
from training import LatentNoiseTrainer, get_optimizer
|
| 15 |
|
| 16 |
|
| 17 |
-
def
|
|
|
|
| 18 |
seed_everything(args.seed)
|
| 19 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
| 20 |
# Set up logging and name settings
|
|
@@ -92,6 +93,10 @@ def main(args, progress_callback=None):
|
|
| 92 |
)
|
| 93 |
enable_grad = not args.no_optim
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
if args.task == "single":
|
| 96 |
init_latents = torch.randn(shape, device=device, dtype=dtype)
|
| 97 |
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
|
|
@@ -269,7 +274,12 @@ def main(args, progress_callback=None):
|
|
| 269 |
# log total rewards
|
| 270 |
logging.info(f"Mean initial rewards: {total_init_rewards}")
|
| 271 |
logging.info(f"Mean best rewards: {total_best_rewards}")
|
| 272 |
-
|
| 273 |
-
|
| 274 |
args = parse_args()
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from training import LatentNoiseTrainer, get_optimizer
|
| 15 |
|
| 16 |
|
| 17 |
+
def setup(args):
|
| 18 |
+
#args = parse_args()
|
| 19 |
seed_everything(args.seed)
|
| 20 |
bf.makedirs(f"{args.save_dir}/logs/{args.task}")
|
| 21 |
# Set up logging and name settings
|
|
|
|
| 93 |
)
|
| 94 |
enable_grad = not args.no_optim
|
| 95 |
|
| 96 |
+
return args, trainer, device, dtype, shape, enable_grad, settings
|
| 97 |
+
|
| 98 |
+
def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, progress_callback=None):
|
| 99 |
+
#args = parse_args()
|
| 100 |
if args.task == "single":
|
| 101 |
init_latents = torch.randn(shape, device=device, dtype=dtype)
|
| 102 |
latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad)
|
|
|
|
| 274 |
# log total rewards
|
| 275 |
logging.info(f"Mean initial rewards: {total_init_rewards}")
|
| 276 |
logging.info(f"Mean best rewards: {total_best_rewards}")
|
| 277 |
+
|
| 278 |
+
def main():
|
| 279 |
args = parse_args()
|
| 280 |
+
args, trainer, device, dtype, shape, enable_grad, settings = setup(args)
|
| 281 |
+
execute_task(args, trainer, device, dtype, shape, enable_grad, settings)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
if __name__ == "__main__":
|
| 285 |
+
main()
|