Update pipeline_spad.py
Browse files- pipeline_spad.py +20 -23
pipeline_spad.py
CHANGED
|
@@ -81,7 +81,7 @@ class SPADPipeline(DiffusionPipeline):
|
|
| 81 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 82 |
device = self.device
|
| 83 |
|
| 84 |
-
#
|
| 85 |
if elevations is None or azimuths is None:
|
| 86 |
elevations = [45] * 4
|
| 87 |
azimuths = [0, 90, 180, 270]
|
|
@@ -90,23 +90,23 @@ class SPADPipeline(DiffusionPipeline):
|
|
| 90 |
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
| 91 |
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
| 92 |
|
| 93 |
-
#
|
| 94 |
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
|
| 95 |
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
|
| 96 |
|
| 97 |
-
#
|
| 98 |
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
|
| 99 |
text_embeddings = self.text_encoder(text_input_ids)[0]
|
| 100 |
|
| 101 |
-
#
|
| 102 |
max_length = text_input_ids.shape[-1]
|
| 103 |
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
| 104 |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
| 105 |
|
| 106 |
-
#
|
| 107 |
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
|
| 108 |
|
| 109 |
-
#
|
| 110 |
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
| 111 |
latents = self.prepare_latents(
|
| 112 |
batch_size,
|
|
@@ -119,36 +119,33 @@ class SPADPipeline(DiffusionPipeline):
|
|
| 119 |
generator=None,
|
| 120 |
)
|
| 121 |
|
| 122 |
-
#
|
| 123 |
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
|
| 124 |
|
| 125 |
-
#
|
| 126 |
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
|
| 127 |
|
| 128 |
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
| 129 |
n_objects = 2;
|
| 130 |
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
|
| 131 |
|
| 132 |
-
#
|
| 133 |
# self.scheduler.set_timesteps(num_inference_steps)
|
| 134 |
self.scheduler.set_timesteps(50)
|
| 135 |
-
#
|
| 136 |
-
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len,
|
| 137 |
|
| 138 |
-
#
|
| 139 |
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
| 140 |
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
| 141 |
-
#
|
| 142 |
for t in tqdm(self.scheduler.timesteps):
|
| 143 |
-
#
|
| 144 |
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
| 145 |
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
| 146 |
|
| 147 |
-
#
|
| 148 |
context = [
|
| 149 |
-
# text_embeddings.unsqueeze(1), # [batch_size, 1, max_seq_len, 768]
|
| 150 |
-
# camera_embeddings.unsqueeze(1) * 0.0, # [batch_size, 1, 1280] * 0.0
|
| 151 |
-
# epi_constraint_masks # Keep this as is for now
|
| 152 |
text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
|
| 153 |
camera_embeddings, # [n_objects, n_views, 1280]
|
| 154 |
torch.ones(n_objects, n_views, 6, 32, 32).to(device)
|
|
@@ -161,22 +158,22 @@ class SPADPipeline(DiffusionPipeline):
|
|
| 161 |
context=context
|
| 162 |
)
|
| 163 |
|
| 164 |
-
#
|
| 165 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 166 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 167 |
|
| 168 |
|
| 169 |
-
#
|
| 170 |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 171 |
|
| 172 |
# reduce latents
|
| 173 |
-
#EXPERIMENTAL
|
| 174 |
latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
|
| 175 |
|
| 176 |
-
#
|
| 177 |
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 178 |
|
| 179 |
-
#
|
| 180 |
images = (images / 2 + 0.5).clamp(0, 1)
|
| 181 |
|
| 182 |
if images.dim() == 5:
|
|
|
|
| 81 |
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 82 |
device = self.device
|
| 83 |
|
| 84 |
+
# generate camera batch
|
| 85 |
if elevations is None or azimuths is None:
|
| 86 |
elevations = [45] * 4
|
| 87 |
azimuths = [0, 90, 180, 270]
|
|
|
|
| 90 |
camera_batch = self.generate_camera_batch(elevations, azimuths, use_abs=self.use_abs_extrinsics)
|
| 91 |
camera_batch = {k: v[None].repeat_interleave(batch_size, dim=0).to(device) for k, v in camera_batch.items()}
|
| 92 |
|
| 93 |
+
# prepare gaussian blob initialization
|
| 94 |
blob = self.get_gaussian_image(sigma=blob_sigma).to(device)
|
| 95 |
camera_batch["img"] = blob.unsqueeze(0).unsqueeze(0).repeat(batch_size, n_views, 1, 1, 1)
|
| 96 |
|
| 97 |
+
# encode text
|
| 98 |
text_input_ids = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt").input_ids.to(device)
|
| 99 |
text_embeddings = self.text_encoder(text_input_ids)[0]
|
| 100 |
|
| 101 |
+
# prepare unconditional embeddings for classifier-free guidance
|
| 102 |
max_length = text_input_ids.shape[-1]
|
| 103 |
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
| 104 |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
| 105 |
|
| 106 |
+
# encode camera data
|
| 107 |
camera_embeddings = self.cc_projection(camera_batch["cam"]).to(device)
|
| 108 |
|
| 109 |
+
# prepare latents
|
| 110 |
latent_height, latent_width = self.vae.config.sample_size // 8, self.vae.config.sample_size // 8
|
| 111 |
latents = self.prepare_latents(
|
| 112 |
batch_size,
|
|
|
|
| 119 |
generator=None,
|
| 120 |
)
|
| 121 |
|
| 122 |
+
# prepare epi_constraint_masks (placeholder- replace with actual implementation later - MIGHT AFFECT PERFORMANCE)
|
| 123 |
epi_constraint_masks = torch.ones(batch_size, n_views, latent_height, latent_width, n_views, latent_height, latent_width, dtype=torch.bool, device=device)
|
| 124 |
|
| 125 |
+
# prepare plucker embeddings (placeholder, replace with actual implementation - MIGHT AFFECT PERFORMANCE)
|
| 126 |
plucker_embeds = torch.zeros(batch_size, n_views, 6, latent_height, latent_width, device=device)
|
| 127 |
|
| 128 |
latent_height, latent_width = 64, 64 # Fixed to match the required shape [batch_size, 1, 4, 64, 64]
|
| 129 |
n_objects = 2;
|
| 130 |
latents = torch.randn(n_objects, n_views, 4, 64, 64, device=device, dtype=self.unet.dtype)
|
| 131 |
|
| 132 |
+
# set up scheduler
|
| 133 |
# self.scheduler.set_timesteps(num_inference_steps)
|
| 134 |
self.scheduler.set_timesteps(50)
|
| 135 |
+
# repeat text_embeddings to match the desired dimensions
|
| 136 |
+
text_embeddings = text_embeddings.repeat(n_objects, 1, 1) # Shape: [2, max_seq_len, 768]
|
| 137 |
|
| 138 |
+
# reshape text_embeddings to match [n_objects, n_views, max_seq_len, 512]
|
| 139 |
text_embeddings = text_embeddings.unsqueeze(1).repeat(1, n_views, 1, 1)
|
| 140 |
camera_embeddings = camera_embeddings.repeat(n_objects, 1, 1, 1)
|
| 141 |
+
# denoising loop
|
| 142 |
for t in tqdm(self.scheduler.timesteps):
|
| 143 |
+
# expand timesteps to match shape [batch_size, 1, 1]
|
| 144 |
# timesteps = torch.full((batch_size, 1, 1), t, device=device, dtype=torch.long)
|
| 145 |
timesteps = torch.full((n_objects, n_views), t, device=device, dtype=torch.long)
|
| 146 |
|
| 147 |
+
# prepare context
|
| 148 |
context = [
|
|
|
|
|
|
|
|
|
|
| 149 |
text_embeddings.to(device), # [n_objects, n_views, max_seq_len, 768]
|
| 150 |
camera_embeddings, # [n_objects, n_views, 1280]
|
| 151 |
torch.ones(n_objects, n_views, 6, 32, 32).to(device)
|
|
|
|
| 158 |
context=context
|
| 159 |
)
|
| 160 |
|
| 161 |
+
# perform guidance
|
| 162 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 163 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 164 |
|
| 165 |
|
| 166 |
+
# compute previous noisy sample
|
| 167 |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 168 |
|
| 169 |
# reduce latents
|
| 170 |
+
#EXPERIMENTAL - MIGHT AFFECT PERFORMANCE
|
| 171 |
latents_reshaped = latents[:, 0, :, :, :] # Selecting the first view
|
| 172 |
|
| 173 |
+
# decode latents
|
| 174 |
images = self.vae.decode(latents_reshaped / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 175 |
|
| 176 |
+
# post-process images
|
| 177 |
images = (images / 2 + 0.5).clamp(0, 1)
|
| 178 |
|
| 179 |
if images.dim() == 5:
|