| import torch |
| import torch.nn.functional as F |
|
|
| from diffusers import VQDiffusionScheduler |
|
|
| from .test_schedulers import SchedulerCommonTest |
|
|
|
|
| class VQDiffusionSchedulerTest(SchedulerCommonTest): |
| scheduler_classes = (VQDiffusionScheduler,) |
|
|
| def get_scheduler_config(self, **kwargs): |
| config = { |
| "num_vec_classes": 4097, |
| "num_train_timesteps": 100, |
| } |
|
|
| config.update(**kwargs) |
| return config |
|
|
| def dummy_sample(self, num_vec_classes): |
| batch_size = 4 |
| height = 8 |
| width = 8 |
|
|
| sample = torch.randint(0, num_vec_classes, (batch_size, height * width)) |
|
|
| return sample |
|
|
| @property |
| def dummy_sample_deter(self): |
| assert False |
|
|
| def dummy_model(self, num_vec_classes): |
| def model(sample, t, *args): |
| batch_size, num_latent_pixels = sample.shape |
| logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels)) |
| return_value = F.log_softmax(logits.double(), dim=1).float() |
| return return_value |
|
|
| return model |
|
|
| def test_timesteps(self): |
| for timesteps in [2, 5, 100, 1000]: |
| self.check_over_configs(num_train_timesteps=timesteps) |
|
|
| def test_num_vec_classes(self): |
| for num_vec_classes in [5, 100, 1000, 4000]: |
| self.check_over_configs(num_vec_classes=num_vec_classes) |
|
|
| def test_time_indices(self): |
| for t in [0, 50, 99]: |
| self.check_over_forward(time_step=t) |
|
|
| def test_add_noise_device(self): |
| pass |
|
|