| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| from torch.utils.data import DataLoader, TensorDataset
|
| from sklearn.model_selection import train_test_split
|
| from src.dataset_loader import CognitiveLoadDataset
|
| from src.model import EEGConformer
|
| import numpy as np
|
| import time
|
|
|
|
|
| BATCH_SIZE = 1024
|
| EPOCHS = 2000
|
| LEARNING_RATE = 0.001
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| def print_step(step, msg):
|
| print(f"\n[{step}] {msg}")
|
| print("-" * 50)
|
|
|
| def train():
|
| print("="*60)
|
| print(" COGNTIVE LOAD DETECTION SYSTEM - TRAINING ENGINE ")
|
| print("="*60)
|
| print(f"Device: {DEVICE.upper()}")
|
| print("Precision: Full (FP32) - Max Accuracy")
|
|
|
|
|
| print_step("1/4", "Loading & Verifying Dataset...")
|
| data_path = os.path.join("data")
|
| if not os.path.exists(data_path):
|
| print("Error: Data folder not found!")
|
| return
|
|
|
| loader = CognitiveLoadDataset(data_path)
|
|
|
| loader.load_data()
|
| X, y = loader.get_data()
|
|
|
| msg = f"Data Loaded Successfully.\n Samples: {len(X)}\n Shape: {X.shape if len(X)>0 else 'N/A'}"
|
| print(msg)
|
|
|
| if len(X) == 0:
|
| print("CRITICAL ERROR: No data loaded. Aborting.")
|
| return
|
|
|
|
|
| print_step("2/4", "Preprocessing & Splitting...")
|
| X = np.array(X, dtype=np.float32)
|
| y = np.array(y, dtype=np.longlong)
|
|
|
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
| print(f" Train Set: {X_train.shape[0]} samples")
|
| print(f" Test Set: {X_test.shape[0]} samples")
|
|
|
| train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
|
| test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
|
|
|
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
|
| test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| print(" DataLoaders ready.")
|
|
|
|
|
| print_step("3/4", "Initializing EEGConformer Model...")
|
| n_classes = len(np.unique(y)) if len(y) > 0 else 3
|
| print(f" Detected Classes: {n_classes}")
|
|
|
| model = EEGConformer(n_classes=n_classes, channels=X.shape[1], time_points=X.shape[2])
|
| model = model.to(DEVICE)
|
| model.float()
|
|
|
| if int(torch.__version__.split('.')[0]) >= 2 and DEVICE == 'cuda':
|
|
|
| print(" [Info] torch.compile() disabled for stability.")
|
|
|
|
|
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
|
|
|
|
|
|
|
| scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
|
|
|
|
|
| print_step("4/4", "Starting Training Phase...")
|
| best_acc = 0.0
|
|
|
| try:
|
| from tqdm import tqdm
|
| except ImportError:
|
| def tqdm(x, desc=None): return x
|
|
|
| for epoch in range(EPOCHS):
|
| model.train()
|
| running_loss = 0.0
|
| correct = 0
|
| total = 0
|
| start = time.time()
|
|
|
|
|
| pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch")
|
|
|
| for inputs, labels in pbar:
|
| inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
|
| optimizer.zero_grad()
|
| outputs = model(inputs)
|
| loss = criterion(outputs, labels)
|
| loss.backward()
|
| optimizer.step()
|
|
|
| running_loss += loss.item()
|
| _, predicted = outputs.max(1)
|
| total += labels.size(0)
|
| correct += predicted.eq(labels).sum().item()
|
|
|
|
|
| acc = 100. * correct / total
|
| pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Acc": f"{acc:.2f}%"})
|
|
|
| train_acc = 100. * correct / total
|
|
|
|
|
| model.eval()
|
| val_correct = 0
|
| val_total = 0
|
| with torch.no_grad():
|
| for inputs, labels in test_loader:
|
| inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
| outputs = model(inputs)
|
| _, predicted = outputs.max(1)
|
| val_total += labels.size(0)
|
| val_correct += predicted.eq(labels).sum().item()
|
|
|
| val_acc = 100. * val_correct / val_total
|
| epoch_time = time.time()-start
|
|
|
| print(f" --> Epoch {epoch+1} Summary: Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Time: {epoch_time:.1f}s")
|
|
|
| if val_acc > best_acc:
|
| best_acc = val_acc
|
| if not os.path.exists("models"):
|
| os.makedirs("models")
|
| torch.save(model.state_dict(), "models/best_model.pth")
|
| print(f" [+] New Best Model Saved! ({best_acc:.2f}%)")
|
|
|
|
|
| scheduler.step()
|
|
|
| print("\n" + "="*60)
|
| print(f"TRAINING COMPLETE. Best Accuracy: {best_acc:.2f}%")
|
| print("="*60)
|
|
|
| if __name__ == "__main__":
|
| train()
|
|
|