YashNagraj75 commited on
Commit
5b58fc5
·
1 Parent(s): fdc368b

Change the file name

Browse files
Files changed (1) hide show
  1. train_dit.py +3 -9
train_dit.py CHANGED
@@ -165,9 +165,7 @@ def train(args):
165
  "step": step_count,
166
  "optimizer": optimizer.state_dict(),
167
  },
168
- os.path.join(
169
- train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
170
- ),
171
  )
172
  if (epoch + 1) % 5 == 0: # Save every 5 epochs
173
  artifact = wandb.Artifact(
@@ -176,9 +174,7 @@ def train(args):
176
  description=f"DIT model checkpoint at epoch {epoch + 1}",
177
  )
178
  artifact.add_file(
179
- os.path.join(
180
- train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
181
- ),
182
  )
183
  wandb.log_artifact(artifact)
184
 
@@ -186,9 +182,7 @@ def train(args):
186
  "dit_model_final", type="model", description="Final DIT model checkpoint"
187
  )
188
  final_artifact.add_file(
189
- os.path.join(
190
- train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
191
- )
192
  )
193
  wandb.log_artifact(final_artifact)
194
 
 
165
  "step": step_count,
166
  "optimizer": optimizer.state_dict(),
167
  },
168
+ os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
 
 
169
  )
170
  if (epoch + 1) % 5 == 0: # Save every 5 epochs
171
  artifact = wandb.Artifact(
 
174
  description=f"DIT model checkpoint at epoch {epoch + 1}",
175
  )
176
  artifact.add_file(
177
+ os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
 
 
178
  )
179
  wandb.log_artifact(artifact)
180
 
 
182
  "dit_model_final", type="model", description="Final DIT model checkpoint"
183
  )
184
  final_artifact.add_file(
185
+ os.path.join(train_config["task_name"], train_config["dit_ckpt_name"])
 
 
186
  )
187
  wandb.log_artifact(final_artifact)
188