KublaiKhan1 commited on
Commit
0e74bea
·
verified ·
1 Parent(s): 1da375f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. gram_smol_latent/train.py +5 -2
gram_smol_latent/train.py CHANGED
@@ -183,9 +183,12 @@ class VQGANModel(flax.struct.PyTreeNode):
183
 
184
  return loss
185
 
 
186
 
187
- B, H, W, C = reconstructed_images.shape
188
- reshaped_features = reconstructed_images.reshape(B, -1, C)
 
 
189
  batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
190
  per_image_losses = batched_loss_fn(reconstructed_images)
191
 
 
183
 
184
  return loss
185
 
186
+ #latents = results_dict["latents"]
187
 
188
+ #B, H, W, C = reconstructed_images.shape
189
+ #reshaped_features = reconstructed_images.reshape(B, -1, C)
190
+ B, H, W, C = result_dict["latents"].shape
191
+ reshaped_features = result_dict["latents"].reshape(B, -1, C)
192
  batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
193
  per_image_losses = batched_loss_fn(reconstructed_images)
194