| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import re |
| | import shutil |
| | import sys |
| | import tempfile |
| | import unittest |
| |
|
| |
|
| | git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
| | sys.path.append(os.path.join(git_repo_path, "utils")) |
| |
|
| | import check_copies |
| |
|
| |
|
| | |
| | |
| | REFERENCE_CODE = """ \""" |
| | Output class for the scheduler's `step` function output. |
| | |
| | Args: |
| | prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
| | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the |
| | denoising loop. |
| | pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): |
| | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. |
| | `pred_original_sample` can be used to preview progress or for guidance. |
| | \""" |
| | |
| | prev_sample: torch.Tensor |
| | pred_original_sample: Optional[torch.Tensor] = None |
| | """ |
| |
|
| |
|
| | class CopyCheckTester(unittest.TestCase): |
| | def setUp(self): |
| | self.diffusers_dir = tempfile.mkdtemp() |
| | os.makedirs(os.path.join(self.diffusers_dir, "schedulers/")) |
| | check_copies.DIFFUSERS_PATH = self.diffusers_dir |
| | shutil.copy( |
| | os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"), |
| | os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"), |
| | ) |
| |
|
| | def tearDown(self): |
| | check_copies.DIFFUSERS_PATH = "src/diffusers" |
| | shutil.rmtree(self.diffusers_dir) |
| |
|
| | def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None): |
| | code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code |
| | if overwrite_result is not None: |
| | expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result |
| | code = check_copies.run_ruff(code) |
| | fname = os.path.join(self.diffusers_dir, "new_code.py") |
| | with open(fname, "w", newline="\n") as f: |
| | f.write(code) |
| | if overwrite_result is None: |
| | self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0) |
| | else: |
| | check_copies.is_copy_consistent(f.name, overwrite=True) |
| | with open(fname, "r") as f: |
| | self.assertTrue(f.read(), expected) |
| |
|
| | def test_find_code_in_diffusers(self): |
| | code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput") |
| | self.assertEqual(code, REFERENCE_CODE) |
| |
|
| | def test_is_copy_consistent(self): |
| | |
| | self.check_copy_consistency( |
| | "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", |
| | "DDPMSchedulerOutput", |
| | REFERENCE_CODE + "\n", |
| | ) |
| |
|
| | |
| | self.check_copy_consistency( |
| | "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput", |
| | "DDPMSchedulerOutput", |
| | REFERENCE_CODE, |
| | ) |
| |
|
| | |
| | self.check_copy_consistency( |
| | "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", |
| | "TestSchedulerOutput", |
| | re.sub("DDPM", "Test", REFERENCE_CODE), |
| | ) |
| |
|
| | |
| | long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason" |
| | self.check_copy_consistency( |
| | f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}", |
| | f"{long_class_name}SchedulerOutput", |
| | re.sub("Bert", long_class_name, REFERENCE_CODE), |
| | ) |
| |
|
| | |
| | self.check_copy_consistency( |
| | "# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test", |
| | "TestSchedulerOutput", |
| | REFERENCE_CODE, |
| | overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE), |
| | ) |
| |
|