Spanicin commited on
Commit
bbe45e3
·
verified ·
1 Parent(s): be479fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -646,6 +646,14 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
646
 
647
  avg_loss = epoch_loss / len(train_loader)
648
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
 
 
 
 
 
 
 
 
649
 
650
  # Save model
651
  MODEL.eval()
@@ -656,6 +664,19 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
656
  "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
657
  "config": CONFIG
658
  }, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
  logs.append("-" * 40)
661
  logs.append(f"✅ Model saved to {save_path}")
 
646
 
647
  avg_loss = epoch_loss / len(train_loader)
648
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
649
+
650
+ if (epoch + 1) % 10 == 0:
651
+ torch.save({
652
+ "model_state_dict": MODEL.state_dict(),
653
+ "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
654
+ "config": CONFIG
655
+ }, f"checkpoints/{save_name}_epoch{epoch+1}.pt")
656
+ logs.append(f"💾 Saved checkpoint at epoch {epoch+1}")
657
 
658
  # Save model
659
  MODEL.eval()
 
664
  "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
665
  "config": CONFIG
666
  }, save_path)
667
+
668
+ try:
669
+ from huggingface_hub import HfApi
670
+ api = HfApi()
671
+ api.upload_file(
672
+ path_or_fileobj=save_path,
673
+ path_in_repo=f"checkpoints/{save_name}.pt",
674
+ repo_id="Spanicin/candlestick-diffusion",
675
+ repo_type="space"
676
+ )
677
+ logs.append("☁️ Checkpoint uploaded to repo")
678
+ except Exception as e:
679
+ logs.append(f"⚠️ Upload failed: {e}")
680
 
681
  logs.append("-" * 40)
682
  logs.append(f"✅ Model saved to {save_path}")