BiliSakura commited on
Commit
0c5f308
·
verified ·
1 Parent(s): 11480ee

Add files using upload-large-folder tool

Browse files
Files changed (2) hide show
  1. pipeline_diffusionsat.py +1 -0
  2. unet/sat_unet.py +3 -1
pipeline_diffusionsat.py CHANGED
@@ -161,6 +161,7 @@ class DiffusionSatPipeline(DiffusionPipeline):
161
  disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
162
  enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
163
  _execution_device = DiffusersStableDiffusionPipeline._execution_device
 
164
  _encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
165
  run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
166
  decode_latents = DiffusersStableDiffusionPipeline.decode_latents
 
161
  disable_vae_slicing = DiffusersStableDiffusionPipeline.disable_vae_slicing
162
  enable_sequential_cpu_offload = DiffusersStableDiffusionPipeline.enable_sequential_cpu_offload
163
  _execution_device = DiffusersStableDiffusionPipeline._execution_device
164
+ encode_prompt = DiffusersStableDiffusionPipeline.encode_prompt
165
  _encode_prompt = DiffusersStableDiffusionPipeline._encode_prompt
166
  run_safety_checker = DiffusersStableDiffusionPipeline.run_safety_checker
167
  decode_latents = DiffusersStableDiffusionPipeline.decode_latents
unet/sat_unet.py CHANGED
@@ -74,7 +74,9 @@ class SatUNet(UNet2DConditionModel):
74
  # Reuse the same projection used for timestep encoding to stay aligned with base embeddings.
75
  projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
76
 
77
- md_emb = projected.new_zeros((md_bsz, projected.shape[-1]))
 
 
78
  for idx, md_embed in enumerate(self.metadata_embedding):
79
  md_emb = md_emb + md_embed(projected[:, idx, :])
80
 
 
74
  # Reuse the same projection used for timestep encoding to stay aligned with base embeddings.
75
  projected = self.time_proj(metadata.view(-1)).view(md_bsz, self.num_metadata, -1).to(dtype=dtype)
76
 
77
+ # md_embed outputs time_embed_dim (1280), not projected.shape[-1] (320)
78
+ time_embed_dim = self.time_embedding.linear_2.out_features
79
+ md_emb = torch.zeros((md_bsz, time_embed_dim), device=metadata.device, dtype=dtype)
80
  for idx, md_embed in enumerate(self.metadata_embedding):
81
  md_emb = md_emb + md_embed(projected[:, idx, :])
82