Add files using upload-large-folder tool
Browse files- pipeline_diffusionsat.py +1 -0
- unet/sat_unet.py +3 -1
pipeline_diffusionsat.py
CHANGED
|
@@ -161,6 +161,7 @@ class DiffusionSatPipeline(DiffusionPipeline):
|
|
| 161 |
disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
|
| 162 |
enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
|
| 163 |
_execution_device = DiffusersStableDiffusionPipeline._execution_device
|
|
|
|
| 164 |
_encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
|
| 165 |
run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
|
| 166 |
decode_latents = DiffusersStableDiffusionPipeline.decode_latents
|
|
|
|
| 161 |
disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
|
| 162 |
enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
|
| 163 |
_execution_device = DiffusersStableDiffusionPipeline._execution_device
|
| 164 |
+
encode_prompt = DiffusersStableDiffusionPipeline.encode_prompt
|
| 165 |
_encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
|
| 166 |
run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
|
| 167 |
decode_latents = DiffusersStableDiffusionPipeline.decode_latents
|
unet/sat_unet.py
CHANGED
|
@@ -74,7 +74,9 @@ class SatUNet(UNet2DConditionModel):
|
|
| 74 |
# Reuse the same projection used for timestep encoding to stay aligned with base embeddings.
|
| 75 |
projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
for idx, md_embed in enumerate(self.metadata_embedding):
|
| 79 |
md_emb = md_emb + md_embed(projected[:, idx, :])
|
| 80 |
|
|
|
|
| 74 |
# Reuse the same projection used for timestep encoding to stay aligned with base embeddings.
|
| 75 |
projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
|
| 76 |
|
| 77 |
+
# md_embed outputs time_embed_dim (1280), not projected.shape[-1] (320)
|
| 78 |
+
time_embed_dim = self.time_embedding.linear_2.out_features
|
| 79 |
+
md_emb = torch.zeros((md_bsz, time_embed_dim), device=metadata.device, dtype=dtype)
|
| 80 |
for idx, md_embed in enumerate(self.metadata_embedding):
|
| 81 |
md_emb = md_emb + md_embed(projected[:, idx, :])
|
| 82 |
|