floorplan-segmentation / hf_train.py
hallelu's picture
Upload 4 files
69f257e verified
#!/usr/bin/env python3
"""
🏠 Floorplan Segmentation Training on Hugging Face
Complete training script with proper logging and error handling
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import time
import gc
from datetime import datetime
print("πŸš€ Starting Floorplan Segmentation Training on Hugging Face...")
print(f"⏰ Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# ============================================================================
# 1. MODEL ARCHITECTURE
# ============================================================================
class UltraSimpleModel(nn.Module):
def __init__(self, n_channels=3, n_classes=5):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(n_channels, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 2, stride=2),
nn.ReLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 2, stride=2),
nn.ReLU(),
nn.Conv2d(16, n_classes, 1),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# ============================================================================
# 2. DATASET CLASS
# ============================================================================
class SimpleDataset(Dataset):
def __init__(self, data_dir, image_size=224):
self.data_dir = data_dir
self.image_size = image_size
# Get image files
self.image_files = []
for file in os.listdir(data_dir):
if file.endswith('_image.png'):
mask_file = file.replace('_image.png', '_mask.png')
if os.path.exists(os.path.join(data_dir, mask_file)):
self.image_files.append(file)
print(f"πŸ“Š Found {len(self.image_files)} image-mask pairs in {data_dir}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
# Load image
image_file = self.image_files[idx]
image_path = os.path.join(self.data_dir, image_file)
mask_path = os.path.join(self.data_dir, image_file.replace('_image.png', '_mask.png'))
# Load and preprocess
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (self.image_size, self.image_size))
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (self.image_size, self.image_size))
# Convert to tensors
image = torch.from_numpy(image).float().permute(2, 0, 1) / 255.0
mask = torch.from_numpy(mask).long()
return image, mask
# ============================================================================
# 3. TRAINING SETUP
# ============================================================================
def setup_training():
"""Setup training environment"""
print("πŸ”§ Setting up training environment...")
# Clear GPU memory
torch.cuda.empty_cache()
gc.collect()
# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"βœ… Using device: {device}")
if torch.cuda.is_available():
print(f"βœ… GPU: {torch.cuda.get_device_name(0)}")
print(f"βœ… GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Training parameters
BATCH_SIZE = 4
IMAGE_SIZE = 224
EPOCHS = 50
LEARNING_RATE = 1e-4
print(f"πŸ”„ Training Configuration:")
print(f" Batch size: {BATCH_SIZE}")
print(f" Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f" Epochs: {EPOCHS}")
print(f" Learning rate: {LEARNING_RATE}")
return device, BATCH_SIZE, IMAGE_SIZE, EPOCHS, LEARNING_RATE
def create_data_loaders(BATCH_SIZE, IMAGE_SIZE):
"""Create training and validation data loaders"""
print("πŸ“Š Creating data loaders...")
# Check if data exists
if not os.path.exists('processed_data'):
print("❌ processed_data directory not found!")
print("πŸ’‘ Please upload processed_data.zip to this repository")
return None, None
# Create datasets
train_dataset = SimpleDataset('processed_data/train', image_size=IMAGE_SIZE)
val_dataset = SimpleDataset('processed_data/val', image_size=IMAGE_SIZE)
# Create loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
print(f"βœ… Data loaders created!")
print(f" Training batches: {len(train_loader)}")
print(f" Validation batches: {len(val_loader)}")
return train_loader, val_loader
# ============================================================================
# 4. TRAINING LOOP
# ============================================================================
def train_model(model, train_loader, val_loader, device, EPOCHS, LEARNING_RATE):
"""Main training loop"""
print(f"\n🎯 Starting training for {EPOCHS} epochs...")
# Setup training components
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
# Training history
history = {
'train_loss': [],
'val_loss': [],
'learning_rate': []
}
best_val_loss = float('inf')
start_time = time.time()
for epoch in range(EPOCHS):
epoch_start_time = time.time()
print(f"\nπŸ“… Epoch {epoch+1}/{EPOCHS}")
# Training phase
model.train()
train_loss = 0.0
train_pbar = tqdm(train_loader, desc="Training")
for batch_idx, (images, masks) in enumerate(train_pbar):
images = images.to(device)
masks = masks.to(device)
# Forward pass
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
# Backward pass
loss.backward()
optimizer.step()
# Update metrics
train_loss += loss.item()
# Update progress bar
train_pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'GPU': f'{torch.cuda.memory_allocated()/1e9:.1f}GB'
})
# Clear cache periodically
if batch_idx % 100 == 0:
torch.cuda.empty_cache()
avg_train_loss = train_loss / len(train_loader)
# Validation phase
model.eval()
val_loss = 0.0
with torch.no_grad():
val_pbar = tqdm(val_loader, desc="Validation")
for batch_idx, (images, masks) in enumerate(val_pbar):
images = images.to(device)
masks = masks.to(device)
outputs = model(images)
loss = criterion(outputs, masks)
val_loss += loss.item()
val_pbar.set_postfix({
'Loss': f'{loss.item():.4f}'
})
avg_val_loss = val_loss / len(val_loader)
# Update learning rate
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
# Update history
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['learning_rate'].append(current_lr)
# Calculate epoch time
epoch_time = time.time() - epoch_start_time
# Print results
print(f"πŸ“Š Train Loss: {avg_train_loss:.4f}")
print(f" Val Loss: {avg_val_loss:.4f}")
print(f"πŸ“Š Learning Rate: {current_lr:.6f}")
print(f" GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"⏱️ Epoch time: {epoch_time:.1f}s")
# Save best model
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss': best_val_loss,
'history': history,
'config': {
'model_type': 'ultra_simple',
'n_channels': 3,
'n_classes': 5,
'image_size': 224,
'batch_size': 4
}
}, 'best_model.pth')
print(f"βœ… New best model saved! Loss: {best_val_loss:.4f}")
# Save checkpoint every 10 epochs
if (epoch + 1) % 10 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_val_loss': best_val_loss,
'history': history
}, f'checkpoint_epoch_{epoch+1}.pth')
print(f"πŸ’Ύ Checkpoint saved: checkpoint_epoch_{epoch+1}.pth")
# Clear cache after each epoch
torch.cuda.empty_cache()
# Progress update
if (epoch + 1) % 5 == 0:
elapsed_time = time.time() - start_time
avg_epoch_time = elapsed_time / (epoch + 1)
remaining_epochs = EPOCHS - (epoch + 1)
estimated_time = remaining_epochs * avg_epoch_time
print(f"\nπŸ“ˆ Progress Update:")
print(f" Epochs completed: {epoch+1}/{EPOCHS}")
print(f" Best validation loss: {best_val_loss:.4f}")
print(f" Average epoch time: {avg_epoch_time:.1f}s")
print(f" Estimated time remaining: {estimated_time/60:.1f} minutes")
# Training complete
total_time = time.time() - start_time
print(f"\nπŸŽ‰ Training completed!")
print(f"⏱️ Total time: {total_time/3600:.1f} hours")
print(f" Best validation loss: {best_val_loss:.4f}")
return history
# ============================================================================
# 5. VISUALIZATION
# ============================================================================
def plot_training_history(history):
"""Plot training history"""
if len(history['train_loss']) > 0:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# Plot losses
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True)
# Plot learning rate
ax2.plot(history['learning_rate'], label='Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Learning Rate')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
print("πŸ“Š Training history plotted and saved as 'training_history.png'")
# ============================================================================
# 6. MAIN FUNCTION
# ============================================================================
def main():
"""Main training function"""
try:
# Setup
device, BATCH_SIZE, IMAGE_SIZE, EPOCHS, LEARNING_RATE = setup_training()
# Create data loaders
train_loader, val_loader = create_data_loaders(BATCH_SIZE, IMAGE_SIZE)
if train_loader is None:
return
# Create model
model = UltraSimpleModel(n_channels=3, n_classes=5).to(device)
print(f"βœ… Model created! Parameters: {sum(p.numel() for p in model.parameters()):,}")
# Train model
history = train_model(model, train_loader, val_loader, device, EPOCHS, LEARNING_RATE)
# Plot results
plot_training_history(history)
print("\nβœ… Training completed successfully!")
print("πŸ’Ύ Best model saved as 'best_model.pth'")
print("πŸ“Š Training history saved as 'training_history.png'")
except Exception as e:
print(f"❌ Training failed with error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()