KublaiKhan1 commited on
Commit
1a7c3b5
·
verified ·
1 Parent(s): 561f2b7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. GramSmall/train.py +33 -12
GramSmall/train.py CHANGED
@@ -161,19 +161,36 @@ class VQGANModel(flax.struct.PyTreeNode):
161
  assert reconstructed_images.shape == images.shape
162
 
163
 
164
- #Gram is not normalized, so let's try that first.
165
- reshaped_latents = result_dict["latents"].reshape(result_dict["latents"].shape[0],-1,result_dict["latents"].shape[-1])
166
- #Reshape to batch x patches x embeddings
167
- #Calculate gram matrix
168
- x_transposed = jnp.transpose(reshaped_latents, (0, 2, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- gram_matrix = jnp.matmul(reshaped_latents, x_transposed)
171
- diagonal_elements = jnp.einsum('bii->bi', gram_matrix)
172
- sum_of_diagonals = jnp.sum(diagonal_elements)
173
- total_sum = jnp.sum(gram_matrix)
174
- gram_loss = total_sum - sum_of_diagonals
175
- gram_loss = gram_loss / 992 #divide by 32x32 - 32
176
- gram_loss = gram_loss / 40 #Try this for now
177
 
178
 
179
 
@@ -207,6 +224,7 @@ class VQGANModel(flax.struct.PyTreeNode):
207
  + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
208
  + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
209
  + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
 
210
  #+ (smooth_loss * FLAGS.model['pl_weight'] )
211
  codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
212
 
@@ -218,6 +236,7 @@ class VQGANModel(flax.struct.PyTreeNode):
218
  'perceptual_loss': perceptual_loss,
219
  'quantizer_loss': quantizer_loss,
220
  'codebook_usage': codebook_usage,
 
221
  #'pl_loss': smooth_loss,
222
  }
223
 
@@ -581,6 +600,8 @@ def main(_):
581
 
582
  model, update_info = model.update(batch_images)
583
 
 
 
584
  if i % FLAGS.log_interval == 0:
585
  update_info = jax.tree.map(lambda x: x.mean(), update_info)
586
  train_metrics = {f'training/{k}': v for k, v in update_info.items()}
 
161
  assert reconstructed_images.shape == images.shape
162
 
163
 
164
+ def calculate_covariance_loss_single(image):
165
+ """Calculates the covariance loss for one image."""
166
+ # image.shape is (H, W, C)
167
+ C = image.shape[-1]
168
+
169
+ # Reshape the spatial dimensions into one dimension of "observations"
170
+ # New shape: (H*W, C)
171
+ reshaped_features = image.reshape(-1, C)
172
+
173
+ # Calculate the covariance matrix of the channels.
174
+ # We treat each channel as a variable and spatial locations as observations.
175
+ # The resulting shape will be (C, C).
176
+ cov_matrix = jnp.cov(reshaped_features, rowvar=False)
177
+
178
+ # The target is the identity matrix of size (C, C)
179
+ identity_matrix = jnp.eye(C)
180
+
181
+ # The loss is the sum of squared differences (Frobenius norm squared)
182
+ loss = jnp.sum(jnp.square(cov_matrix - identity_matrix))
183
+
184
+ return loss
185
+
186
 
187
+ B, H, W, C = reconstructed_images.shape
188
+ reshaped_features = reconstructed_images.reshape(B, -1, C)
189
+ batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
190
+ per_image_losses = batched_loss_fn(reconstructed_images)
191
+
192
+ gram_loss = jnp.mean(per_image_losses) * 1
193
+ #Gram loss is very low - let's crank it up until it starts harming thngs?
194
 
195
 
196
 
 
224
  + (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
225
  + (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
226
  + (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
227
+ + gram_loss
228
  #+ (smooth_loss * FLAGS.model['pl_weight'] )
229
  codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
230
 
 
236
  'perceptual_loss': perceptual_loss,
237
  'quantizer_loss': quantizer_loss,
238
  'codebook_usage': codebook_usage,
239
+ 'cov loss': gram_loss
240
  #'pl_loss': smooth_loss,
241
  }
242
 
 
600
 
601
  model, update_info = model.update(batch_images)
602
 
603
+ print(update_info)
604
+
605
  if i % FLAGS.log_interval == 0:
606
  update_info = jax.tree.map(lambda x: x.mean(), update_info)
607
  train_metrics = {f'training/{k}': v for k, v in update_info.items()}