Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -646,6 +646,14 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 646 |
|
| 647 |
avg_loss = epoch_loss / len(train_loader)
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
# Save model
|
| 651 |
MODEL.eval()
|
|
@@ -656,6 +664,19 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 656 |
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 657 |
"config": CONFIG
|
| 658 |
}, save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
logs.append("-" * 40)
|
| 661 |
logs.append(f"✅ Model saved to {save_path}")
|
|
|
|
| 646 |
|
| 647 |
avg_loss = epoch_loss / len(train_loader)
|
| 648 |
logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
|
| 649 |
+
|
| 650 |
+
if (epoch + 1) % 10 == 0:
|
| 651 |
+
torch.save({
|
| 652 |
+
"model_state_dict": MODEL.state_dict(),
|
| 653 |
+
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 654 |
+
"config": CONFIG
|
| 655 |
+
}, f"checkpoints/{save_name}_epoch{epoch+1}.pt")
|
| 656 |
+
logs.append(f"💾 Saved checkpoint at epoch {epoch+1}")
|
| 657 |
|
| 658 |
# Save model
|
| 659 |
MODEL.eval()
|
|
|
|
| 664 |
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 665 |
"config": CONFIG
|
| 666 |
}, save_path)
|
| 667 |
+
|
| 668 |
+
try:
|
| 669 |
+
from huggingface_hub import HfApi
|
| 670 |
+
api = HfApi()
|
| 671 |
+
api.upload_file(
|
| 672 |
+
path_or_fileobj=save_path,
|
| 673 |
+
path_in_repo=f"checkpoints/{save_name}.pt",
|
| 674 |
+
repo_id="Spanicin/candlestick-diffusion",
|
| 675 |
+
repo_type="space"
|
| 676 |
+
)
|
| 677 |
+
logs.append("☁️ Checkpoint uploaded to repo")
|
| 678 |
+
except Exception as e:
|
| 679 |
+
logs.append(f"⚠️ Upload failed: {e}")
|
| 680 |
|
| 681 |
logs.append("-" * 40)
|
| 682 |
logs.append(f"✅ Model saved to {save_path}")
|