Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
|
@@ -582,12 +582,20 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 582 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 583 |
|
| 584 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
# Setup
|
| 586 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
| 587 |
CONFIG = {
|
| 588 |
-
"base_channels": 64
|
| 589 |
-
"channel_mults": (1, 2, 4),
|
| 590 |
-
"context_dim": 256
|
| 591 |
"image_size": image_size,
|
| 592 |
"timesteps": 1000
|
| 593 |
}
|
|
@@ -623,13 +631,14 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 623 |
logs = [f"π Training started on {DEVICE}"]
|
| 624 |
logs.append(f"π Model parameters: {num_params:,}")
|
| 625 |
logs.append(f"π Training samples: {len(train_dataset)}")
|
|
|
|
|
|
|
| 626 |
logs.append("-" * 40)
|
| 627 |
|
| 628 |
-
|
| 629 |
-
current_step = 0
|
| 630 |
-
|
| 631 |
-
for epoch in range(epochs):
|
| 632 |
epoch_loss = 0
|
|
|
|
|
|
|
| 633 |
for images, texts in train_loader:
|
| 634 |
images = images.to(DEVICE)
|
| 635 |
context = TEXT_ENCODER(texts, DEVICE)
|
|
@@ -641,10 +650,28 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 641 |
optimizer.step()
|
| 642 |
|
| 643 |
epoch_loss += loss.item()
|
| 644 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
-
|
| 647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
|
| 649 |
# Save model
|
| 650 |
MODEL.eval()
|
|
@@ -662,7 +689,8 @@ def train_model(data_path, epochs, batch_size, learning_rate, image_size, save_n
|
|
| 662 |
return "\n".join(logs)
|
| 663 |
|
| 664 |
except Exception as e:
|
| 665 |
-
|
|
|
|
| 666 |
|
| 667 |
|
| 668 |
def load_checkpoint(checkpoint_file):
|
|
|
|
| 582 |
global MODEL, TEXT_ENCODER, DIFFUSION, DEVICE, CONFIG
|
| 583 |
|
| 584 |
try:
|
| 585 |
+
# Clear GPU memory
|
| 586 |
+
import gc
|
| 587 |
+
gc.collect()
|
| 588 |
+
if torch.cuda.is_available():
|
| 589 |
+
torch.cuda.empty_cache()
|
| 590 |
+
|
| 591 |
# Setup
|
| 592 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 593 |
+
|
| 594 |
+
# Use smaller model for T4 GPU
|
| 595 |
CONFIG = {
|
| 596 |
+
"base_channels": 48, # Reduced from 64
|
| 597 |
+
"channel_mults": (1, 2, 4), # Keep same
|
| 598 |
+
"context_dim": 192, # Reduced from 256
|
| 599 |
"image_size": image_size,
|
| 600 |
"timesteps": 1000
|
| 601 |
}
|
|
|
|
| 631 |
logs = [f"π Training started on {DEVICE}"]
|
| 632 |
logs.append(f"π Model parameters: {num_params:,}")
|
| 633 |
logs.append(f"π Training samples: {len(train_dataset)}")
|
| 634 |
+
logs.append(f"πΌοΈ Image size: {image_size}x{image_size}")
|
| 635 |
+
logs.append(f"π¦ Batch size: {batch_size}")
|
| 636 |
logs.append("-" * 40)
|
| 637 |
|
| 638 |
+
for epoch in range(int(epochs)):
|
|
|
|
|
|
|
|
|
|
| 639 |
epoch_loss = 0
|
| 640 |
+
batch_count = 0
|
| 641 |
+
|
| 642 |
for images, texts in train_loader:
|
| 643 |
images = images.to(DEVICE)
|
| 644 |
context = TEXT_ENCODER(texts, DEVICE)
|
|
|
|
| 650 |
optimizer.step()
|
| 651 |
|
| 652 |
epoch_loss += loss.item()
|
| 653 |
+
batch_count += 1
|
| 654 |
+
|
| 655 |
+
# Clear cache periodically
|
| 656 |
+
if batch_count % 50 == 0:
|
| 657 |
+
if torch.cuda.is_available():
|
| 658 |
+
torch.cuda.empty_cache()
|
| 659 |
+
|
| 660 |
+
avg_loss = epoch_loss / max(batch_count, 1)
|
| 661 |
+
logs.append(f"Epoch {epoch+1}/{int(epochs)}: loss = {avg_loss:.4f}")
|
| 662 |
|
| 663 |
+
# Save checkpoint every 10 epochs
|
| 664 |
+
if (epoch + 1) % 10 == 0:
|
| 665 |
+
print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
|
| 666 |
+
checkpoint_path = f"checkpoints/{save_name}_epoch{epoch+1}.pt"
|
| 667 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 668 |
+
torch.save({
|
| 669 |
+
"model_state_dict": MODEL.state_dict(),
|
| 670 |
+
"text_encoder_state_dict": TEXT_ENCODER.state_dict(),
|
| 671 |
+
"config": CONFIG,
|
| 672 |
+
"epoch": epoch + 1
|
| 673 |
+
}, checkpoint_path)
|
| 674 |
+
logs.append(f"πΎ Checkpoint saved: {checkpoint_path}")
|
| 675 |
|
| 676 |
# Save model
|
| 677 |
MODEL.eval()
|
|
|
|
| 689 |
return "\n".join(logs)
|
| 690 |
|
| 691 |
except Exception as e:
|
| 692 |
+
import traceback
|
| 693 |
+
return f"β Training failed: {str(e)}\n{traceback.format_exc()}"
|
| 694 |
|
| 695 |
|
| 696 |
def load_checkpoint(checkpoint_file):
|