Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -647,14 +647,7 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 647 |
avg_loss = epoch_loss / len(train_loader)
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
| 649 |
|
| 650 |
-
|
| 651 |
-
ckpt_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
|
| 652 |
-
torch.save({
|
| 653 |
-
"model_state_dict": MODEL.state_dict(),
|
| 654 |
-
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 655 |
-
"config": CONFIG
|
| 656 |
-
}, ckpt_path)
|
| 657 |
-
print(f"💾 Saved checkpoint at epoch {epoch+1}")
|
| 658 |
|
| 659 |
# Save model
|
| 660 |
MODEL.eval()
|
|
|
|
| 647 |
avg_loss = epoch_loss / len(train_loader)
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
| 649 |
|
| 650 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
|
| 652 |
# Save model
|
| 653 |
MODEL.eval()
|