import torch from torch.utils.data import DataLoader, random_split, Subset from torch.cuda.amp import autocast, GradScaler from tqdm import tqdm import numpy as np import os import datetime import pandas as pd import matplotlib.pyplot as plt import math import joblib from dataloader import MultiHouseDataset from hierarchical_diffusion_model import HierarchicalDiffusionModel if torch.cuda.is_available(): DEVICE = "cuda" torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True print("Using NVIDIA CUDA backend.") elif torch.backends.mps.is_available(): DEVICE = "mps" print("Using Apple MPS backend.") else: DEVICE = "cpu" print("Using CPU.") EPOCHS = 200 LEARNING_RATE = 1e-4 BATCH_SIZE = 512 USE_AMP = True GRADIENT_CLIP_VAL = 0.1 WINDOW_DURATION = '14_days' DATA_DIRECTORY = './data/per_house' NUM_WORKERS = os.cpu_count() // 2 PIN_MEMORY = True USE_ATTENTION = True DROPOUT = 0.1 HIDDEN_SIZE = 512 EMBEDDING_DIM = 64 DIFFUSION_TIMESTEPS = 500 DOWNSCALE_FACTOR = 4 def calculate_window_size(duration: str) -> int: SAMPLES_PER_DAY = 48 mapping = { '2_days': 2 * SAMPLES_PER_DAY, '7_days': 7 * SAMPLES_PER_DAY, '14_days': 14 * SAMPLES_PER_DAY, '15_days': 15 * SAMPLES_PER_DAY, '30_days': 30 * SAMPLES_PER_DAY } if duration not in mapping: raise ValueError(f"Invalid WINDOW_DURATION: {duration}") return mapping[duration] def denormalize_data(normalized_data, scaler_path='global_scaler.gz'): scaler = joblib.load(scaler_path) original_shape = normalized_data.shape if len(original_shape) == 3: batch_size, seq_len, features = original_shape normalized_flat = normalized_data.reshape(-1, features) denormalized_flat = scaler.inverse_transform(normalized_flat) return denormalized_flat.reshape(original_shape) else: return scaler.inverse_transform(normalized_data) def moving_average(data, window_size): return np.convolve(data, np.ones(window_size), 'valid') / window_size def save_and_plot_loss(loss_dict, title, filepath, window_size=10): plt.figure(figsize=(12, 6)) for label, losses in loss_dict.items(): pd.DataFrame({label: losses}).to_csv(f"{filepath}_{label.lower().replace(' ', '_')}.csv", index=False) plt.plot(losses, label=f'Raw {label}', alpha=0.3) if len(losses) > window_size: smoothed_losses = moving_average(losses, window_size) plt.plot(np.arange(window_size - 1, len(losses)), smoothed_losses, label=f'Smoothed {label}') plt.title(title) plt.xlabel('Epoch'); plt.ylabel('Loss') plt.legend(); plt.grid(True) plt.savefig(f"{filepath}.png"); plt.close() print(f" Loss plot saved to {filepath}.png") def train_diffusion(log_dir, model_save_path): print("--- Starting Hierarchical Diffusion Training ---") window_size = calculate_window_size(WINDOW_DURATION) print(f"Using window duration: {WINDOW_DURATION} ({window_size} samples)") dataset = MultiHouseDataset( data_dir=DATA_DIRECTORY, window_size=window_size, step_size=window_size//2, limit_to_one_year=False ) print(f"Dataset loaded: {len(dataset)} samples, {dataset.num_houses} houses, {dataset[0][0].shape[1]} features.") val_split = 0.1 val_size = int(len(dataset) * val_split) train_size = len(dataset) - val_size train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) print(f"Train size: {train_size}, Validation size: {val_size}") train_dataloader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True ) val_dataloader = DataLoader( val_dataset, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY ) channel_weights = torch.tensor([1.0, 8.0, 1.0, 1.0], device=DEVICE) print(f"Using channel weights: {channel_weights}") model = HierarchicalDiffusionModel( in_channels=dataset[0][0].shape[1], num_houses=dataset.num_houses, downscale_factor=DOWNSCALE_FACTOR, channel_weights=channel_weights, embedding_dim=EMBEDDING_DIM, hidden_dims=[HIDDEN_SIZE // 4, HIDDEN_SIZE // 2, HIDDEN_SIZE], dropout=DROPOUT, use_attention=USE_ATTENTION, num_timesteps=DIFFUSION_TIMESTEPS, blocks_per_level=3 ).to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) scaler = GradScaler(enabled=(USE_AMP and DEVICE == "cuda")) train_losses, val_losses = [], [] best_val_loss = float('inf') print(f"Starting training for {EPOCHS} epochs...") for epoch in range(EPOCHS): model.train() total_train_loss = 0.0 pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{EPOCHS} (Train)") for clean_data, conditions in pbar: clean_data = clean_data.to(DEVICE, non_blocking=PIN_MEMORY) conditions = {k: v.to(DEVICE, non_blocking=PIN_MEMORY) for k, v in conditions.items()} optimizer.zero_grad(set_to_none=True) with autocast(enabled=(USE_AMP and DEVICE == "cuda")): loss = model(clean_data, conditions) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL) scaler.step(optimizer) scaler.update() total_train_loss += loss.item() pbar.set_postfix({'loss': f'{loss.item():.6f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'}) avg_train_loss = total_train_loss / len(train_dataloader) train_losses.append(avg_train_loss) model.eval() total_val_loss = 0.0 with torch.no_grad(): for clean_data, conditions in tqdm(val_dataloader, desc="Validating"): clean_data = clean_data.to(DEVICE, non_blocking=PIN_MEMORY) conditions = {k: v.to(DEVICE, non_blocking=PIN_MEMORY) for k, v in conditions.items()} with autocast(enabled=(USE_AMP and DEVICE == "cuda")): loss = model(clean_data, conditions) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(val_dataloader) val_losses.append(avg_val_loss) print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss torch.save(model.state_dict(), model_save_path) print(f"New best model saved to {model_save_path} (Val Loss: {best_val_loss:.6f})") scheduler.step() print("--- Training complete ---") save_and_plot_loss( {'Train Loss': train_losses, 'Validation Loss': val_losses}, 'Hierarchical Diffusion Model Training & Validation Loss', os.path.join(log_dir, 'diffusion_loss_curves') ) return dataset if __name__ == "__main__": timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") run_name = f"hierarchical_diffusion_{WINDOW_DURATION}_{timestamp}" log_dir = os.path.join("./training_logs", run_name) os.makedirs(log_dir, exist_ok=True) model_path = os.path.join(log_dir, 'best_hierarchical_model.pth') print(f"Starting new run: {run_name}") print(f"Logs and models will be saved to: {log_dir}") full_dataset = train_diffusion(log_dir=log_dir, model_save_path=model_path) print("\nTraining and best model saving complete.") print(f"Model saved to: {model_path}") print(f"Loss curves saved to: {os.path.join(log_dir, 'diffusion_loss_curves.png')}")