Spanicin commited on
Commit
ccadaa5
·
verified ·
1 Parent(s): bbe45e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -648,12 +648,26 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
648
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
649
 
650
  if (epoch + 1) % 10 == 0:
 
651
  torch.save({
652
  "model_state_dict": MODEL.state_dict(),
653
  "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
654
  "config": CONFIG
655
- }, f"checkpoints/{save_name}_epoch{epoch+1}.pt")
656
  logs.append(f"💾 Saved checkpoint at epoch {epoch+1}")
 
 
 
 
 
 
 
 
 
 
 
 
 
657
 
658
  # Save model
659
  MODEL.eval()
 
648
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
649
 
650
  if (epoch + 1) % 10 == 0:
651
+ ckpt_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
652
  torch.save({
653
  "model_state_dict": MODEL.state_dict(),
654
  "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
655
  "config": CONFIG
656
+ }, ckpt_path)
657
  logs.append(f"💾 Saved checkpoint at epoch {epoch+1}")
658
+ try:
659
+ from huggingface_hub import HfApi
660
+ api = HfApi()
661
+ api.upload_file(
662
+ path_or_fileobj=ckpt_path,
663
+ path_in_repo=ckpt_path,
664
+ repo_id="Spanicin/candlestick-diffusion",
665
+ repo_type="space",
666
+ token=os.environ.get("HF_TOKEN")
667
+ )
668
+ logs.append(f"☁️ Uploaded to repo")
669
+ except Exception as e:
670
+ logs.append(f"⚠️ Upload failed: {e}")
671
 
672
  # Save model
673
  MODEL.eval()