Upload folder using huggingface_hub
Browse files- GramSmall/train.py +33 -12
GramSmall/train.py
CHANGED
|
@@ -161,19 +161,36 @@ class VQGANModel(flax.struct.PyTreeNode):
|
|
| 161 |
assert reconstructed_images.shape == images.shape
|
| 162 |
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
gram_loss =
|
| 176 |
-
|
| 177 |
|
| 178 |
|
| 179 |
|
|
@@ -207,6 +224,7 @@ class VQGANModel(flax.struct.PyTreeNode):
|
|
| 207 |
+ (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
|
| 208 |
+ (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
|
| 209 |
+ (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
|
|
|
|
| 210 |
#+ (smooth_loss * FLAGS.model['pl_weight'] )
|
| 211 |
codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
|
| 212 |
|
|
@@ -218,6 +236,7 @@ class VQGANModel(flax.struct.PyTreeNode):
|
|
| 218 |
'perceptual_loss': perceptual_loss,
|
| 219 |
'quantizer_loss': quantizer_loss,
|
| 220 |
'codebook_usage': codebook_usage,
|
|
|
|
| 221 |
#'pl_loss': smooth_loss,
|
| 222 |
}
|
| 223 |
|
|
@@ -581,6 +600,8 @@ def main(_):
|
|
| 581 |
|
| 582 |
model, update_info = model.update(batch_images)
|
| 583 |
|
|
|
|
|
|
|
| 584 |
if i % FLAGS.log_interval == 0:
|
| 585 |
update_info = jax.tree.map(lambda x: x.mean(), update_info)
|
| 586 |
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
|
|
|
|
| 161 |
assert reconstructed_images.shape == images.shape
|
| 162 |
|
| 163 |
|
| 164 |
+
def calculate_covariance_loss_single(image):
|
| 165 |
+
"""Calculates the covariance loss for one image."""
|
| 166 |
+
# image.shape is (H, W, C)
|
| 167 |
+
C = image.shape[-1]
|
| 168 |
+
|
| 169 |
+
# Reshape the spatial dimensions into one dimension of "observations"
|
| 170 |
+
# New shape: (H*W, C)
|
| 171 |
+
reshaped_features = image.reshape(-1, C)
|
| 172 |
+
|
| 173 |
+
# Calculate the covariance matrix of the channels.
|
| 174 |
+
# We treat each channel as a variable and spatial locations as observations.
|
| 175 |
+
# The resulting shape will be (C, C).
|
| 176 |
+
cov_matrix = jnp.cov(reshaped_features, rowvar=False)
|
| 177 |
+
|
| 178 |
+
# The target is the identity matrix of size (C, C)
|
| 179 |
+
identity_matrix = jnp.eye(C)
|
| 180 |
+
|
| 181 |
+
# The loss is the sum of squared differences (Frobenius norm squared)
|
| 182 |
+
loss = jnp.sum(jnp.square(cov_matrix - identity_matrix))
|
| 183 |
+
|
| 184 |
+
return loss
|
| 185 |
+
|
| 186 |
|
| 187 |
+
B, H, W, C = reconstructed_images.shape
|
| 188 |
+
reshaped_features = reconstructed_images.reshape(B, -1, C)
|
| 189 |
+
batched_loss_fn = jax.vmap(calculate_covariance_loss_single, in_axes=0)
|
| 190 |
+
per_image_losses = batched_loss_fn(reconstructed_images)
|
| 191 |
+
|
| 192 |
+
gram_loss = jnp.mean(per_image_losses) * 1
|
| 193 |
+
#Gram loss is very low - let's crank it up until it starts harming thngs?
|
| 194 |
|
| 195 |
|
| 196 |
|
|
|
|
| 224 |
+ (quantizer_loss * FLAGS.model['quantizer_loss_ratio']) \
|
| 225 |
+ (d_loss_for_vae * FLAGS.model['g_adversarial_loss_weight']) \
|
| 226 |
+ (perceptual_loss * FLAGS.model['perceptual_loss_weight']) \
|
| 227 |
+
+ gram_loss
|
| 228 |
#+ (smooth_loss * FLAGS.model['pl_weight'] )
|
| 229 |
codebook_usage = result_dict['usage'] if 'usage' in result_dict else 0.0
|
| 230 |
|
|
|
|
| 236 |
'perceptual_loss': perceptual_loss,
|
| 237 |
'quantizer_loss': quantizer_loss,
|
| 238 |
'codebook_usage': codebook_usage,
|
| 239 |
+
'cov loss': gram_loss
|
| 240 |
#'pl_loss': smooth_loss,
|
| 241 |
}
|
| 242 |
|
|
|
|
| 600 |
|
| 601 |
model, update_info = model.update(batch_images)
|
| 602 |
|
| 603 |
+
print(update_info)
|
| 604 |
+
|
| 605 |
if i % FLAGS.log_interval == 0:
|
| 606 |
update_info = jax.tree.map(lambda x: x.mean(), update_info)
|
| 607 |
train_metrics = {f'training/{k}': v for k, v in update_info.items()}
|