Commit ·
5b58fc5
1
Parent(s): fdc368b
Change the file name
Browse files- train_dit.py +3 -9
train_dit.py
CHANGED
|
@@ -165,9 +165,7 @@ def train(args):
|
|
| 165 |
"step": step_count,
|
| 166 |
"optimizer": optimizer.state_dict(),
|
| 167 |
},
|
| 168 |
-
os.path.join(
|
| 169 |
-
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
|
| 170 |
-
),
|
| 171 |
)
|
| 172 |
if (epoch + 1) % 5 == 0: # Save every 5 epochs
|
| 173 |
artifact = wandb.Artifact(
|
|
@@ -176,9 +174,7 @@ def train(args):
|
|
| 176 |
description=f"DIT model checkpoint at epoch {epoch + 1}",
|
| 177 |
)
|
| 178 |
artifact.add_file(
|
| 179 |
-
os.path.join(
|
| 180 |
-
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
|
| 181 |
-
),
|
| 182 |
)
|
| 183 |
wandb.log_artifact(artifact)
|
| 184 |
|
|
@@ -186,9 +182,7 @@ def train(args):
|
|
| 186 |
"dit_model_final", type="model", description="Final DIT model checkpoint"
|
| 187 |
)
|
| 188 |
final_artifact.add_file(
|
| 189 |
-
os.path.join(
|
| 190 |
-
train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
|
| 191 |
-
)
|
| 192 |
)
|
| 193 |
wandb.log_artifact(final_artifact)
|
| 194 |
|
|
|
|
| 165 |
"step": step_count,
|
| 166 |
"optimizer": optimizer.state_dict(),
|
| 167 |
},
|
| 168 |
+
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
|
|
|
|
|
|
|
| 169 |
)
|
| 170 |
if (epoch + 1) % 5 == 0: # Save every 5 epochs
|
| 171 |
artifact = wandb.Artifact(
|
|
|
|
| 174 |
description=f"DIT model checkpoint at epoch {epoch + 1}",
|
| 175 |
)
|
| 176 |
artifact.add_file(
|
| 177 |
+
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
|
|
|
|
|
|
|
| 178 |
)
|
| 179 |
wandb.log_artifact(artifact)
|
| 180 |
|
|
|
|
| 182 |
"dit_model_final", type="model", description="Final DIT model checkpoint"
|
| 183 |
)
|
| 184 |
final_artifact.add_file(
|
| 185 |
+
os.path.join(train_config["task_name"], train_config["dit_ckpt_name"])
|
|
|
|
|
|
|
| 186 |
)
|
| 187 |
wandb.log_artifact(final_artifact)
|
| 188 |
|