| | |
| | |
| | |
| | |
| |
|
| | import torch as th |
| |
|
| |
|
| | def get_generator(generator, num_samples=0, seed=0): |
| | if generator == "dummy": |
| | return DummyGenerator() |
| | elif generator == "determ": |
| | return DeterministicGenerator(num_samples, seed) |
| | elif generator == "determ-indiv": |
| | return DeterministicIndividualGenerator(num_samples, seed) |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | class DummyGenerator: |
| | def randn(self, *args, **kwargs): |
| | return th.randn(*args, **kwargs) |
| |
|
| | def randint(self, *args, **kwargs): |
| | return th.randint(*args, **kwargs) |
| |
|
| | def randn_like(self, *args, **kwargs): |
| | return th.randn_like(*args, **kwargs) |
| |
|
| |
|
| | class DeterministicGenerator: |
| | """ |
| | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines |
| | Uses a single rng and samples num_samples sized randomness and subsamples the current indices |
| | """ |
| |
|
| | def __init__(self, num_samples, seed=0): |
| | print("Warning: Distributed not initialised, using single rank") |
| | self.rank = 0 |
| | self.world_size = 1 |
| | self.num_samples = num_samples |
| | self.done_samples = 0 |
| | self.seed = seed |
| | self.rng_cpu = th.Generator() |
| | if th.cuda.is_available(): |
| | self.rng_cuda = th.Generator(dist_util.dev()) |
| | self.set_seed(seed) |
| |
|
| | def get_global_size_and_indices(self, size): |
| | global_size = (self.num_samples, *size[1:]) |
| | indices = th.arange( |
| | self.done_samples + self.rank, |
| | self.done_samples + self.world_size * int(size[0]), |
| | self.world_size, |
| | ) |
| | indices = th.clamp(indices, 0, self.num_samples - 1) |
| | assert ( |
| | len(indices) == size[0] |
| | ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" |
| | return global_size, indices |
| |
|
| | def get_generator(self, device): |
| | return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda |
| |
|
| | def randn(self, *size, dtype=th.float, device="cpu"): |
| | global_size, indices = self.get_global_size_and_indices(size) |
| | generator = self.get_generator(device) |
| | return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[ |
| | indices |
| | ] |
| |
|
| | def randint(self, low, high, size, dtype=th.long, device="cpu"): |
| | global_size, indices = self.get_global_size_and_indices(size) |
| | generator = self.get_generator(device) |
| | return th.randint( |
| | low, high, generator=generator, size=global_size, dtype=dtype, device=device |
| | )[indices] |
| |
|
| | def randn_like(self, tensor): |
| | size, dtype, device = tensor.size(), tensor.dtype, tensor.device |
| | return self.randn(*size, dtype=dtype, device=device) |
| |
|
| | def set_done_samples(self, done_samples): |
| | self.done_samples = done_samples |
| | self.set_seed(self.seed) |
| |
|
| | def get_seed(self): |
| | return self.seed |
| |
|
| | def set_seed(self, seed): |
| | self.rng_cpu.manual_seed(seed) |
| | if th.cuda.is_available(): |
| | self.rng_cuda.manual_seed(seed) |
| |
|
| |
|
| | class DeterministicIndividualGenerator: |
| | """ |
| | RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines |
| | Uses a separate rng for each sample to reduce memoery usage |
| | """ |
| |
|
| | def __init__(self, num_samples, seed=0): |
| | print("Warning: Distributed not initialised, using single rank") |
| | self.rank = 0 |
| | self.world_size = 1 |
| | self.num_samples = num_samples |
| | self.done_samples = 0 |
| | self.seed = seed |
| | self.rng_cpu = [th.Generator() for _ in range(num_samples)] |
| | if th.cuda.is_available(): |
| | self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)] |
| | self.set_seed(seed) |
| |
|
| | def get_size_and_indices(self, size): |
| | indices = th.arange( |
| | self.done_samples + self.rank, |
| | self.done_samples + self.world_size * int(size[0]), |
| | self.world_size, |
| | ) |
| | indices = th.clamp(indices, 0, self.num_samples - 1) |
| | assert ( |
| | len(indices) == size[0] |
| | ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}" |
| | return (1, *size[1:]), indices |
| |
|
| | def get_generator(self, device): |
| | return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda |
| |
|
| | def randn(self, *size, dtype=th.float, device="cpu"): |
| | size, indices = self.get_size_and_indices(size) |
| | generator = self.get_generator(device) |
| | return th.cat( |
| | [ |
| | th.randn(*size, generator=generator[i], dtype=dtype, device=device) |
| | for i in indices |
| | ], |
| | dim=0, |
| | ) |
| |
|
| | def randint(self, low, high, size, dtype=th.long, device="cpu"): |
| | size, indices = self.get_size_and_indices(size) |
| | generator = self.get_generator(device) |
| | return th.cat( |
| | [ |
| | th.randint( |
| | low, |
| | high, |
| | generator=generator[i], |
| | size=size, |
| | dtype=dtype, |
| | device=device, |
| | ) |
| | for i in indices |
| | ], |
| | dim=0, |
| | ) |
| |
|
| | def randn_like(self, tensor): |
| | size, dtype, device = tensor.size(), tensor.dtype, tensor.device |
| | return self.randn(*size, dtype=dtype, device=device) |
| |
|
| | def set_done_samples(self, done_samples): |
| | self.done_samples = done_samples |
| |
|
| | def get_seed(self): |
| | return self.seed |
| |
|
| | def set_seed(self, seed): |
| | [ |
| | rng_cpu.manual_seed(i + self.num_samples * seed) |
| | for i, rng_cpu in enumerate(self.rng_cpu) |
| | ] |
| | if th.cuda.is_available(): |
| | [ |
| | rng_cuda.manual_seed(i + self.num_samples * seed) |
| | for i, rng_cuda in enumerate(self.rng_cuda) |
| | ] |
| |
|