Open-TO commited on
Commit
6f203b4
·
1 Parent(s): fefda04

data and checkpoint added

Browse files
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+ #jupyter
9
+ .ipynb_checkpoints
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112
+ .pdm.toml
113
+ .pdm-python
114
+ .pdm-build/
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bikefusion import *
2
+ from torch.utils.data import TensorDataset, random_split
3
+ import torch
4
+ import os
5
+
6
+ current_dir = os.path.dirname(os.path.abspath(__file__))
7
+
8
+ def load_bikefusion_and_data():
9
+ partial_images, masks, parametric, description, targets = load_data(os.path.join(current_dir, 'data/'))
10
+
11
+ training_images = preprocess(targets)
12
+
13
+ dataset = TensorDataset(training_images)
14
+
15
+ # split to training and validation
16
+ train_size = int(0.8 * len(dataset))
17
+ val_size = len(dataset) - train_size
18
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
19
+
20
+ Diffuser = InpaintingDenoisingDiffusion(train_dataset, val_dataset, image_size=128)
21
+ Diffuser.load_checkpoint(os.path.join(current_dir, 'chekpoint/bikefusion.pt'))
22
+
23
+ return Diffuser
bikefusion/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .data_utils import load_data, preprocess, postprocess, to_mask_map
2
+ from .visualizers import visualize_imagesets, visualize_image_evolution
3
+ from .pipline import InpaintingDenoisingPipeline
4
+ from .diffusion import InpaintingDenoisingDiffusion
bikefusion/data_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms.functional as F
2
+ import torch
3
+ import numpy as np
4
+ import pickle
5
+
6
+ def to_mask_map(maks, image_size=(80,128)):
7
+ bs = maks.shape[0]
8
+ mask_map = np.ones((bs, 1, image_size[0], image_size[1]))
9
+
10
+ for i in range(bs):
11
+ left, bottom, right, top = maks[i]
12
+ mask_map[i, 0, top:bottom, left:right] = 0
13
+
14
+ return mask_map
15
+
16
+ def pad_to_square(image, target_size=(128, 128)):
17
+ # Calculate padding for each dimension
18
+ h, w = image.shape[1], image.shape[2]
19
+ delta_w = target_size[1] - w
20
+ delta_h = target_size[0] - h
21
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
22
+
23
+ # Pad with white (255) for each side
24
+ padded_image = F.pad(image, padding, fill=255)
25
+ return padded_image
26
+
27
+ def preprocess(images):
28
+ # Convert arrays to tensors
29
+ images = torch.from_numpy(images).float()
30
+
31
+ # Apply padding to each image in the dataset
32
+ images = torch.stack([pad_to_square(img) for img in images])
33
+
34
+ # Normalize images to [-1, 1]
35
+ images = images / 255.0 * 2 - 1
36
+
37
+ return images
38
+
39
+ def un_pad(image, target_size=(80, 128)):
40
+ # Calculate padding for each dimension
41
+ h, w = image.shape[1], image.shape[2]
42
+ delta_w = w - target_size[1]
43
+ delta_h = h - target_size[0]
44
+
45
+ # Unpad the image
46
+ un_padded_image = image[:, delta_h // 2:h - delta_h // 2, delta_w // 2:w - delta_w // 2]
47
+
48
+ return un_padded_image
49
+
50
+ def postprocess(images):
51
+ # Convert tensors to arrays
52
+ images = images.detach().cpu().numpy()
53
+
54
+ # Unpad each image in the dataset
55
+ images = np.stack([un_pad(img) for img in images])
56
+
57
+ # Rescale images to [0, 255]
58
+ images = (images + 1) / 2 * 255
59
+
60
+ return images.astype(np.uint8)
61
+
62
+ def load_data(split="train", path="data/"):
63
+ """
64
+ A function to load the data.
65
+
66
+ Parameters:
67
+ - split: str, "train" or "test"
68
+
69
+ Returns:
70
+ - masked_images: np.ndarray, shape (n_images, channels, height, width), the images with random masks applied
71
+ - masks: np.ndarray, shape (n_images, 4), the boundaries of the masks (left, bottom, right, top)
72
+ - parametric: np.ndarray, shape (n_images, 3), the parametric representation of the images
73
+ - description: list, the description of the images
74
+ - images: np.ndarray, shape (n_images, channels, height, width), the original images
75
+ """
76
+ masked_images = []
77
+ if split == "train":
78
+ for i in range(5):
79
+ masked_im_slice = np.load(f"{path}masked_train_{i}.npy")
80
+ masked_images.append(masked_im_slice)
81
+ masked_images = np.concatenate(masked_images, axis=0)
82
+
83
+ images = []
84
+ for i in range(5):
85
+ im_slice = np.load(f"{path}images_train_{i}.npy")
86
+ images.append(im_slice)
87
+ images = np.concatenate(images, axis=0)
88
+ else:
89
+ masked_images = np.load(f"{path}masked_test.npy")
90
+ images = None
91
+
92
+ description = pickle.load(open(f"{path}desc_{split}.pkl", "rb"))
93
+
94
+ parametric = np.load(f"{path}param_{split}.npy")
95
+
96
+ masks = np.load(f"{path}mask_{split}.npy")
97
+
98
+
99
+ return masked_images, masks, parametric, description, images
bikefusion/diffusion.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import UNet2DModel
2
+ from diffusers import DDPMScheduler
3
+ from accelerate import Accelerator
4
+ from diffusers.optimization import get_cosine_schedule_with_warmup
5
+ from tqdm.auto import tqdm
6
+ import torch.nn.functional as F
7
+ import torch
8
+ from torch.utils.data import DataLoader, random_split
9
+ from .pipline import InpaintingDenoisingPipeline as ConditionalDenoisingPipeline
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ from .visualizers import visualize_imagesets
13
+
14
+ class InpaintingDenoisingDiffusion:
15
+ def __init__(self,
16
+ train_dataset,
17
+ validation_dataset,
18
+ image_size = 128, # the generated image resolution
19
+ n_train_noise_timesteps = 1000, # the number of timesteps for the noise scheduler
20
+ train_batch_size = 16,
21
+ eval_batch_size = 16, # how many images to sample during evaluation
22
+ num_epochs = 50,
23
+ gradient_accumulation_steps = 1,
24
+ learning_rate = 5e-5,
25
+ lr_warmup_steps = 500,
26
+ mixed_precision = "fp16", # `no` for float32, `fp16` for automatic mixed precision
27
+ masking_range = ((24, 104), (0, 128)), # the range of the random masks applied in training
28
+ minimum_mask_portion = 0.4, # the minimum portion of the image to be masked
29
+ maximum_mask_portion = 0.8, # the maximum portion of the image to be masked
30
+ full_image_probability = 0.5, # the probability of applying a full image mask (no image)
31
+ device = None, # "cuda" or "cpu"
32
+ model = None,
33
+ ):
34
+
35
+ self.train_dataset = train_dataset
36
+ self.validation_dataset = validation_dataset
37
+ self.image_size = image_size
38
+ self.train_batch_size = train_batch_size
39
+ self.eval_batch_size = eval_batch_size
40
+ self.num_epochs = num_epochs
41
+ self.gradient_accumulation_steps = gradient_accumulation_steps
42
+ self.learning_rate = learning_rate
43
+ self.lr_warmup_steps = lr_warmup_steps
44
+ self.mixed_precision = mixed_precision
45
+ self.device = device
46
+ self.masking_range = masking_range
47
+ self.full_image_probability = full_image_probability
48
+ self.minimum_mask_portion = minimum_mask_portion
49
+ self.maximum_mask_portion = maximum_mask_portion
50
+ self.n_train_noise_timesteps = n_train_noise_timesteps
51
+
52
+ self.dataloader = DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
53
+ self.val_loader = DataLoader(self.validation_dataset, batch_size=self.eval_batch_size, shuffle=False)
54
+
55
+ if self.device is None:
56
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ if model is not None:
59
+ self.unet = model
60
+ else:
61
+ self.unet = UNet2DModel(
62
+ sample_size=self.image_size, # the target image resolution
63
+ in_channels=6, # the number of input channels, 3 for RGB masked images and 3 for RGB noise
64
+ out_channels=3, # the number of output channels (RGB)
65
+ layers_per_block=2, # how many ResNet layers to use per UNet block
66
+ block_out_channels=(128, 256, 512, 768), # the number of output channels for each UNet block
67
+ down_block_types=(
68
+ "DownBlock2D", # a regular ResNet downsampling block
69
+ "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
70
+ "AttnDownBlock2D",
71
+ "AttnDownBlock2D"
72
+ ),
73
+ up_block_types=(
74
+ "AttnUpBlock2D",# a ResNet upsampling block with spatial self-attention
75
+ "AttnUpBlock2D",
76
+ "AttnUpBlock2D",
77
+ "UpBlock2D"# a regular ResNet upsampling block
78
+ ),
79
+ ).to(self.device)
80
+
81
+ self.noise_scheduler = DDPMScheduler(num_train_timesteps=n_train_noise_timesteps)
82
+ self.optimizer = torch.optim.AdamW(self.unet.parameters(), lr=self.learning_rate)
83
+ self.lr_scheduler = get_cosine_schedule_with_warmup(
84
+ optimizer=self.optimizer,
85
+ num_warmup_steps=self.lr_warmup_steps,
86
+ num_training_steps=(len(self.dataloader) * self.num_epochs),
87
+ )
88
+ self.current_epoch = 0
89
+
90
+ self.denoiser = ConditionalDenoisingPipeline(self.unet, self.noise_scheduler)
91
+
92
+ def checkpoints(self, path):
93
+ torch.save({
94
+ "model": self.unet.state_dict(),
95
+ "optimizer": self.optimizer.state_dict(),
96
+ "scheduler": self.lr_scheduler.state_dict(),
97
+ "epoch": self.current_epoch
98
+ }, path)
99
+
100
+ def load_checkpoint(self, path):
101
+ checkpoint = torch.load(path)
102
+ if "model" not in checkpoint:
103
+ raise ValueError("Checkpoint does not contain a model state dict")
104
+ else:
105
+ self.unet.load_state_dict(checkpoint["model"])
106
+ if "optimizer" in checkpoint:
107
+ try:
108
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
109
+ except:
110
+ print("Could not load optimizer state dict")
111
+ if "scheduler" in checkpoint:
112
+ try:
113
+ self.lr_scheduler.load_state_dict(checkpoint["scheduler"])
114
+ except:
115
+ print("Could not load scheduler state dict")
116
+ if "epoch" in checkpoint:
117
+ self.current_epoch = checkpoint["epoch"]
118
+ else:
119
+ self.current_epoch = 0
120
+ print("Epoch information not found in checkpoint setting epoch to 0")
121
+ self.lr_scheduler.last_epoch = 0
122
+
123
+ def create_random_masks(self, n_masks):
124
+ masks = torch.ones(n_masks, 1, self.image_size, self.image_size).to(self.device)
125
+ height = self.masking_range[0][1] - self.masking_range[0][0]
126
+ width = self.masking_range[1][1] - self.masking_range[1][0]
127
+ mask_heights = np.random.randint(int(height * self.minimum_mask_portion), int(height * self.maximum_mask_portion), n_masks)
128
+ mask_widths = np.random.randint(int(width * self.minimum_mask_portion), int(width * self.maximum_mask_portion), n_masks)
129
+
130
+ top_positions = np.random.randint(0, height - mask_heights + 1, n_masks)
131
+ left_positions = np.random.randint(0, width - mask_widths + 1, n_masks)
132
+
133
+ top_positions += self.masking_range[0][0]
134
+ left_positions += self.masking_range[1][0]
135
+ for i in range(n_masks):
136
+ if np.random.rand() < self.full_image_probability:
137
+ masks[i] = 0
138
+ continue
139
+ top = top_positions[i]
140
+ left = left_positions[i]
141
+ mask_height = mask_heights[i]
142
+ mask_width = mask_widths[i]
143
+
144
+ bottom = top + mask_height
145
+ right = left + mask_width
146
+
147
+ masks[i, :, top:bottom, left:right] = 0
148
+
149
+ return masks
150
+
151
+ def get_sample_batch(self, n_samples):
152
+ rnd_idx = np.random.choice(len(self.train_dataset), n_samples).astype(int)
153
+ images = torch.stack([self.train_dataset[i][0] for i in rnd_idx]).to(self.device)
154
+ masks = self.create_random_masks(n_samples)
155
+ masked_images = images * masks + (1-masks)
156
+
157
+ return masked_images.cpu().detach().numpy(), masks.cpu().detach().numpy(), images.cpu().detach().numpy()
158
+
159
+ def get_sample_noising(self, n_samples, n_timesteps=5):
160
+ rnd_idx = np.random.randint(0, len(self.train_dataset), n_samples)
161
+ images = torch.stack([self.train_dataset[i][0] for i in rnd_idx])
162
+ noise = torch.randn(images.shape, device=images.device)
163
+ timesteps = torch.linspace(0, self.n_train_noise_timesteps-1, n_timesteps, device=images.device).long()
164
+
165
+ noisy_images = torch.zeros((n_samples, n_timesteps, images.shape[1], images.shape[2], images.shape[3]), device=images.device)
166
+ for i in range(n_timesteps):
167
+ noisy_images[:, i] = self.noise_scheduler.add_noise(images, noise, timesteps[i])
168
+
169
+ return noisy_images
170
+
171
+ def reset_epoch(self):
172
+ self.current_epoch = 0
173
+ self.lr_scheduler.last_epoch = 0
174
+
175
+ def train(self, n_epoch=None, display_interval=1, checkpoint_interval=10, checkpoint_path_prefix="checkpoints"):
176
+ if n_epoch is None:
177
+ n_epoch = self.num_epochs
178
+
179
+ accelerator = Accelerator(
180
+ mixed_precision=self.mixed_precision,
181
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
182
+ )
183
+
184
+
185
+ model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
186
+ self.unet, self.optimizer, self.dataloader, self.lr_scheduler
187
+ )
188
+
189
+ for epoch in range(self.current_epoch, n_epoch):
190
+ model.train()
191
+ progress = tqdm(train_dataloader)
192
+ for images in progress:
193
+ images = images[0]
194
+ images = images.to(model.device)
195
+ bs = images.size(0)
196
+ # create random masks
197
+ masks = self.create_random_masks(bs)
198
+ masked_images = images * masks + (1-masks)
199
+
200
+ noise = torch.randn(images.shape, device=images.device)
201
+ timesteps = torch.randint(
202
+ 0, self.noise_scheduler.config.num_train_timesteps, (bs,), device=images.device,
203
+ dtype=torch.int64
204
+ )
205
+
206
+ noisy_images = self.noise_scheduler.add_noise(images, noise, timesteps)
207
+
208
+ with accelerator.accumulate(model):
209
+ noise_pred = model(torch.cat([masked_images,noisy_images],1), timesteps, return_dict=False)[0]
210
+
211
+ loss = F.mse_loss(noise_pred, noise)
212
+ accelerator.backward(loss)
213
+
214
+ if accelerator.sync_gradients:
215
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
216
+
217
+ optimizer.step()
218
+ lr_scheduler.step()
219
+ optimizer.zero_grad()
220
+
221
+ progress.set_postfix_str(f"Epoch: {epoch}/{self.num_epochs}, Loss: {loss.item():.7f}")
222
+
223
+ if epoch % display_interval == 0:
224
+ # unmasks one batch of validation images
225
+ images = next(iter(self.val_loader))[0].to(model.device)[0:4]
226
+ masks = self.create_random_masks(images.size(0))
227
+ masked_images = images * masks + (1-masks)
228
+ filled_images = self.generate(images=masked_images, num_inference_steps=100)
229
+ visualize_imagesets(
230
+ (images.cpu().detach().numpy() + 1)/2,
231
+ masks.cpu().detach().numpy(),
232
+ (masked_images.cpu().detach().numpy() + 1)/2,
233
+ (filled_images + 1)/2,
234
+ titles=["Original", "Mask Maps", "Masked", "Filled"]
235
+ )
236
+
237
+ if epoch % checkpoint_interval == 0:
238
+ self.checkpoints(f"{checkpoint_path_prefix}_{epoch}.pt")
239
+
240
+ self.current_epoch += 1
241
+
242
+ def generate(self, images=None, num_inference_steps=100, n_samples=10, noise_seed=None, return_intermediate=False, guidance_function=None, guidance_scale=1.0):
243
+ if images is None and n_samples is None:
244
+ raise ValueError("Either images or n_samples must be provided")
245
+ elif images is None:
246
+ images = torch.ones(n_samples, 3, self.image_size, self.image_size)
247
+
248
+ initia_noise = None
249
+ if noise_seed is not None:
250
+ torch.manual_seed(noise_seed)
251
+ initia_noise = torch.randn_like(images, device=images.device)
252
+
253
+ self.unet.eval()
254
+ return self.denoiser(images, num_inference_steps=num_inference_steps, initial_noise=initia_noise, return_intermediate=return_intermediate, guidance_function=guidance_function, guidance_scale=guidance_scale)
bikefusion/pipline.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from tqdm.auto import tqdm
5
+
6
+ class InpaintingDenoisingPipeline:
7
+ def __init__(self, model, scheduler):
8
+ super().__init__()
9
+ self.model = model
10
+ self.scheduler = scheduler
11
+
12
+
13
+ @torch.no_grad()
14
+ def __call__(self, masked_images, num_inference_steps=50, guidance_function=None, guidance_scale=1.0, initial_noise=None, return_intermediate=False):
15
+ # Ensure model and scheduler are in evaluation mode
16
+ self.model.eval()
17
+
18
+ # Initialize the noisy image as the input image
19
+ device = self.model.device
20
+ masked_images = masked_images.to(device)
21
+ Bs = masked_images.size(0)
22
+
23
+ if return_intermediate:
24
+ out = torch.zeros((Bs, num_inference_steps+1, 3, masked_images.shape[2], masked_images.shape[3]), device=device)
25
+
26
+ if initial_noise is not None:
27
+ noise = torch.tensor(initial_noise).float().to(device)
28
+ else:
29
+ noise = torch.randn_like(masked_images, dtype=self.model.dtype).to(self.model.device)
30
+
31
+ self.scheduler.set_timesteps(num_inference_steps)
32
+
33
+ if return_intermediate:
34
+ out[:, 0] = noise
35
+ c = 1
36
+
37
+ for t in tqdm(self.scheduler.timesteps):
38
+ # Get the noise level for this timestep
39
+ model_output = self.model(torch.cat([masked_images,noise],1), t).sample
40
+
41
+ noise = self.scheduler.step(model_output, t, noise).prev_sample
42
+
43
+ if guidance_function is not None:
44
+ with torch.enable_grad():
45
+ noise.requires_grad = True
46
+ guidance_objective = guidance_function(noise, t)
47
+ guidance_objective.backward()
48
+ grads = noise.grad
49
+ noise = noise - guidance_scale * grads
50
+
51
+ if return_intermediate:
52
+ out[:, c] = noise
53
+ c += 1
54
+
55
+ if return_intermediate:
56
+ return out.clamp(-1,1).detach().cpu().numpy()
57
+ else:
58
+ return noise.clamp(-1,1).detach().cpu().numpy()
bikefusion/visualizers.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+
5
+ def visualize_imagesets(*args,titles=None):
6
+ """
7
+ A function to visualize image sets in rows.
8
+ Each row corresponds to a different image set with images stacked corresponding to oneanother.
9
+
10
+ Parameters:
11
+ - args: np.ndarray,torch.tensor the image sets to visualize (must be in B x C x H x W format)
12
+ """
13
+
14
+ if titles is not None:
15
+ assert len(titles) == len(args), "Number of titles must match number of image sets"
16
+ elif titles is None:
17
+ pass
18
+ elif not isinstance(titles[0], str) and len(titles[0]) > 1:
19
+ for i in range(len(args)):
20
+ assert len(titles[i]) == len(args[i]), "Number of titles must match number of images in each set or be 1"
21
+
22
+ n_sets = len(args)
23
+ n_images = len(args[0])
24
+ fig, axs = plt.subplots(n_sets, n_images, figsize=(n_images * 3, n_sets * 3))
25
+ for i in range(n_sets):
26
+ for j in range(n_images):
27
+ if isinstance(args[i][j], torch.Tensor):
28
+ image = args[i][j].detach().cpu().numpy()
29
+ elif isinstance(args[i][j], np.ndarray):
30
+ image = args[i][j]
31
+ else:
32
+ raise ValueError("Image must be a numpy array or torch tensor")
33
+ if image.shape[0] == 1:
34
+ if n_sets == 1:
35
+ axs[j].imshow(image[0], cmap="gray")
36
+ else:
37
+ axs[i, j].imshow(image[0], cmap="gray")
38
+ else:
39
+ if n_sets == 1:
40
+ axs[j].imshow(np.transpose(image, (1, 2, 0)))
41
+ else:
42
+ axs[i, j].imshow(np.transpose(image, (1, 2, 0)))
43
+ if n_sets == 1:
44
+ axs[j].axis("off")
45
+ else:
46
+ axs[i, j].axis("off")
47
+ if titles is not None and not isinstance(titles[i], str):
48
+ if len(titles[i]) == 1:
49
+ axs[j].set_title(titles[i][0])
50
+ else:
51
+ axs[i, j].set_title(titles[i][j])
52
+ elif titles is not None:
53
+ if n_sets == 1:
54
+ axs[j].set_title(titles[i])
55
+ else:
56
+ axs[i, j].set_title(titles[i])
57
+
58
+
59
+ plt.tight_layout()
60
+ plt.show()
61
+
62
+ def visualize_image_evolution(images, titles=None):
63
+ """
64
+ A function to visualize the evolution of images over time.
65
+
66
+ Parameters:
67
+ - images: np.ndarray, shape (n_images, n_timesteps, channels, height, width), the images to visualize
68
+ - titles: list, the titles for each iteration
69
+ """
70
+ if isinstance(images, torch.Tensor):
71
+ images = images.detach().cpu().numpy()
72
+
73
+ if titles is not None:
74
+ assert len(titles) == images.shape[1], "Number of titles must match number of timesteps"
75
+ n_images = images.shape[0]
76
+ n_timesteps = images.shape[1]
77
+ fig, axs = plt.subplots(n_images, n_timesteps, figsize=(n_timesteps * 3, n_images * 3))
78
+ for i in range(n_images):
79
+ for j in range(n_timesteps):
80
+ if images.shape[2] == 1:
81
+ axs[i, j].imshow(images[i, j, 0], cmap="gray")
82
+ else:
83
+ axs[i, j].imshow(np.transpose(images[i, j], (1, 2, 0)))
84
+ axs[i, j].axis("off")
85
+ if titles is not None:
86
+ axs[i, j].set_title(titles[j])
87
+ plt.tight_layout()
88
+ plt.show()
chekpoint/bikefusion.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e68138c27aea63527caa1344ed9e1aa9ad21580f0ac8194b4495917b20409ee
3
+ size 678272190
data/desc_test.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2afd555e135b03e8176e726e692d26c131838bf6ba397c13b1238582401e21c7
3
+ size 682467
data/desc_train.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8314c63c08692c3f016b7bdd87d97620e4e978aa72cbd637f79ef4f6438babf2
3
+ size 6844466
data/images_train_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d003aba1dd5d8a35b7b467ef926d09a7b1d4ed4e309876ed49f5f81c181125c9
3
+ size 61440128
data/images_train_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34560a1a54310f8dd8ca831f2510df098323ef2c7509e2aaeed6f80d319f46a1
3
+ size 61440128
data/images_train_2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8035d533235b5af64352964181a2a2d8ffb051de7def627d4cf6ceba641e04b
3
+ size 61440128
data/images_train_3.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba7a6d2c142189b1278aa7fd528d90af054531884acd923747a66f0f4a9668ca
3
+ size 61440128
data/images_train_4.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e23641baa6c970ebb16c2c1a0b34d4d63201fba7d03028673de4c8ee466ea0e
3
+ size 61440128
data/mask_test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:341234393c5924a807c991ef17f788c1a559ecb37d0e43e515bb372f7d9d526d
3
+ size 32128
data/mask_train.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b01ffbdc95fea079ddc07151ea3a7eb282aea5b38467dbe738c4e420973a9b8f
3
+ size 320128
data/masked_test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52cb7f4d6b14828ad8bafc095ebf4ec4fa33aeda7d8a8a7e112365b279b49e67
3
+ size 30720128
data/masked_train_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32183c7f1c0600392cc4f16429162c1982e2495241036eb5576af3db2d49e50d
3
+ size 61440128
data/masked_train_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5f5f32b54c71830452fef97e33fad31d8be6133cd99988f59639de94a3b5a2c
3
+ size 61440128
data/masked_train_2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:971dd8d938d8d29cfbdd7b74811b6fa6870646795601768c6bf9e79bf85be61c
3
+ size 61440128
data/masked_train_3.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bf79d2d514de1840f7b583d91164abb75f49bc0470547da9609027be84cdcc5
3
+ size 61440128
data/masked_train_4.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7eaaadb2a3673d9f95639a891bf53d0c3dcb8fbfeed9e011720de6a9dbd51c5
3
+ size 61440128
data/param_test.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7eea53c5db5f0062570ae63deacc0df05ccada087781940c648cd54d3f5390df
3
+ size 768128
data/param_train.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d267b272bf69c08cc302190bbf47eb7a0e31edc70180cf7010d02888f2b1e8fb
3
+ size 7680128