| import os | |
| import sys | |
| import json | |
| import argparse | |
| import numpy as np | |
| import math | |
| from einops import rearrange | |
| import time | |
| import random | |
| import h5py | |
| from tqdm import tqdm | |
| import webdataset as wds | |
| import gc | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.transforms import ToPILImage #CHANGED (added) | |
| from accelerate import Accelerator, DeepSpeedPlugin | |
| # tf32 data type is faster than standard float32 | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # custom functions # | |
| import utils | |
| global_batch_size = 128 #128 | |
| ### Multi-GPU config ### | |
| local_rank = os.getenv('RANK') | |
| if local_rank is None: | |
| local_rank = 0 | |
| else: | |
| local_rank = int(local_rank) | |
| print("LOCAL RANK ", local_rank) | |
| num_devices = torch.cuda.device_count() | |
| if num_devices==0: num_devices = 1 | |
| accelerator = Accelerator(split_batches=False) | |
| ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ### | |
| # if num_devices <= 1 and utils.is_interactive(): | |
| # # can emulate a distributed environment for deepspeed to work in jupyter notebook | |
| # os.environ["MASTER_ADDR"] = "localhost" | |
| # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000) | |
| # os.environ["RANK"] = "0" | |
| # os.environ["LOCAL_RANK"] = "0" | |
| # os.environ["WORLD_SIZE"] = "1" | |
| # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size! | |
| # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"] | |
| # # alter the deepspeed config according to your global and local batch size | |
| # if local_rank == 0: | |
| # with open('deepspeed_config_stage2.json', 'r') as file: | |
| # config = json.load(file) | |
| # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"]) | |
| # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices | |
| # with open('deepspeed_config_stage2.json', 'w') as file: | |
| # json.dump(config, file) | |
| # else: | |
| # # give some time for the local_rank=0 gpu to prep new deepspeed config file | |
| # time.sleep(10) | |
| # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json") | |
| # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin) | |