Upload folder using huggingface_hub
Browse files
gram_smol_latent/train.py
CHANGED
|
@@ -183,9 +183,12 @@ class VQGANModel(flax.struct.PyTreeNode):
|
|
| 183 |
|
| 184 |
return loss
|
| 185 |
|
|
|
|
| 186 |
|
| 187 |
-
B, H, W, C = reconstructed_images.shape
|
| 188 |
-
reshaped_features = reconstructed_images.reshape(B, -1, C)
|
|
|
|
|
|
|
| 189 |
batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
|
| 190 |
per_image_losses = batched_loss_fn(reconstructed_images)
|
| 191 |
|
|
|
|
| 183 |
|
| 184 |
return loss
|
| 185 |
|
| 186 |
+
#latents = results_dict["latents"]
|
| 187 |
|
| 188 |
+
#B, H, W, C = reconstructed_images.shape
|
| 189 |
+
#reshaped_features = reconstructed_images.reshape(B, -1, C)
|
| 190 |
+
B, H, W, C = result_dict["latents"].shape
|
| 191 |
+
reshaped_features = result_dict["latents"].reshape(B, -1, C)
|
| 192 |
batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
|
| 193 |
per_image_losses = batched_loss_fn(reconstructed_images)
|
| 194 |
|