import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from sklearn.model_selection import train_test_split from src.dataset_loader import CognitiveLoadDataset from src.model import EEGConformer import numpy as np import time # --- Config for H200 (FP32) --- BATCH_SIZE = 1024 # Expand to 1024 on H200 for speed EPOCHS = 2000 LEARNING_RATE = 0.001 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def print_step(step, msg): print(f"\n[{step}] {msg}") print("-" * 50) def train(): print("="*60) print(" COGNTIVE LOAD DETECTION SYSTEM - TRAINING ENGINE ") print("="*60) print(f"Device: {DEVICE.upper()}") print("Precision: Full (FP32) - Max Accuracy") # 1. Load Data print_step("1/4", "Loading & Verifying Dataset...") data_path = os.path.join("data") if not os.path.exists(data_path): print("Error: Data folder not found!") return loader = CognitiveLoadDataset(data_path) # Using tqdm inside loader if possible, or just print here loader.load_data() X, y = loader.get_data() msg = f"Data Loaded Successfully.\n Samples: {len(X)}\n Shape: {X.shape if len(X)>0 else 'N/A'}" print(msg) if len(X) == 0: print("CRITICAL ERROR: No data loaded. Aborting.") return # 2. Preprocessing print_step("2/4", "Preprocessing & Splitting...") X = np.array(X, dtype=np.float32) y = np.array(y, dtype=np.longlong) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) print(f" Train Set: {X_train.shape[0]} samples") print(f" Test Set: {X_test.shape[0]} samples") train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train)) test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test)) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) print(" DataLoaders ready.") # 3. Model Setup print_step("3/4", "Initializing EEGConformer Model...") n_classes = len(np.unique(y)) if len(y) > 0 else 3 print(f" Detected Classes: {n_classes}") model = EEGConformer(n_classes=n_classes, channels=X.shape[1], time_points=X.shape[2]) model = model.to(DEVICE) model.float() # Enforce FP32 if int(torch.__version__.split('.')[0]) >= 2 and DEVICE == 'cuda': # torch.compile causes issues with some custom models/ops, disabling for stability print(" [Info] torch.compile() disabled for stability.") # model = torch.compile(model) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01) # Cosine Annealing with Warm Restarts # T_0=50 (Restart every 50 epochs), T_mult=2 (Double the restart interval each time) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2) # 4. Training Loop print_step("4/4", "Starting Training Phase...") best_acc = 0.0 try: from tqdm import tqdm except ImportError: def tqdm(x, desc=None): return x for epoch in range(EPOCHS): model.train() running_loss = 0.0 correct = 0 total = 0 start = time.time() # Batch Loop with Progress Bar being concise pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch") for inputs, labels in pbar: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() # Update pbar acc = 100. * correct / total pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Acc": f"{acc:.2f}%"}) train_acc = 100. * correct / total # Val model.eval() val_correct = 0 val_total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) outputs = model(inputs) _, predicted = outputs.max(1) val_total += labels.size(0) val_correct += predicted.eq(labels).sum().item() val_acc = 100. * val_correct / val_total epoch_time = time.time()-start print(f" --> Epoch {epoch+1} Summary: Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Time: {epoch_time:.1f}s") if val_acc > best_acc: best_acc = val_acc if not os.path.exists("models"): os.makedirs("models") torch.save(model.state_dict(), "models/best_model.pth") print(f" [+] New Best Model Saved! ({best_acc:.2f}%)") # Step Scheduler (CosineAnnealing doesn't need val_acc) scheduler.step() print("\n" + "="*60) print(f"TRAINING COMPLETE. Best Accuracy: {best_acc:.2f}%") print("="*60) if __name__ == "__main__": train()