Spanicin commited on
Commit
476d0fb
Β·
verified Β·
1 Parent(s): 0a1421d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -11
app.py CHANGED
@@ -582,12 +582,20 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
582
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
583
 
584
  try:
 
 
 
 
 
 
585
  # Setup
586
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
587
  CONFIG = {
588
- "base_channels": 64,
589
- "channel_mults": (1, 2, 4),
590
- "context_dim": 256,
591
  "image_size": image_size,
592
  "timesteps": 1000
593
  }
@@ -623,13 +631,14 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
623
  logs = [f"πŸš€ Training started on {DEVICE}"]
624
  logs.append(f"πŸ“Š Model parameters: {num_params:,}")
625
  logs.append(f"πŸ“ Training samples: {len(train_dataset)}")
 
 
626
  logs.append("-" * 40)
627
 
628
- total_steps = epochs * len(train_loader)
629
- current_step = 0
630
-
631
- for epoch in range(epochs):
632
  epoch_loss = 0
 
 
633
  for images, texts in train_loader:
634
  images = images.to(DEVICE)
635
  context = TEXT_ENCODER(texts, DEVICE)
@@ -641,10 +650,28 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
641
  optimizer.step()
642
 
643
  epoch_loss += loss.item()
644
- current_step += 1
 
 
 
 
 
 
 
 
645
 
646
- avg_loss = epoch_loss / len(train_loader)
647
- logs.append(f"Epoch {epoch+1}/{epochs}: loss = {avg_loss:.4f}")
 
 
 
 
 
 
 
 
 
 
648
 
649
  # Save model
650
  MODEL.eval()
@@ -662,7 +689,8 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
662
  return "\n".join(logs)
663
 
664
  except Exception as e:
665
- return f"❌ Training failed: {str(e)}"
 
666
 
667
 
668
  def load_checkpoint(checkpoint_file):
 
582
  global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
583
 
584
  try:
585
+ # Clear GPU memory
586
+ import gc
587
+ gc.collect()
588
+ if torch.cuda.is_available():
589
+ torch.cuda.empty_cache()
590
+
591
  # Setup
592
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
593
+
594
+ # Use smaller model for T4 GPU
595
  CONFIG = {
596
+ "base_channels": 48, # Reduced from 64
597
+ "channel_mults": (1, 2, 4), # Keep same
598
+ "context_dim": 192, # Reduced from 256
599
  "image_size": image_size,
600
  "timesteps": 1000
601
  }
 
631
  logs = [f"πŸš€ Training started on {DEVICE}"]
632
  logs.append(f"πŸ“Š Model parameters: {num_params:,}")
633
  logs.append(f"πŸ“ Training samples: {len(train_dataset)}")
634
+ logs.append(f"πŸ–ΌοΈ Image size: {image_size}x{image_size}")
635
+ logs.append(f"πŸ“¦ Batch size: {batch_size}")
636
  logs.append("-" * 40)
637
 
638
+ for epoch in range(int(epochs)):
 
 
 
639
  epoch_loss = 0
640
+ batch_count = 0
641
+
642
  for images, texts in train_loader:
643
  images = images.to(DEVICE)
644
  context = TEXT_ENCODER(texts, DEVICE)
 
650
  optimizer.step()
651
 
652
  epoch_loss += loss.item()
653
+ batch_count += 1
654
+
655
+ # Clear cache periodically
656
+ if batch_count % 50 == 0:
657
+ if torch.cuda.is_available():
658
+ torch.cuda.empty_cache()
659
+
660
+ avg_loss = epoch_loss / max(batch_count, 1)
661
+ logs.append(f"Epoch {epoch+1}/{int(epochs)}: loss = {avg_loss:.4f}")
662
 
663
+ # Save checkpoint every 10 epochs
664
+ if (epoch + 1) % 10 == 0:
665
+ print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
666
+ checkpoint_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
667
+ os.makedirs("checkpoints", exist_ok=True)
668
+ torch.save({
669
+ "model_state_dict": MODEL.state_dict(),
670
+ "text_encoder_state_dict": TEXT_ENCODER.state_dict(),
671
+ "config": CONFIG,
672
+ "epoch": epoch + 1
673
+ }, checkpoint_path)
674
+ logs.append(f"πŸ’Ύ Checkpoint saved: {checkpoint_path}")
675
 
676
  # Save model
677
  MODEL.eval()
 
689
  return "\n".join(logs)
690
 
691
  except Exception as e:
692
+ import traceback
693
+ return f"❌ Training failed: {str(e)}\n{traceback.format_exc()}"
694
 
695
 
696
  def load_checkpoint(checkpoint_file):