Spanicin commited on
Commit
04ea678
·
verified ·
1 Parent(s): 1b34ccd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -17
app.py CHANGED
@@ -607,7 +607,7 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
607
  train_dataset = ChartDataset(data_path, image_size=image_size, split="train")
608
  train_loader = DataLoader(
609
  train_dataset, batch_size=batch_size, shuffle=True,
610
- num_workers=0, pin_memory=True, drop_last=True, collate_fn=collate_fn
611
  )
612
 
613
  # Optimizer
@@ -642,12 +642,9 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
642
 
643
  epoch_loss += loss.item()
644
  current_step += 1
645
- print(f"Step {current_step}, loss: {loss.item():.4f}")
646
 
647
  avg_loss = epoch_loss / len(train_loader)
648
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
649
-
650
-
651
 
652
  # Save model
653
  MODEL.eval()
@@ -658,19 +655,6 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
658
  "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
659
  "config": CONFIG
660
  }, save_path)
661
-
662
- try:
663
- from huggingface_hub import HfApi
664
- api = HfApi()
665
- api.upload_file(
666
- path_or_fileobj=save_path,
667
- path_in_repo=f"checkpoints/{save_name}.pt",
668
- repo_id="Spanicin/candlestick-diffusion",
669
- repo_type="space"
670
- )
671
- logs.append("☁️ Checkpoint uploaded to repo")
672
- except Exception as e:
673
- logs.append(f"⚠️ Upload failed: {e}")
674
 
675
  logs.append("-" * 40)
676
  logs.append(f"✅ Model saved to {save_path}")
 
607
  train_dataset = ChartDataset(data_path, image_size=image_size, split="train")
608
  train_loader = DataLoader(
609
  train_dataset, batch_size=batch_size, shuffle=True,
610
+ num_workers=2, pin_memory=True, drop_last=True, collate_fn=collate_fn
611
  )
612
 
613
  # Optimizer
 
642
 
643
  epoch_loss += loss.item()
644
  current_step += 1
 
645
 
646
  avg_loss = epoch_loss / len(train_loader)
647
  logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
 
 
648
 
649
  # Save model
650
  MODEL.eval()
 
655
  "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
656
  "config": CONFIG
657
  }, save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
658
 
659
  logs.append("-" * 40)
660
  logs.append(f"✅ Model saved to {save_path}")