eeg-cognitive-load / src /train.py
dodo-2100's picture
Upload folder using huggingface_hub
2afe0cd verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from src.dataset_loader import CognitiveLoadDataset
from src.model import EEGConformer
import numpy as np
import time
# --- Config for H200 (FP32) ---
BATCH_SIZE = 1024 # Expand to 1024 on H200 for speed
EPOCHS = 2000
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def print_step(step, msg):
print(f"\n[{step}] {msg}")
print("-" * 50)
def train():
print("="*60)
print(" COGNTIVE LOAD DETECTION SYSTEM - TRAINING ENGINE ")
print("="*60)
print(f"Device: {DEVICE.upper()}")
print("Precision: Full (FP32) - Max Accuracy")
# 1. Load Data
print_step("1/4", "Loading & Verifying Dataset...")
data_path = os.path.join("data")
if not os.path.exists(data_path):
print("Error: Data folder not found!")
return
loader = CognitiveLoadDataset(data_path)
# Using tqdm inside loader if possible, or just print here
loader.load_data()
X, y = loader.get_data()
msg = f"Data Loaded Successfully.\n Samples: {len(X)}\n Shape: {X.shape if len(X)>0 else 'N/A'}"
print(msg)
if len(X) == 0:
print("CRITICAL ERROR: No data loaded. Aborting.")
return
# 2. Preprocessing
print_step("2/4", "Preprocessing & Splitting...")
X = np.array(X, dtype=np.float32)
y = np.array(y, dtype=np.longlong)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f" Train Set: {X_train.shape[0]} samples")
print(f" Test Set: {X_test.shape[0]} samples")
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print(" DataLoaders ready.")
# 3. Model Setup
print_step("3/4", "Initializing EEGConformer Model...")
n_classes = len(np.unique(y)) if len(y) > 0 else 3
print(f" Detected Classes: {n_classes}")
model = EEGConformer(n_classes=n_classes, channels=X.shape[1], time_points=X.shape[2])
model = model.to(DEVICE)
model.float() # Enforce FP32
if int(torch.__version__.split('.')[0]) >= 2 and DEVICE == 'cuda':
# torch.compile causes issues with some custom models/ops, disabling for stability
print(" [Info] torch.compile() disabled for stability.")
# model = torch.compile(model)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
# Cosine Annealing with Warm Restarts
# T_0=50 (Restart every 50 epochs), T_mult=2 (Double the restart interval each time)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2)
# 4. Training Loop
print_step("4/4", "Starting Training Phase...")
best_acc = 0.0
try:
from tqdm import tqdm
except ImportError:
def tqdm(x, desc=None): return x
for epoch in range(EPOCHS):
model.train()
running_loss = 0.0
correct = 0
total = 0
start = time.time()
# Batch Loop with Progress Bar being concise
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", unit="batch")
for inputs, labels in pbar:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Update pbar
acc = 100. * correct / total
pbar.set_postfix({"Loss": f"{loss.item():.4f}", "Acc": f"{acc:.2f}%"})
train_acc = 100. * correct / total
# Val
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
_, predicted = outputs.max(1)
val_total += labels.size(0)
val_correct += predicted.eq(labels).sum().item()
val_acc = 100. * val_correct / val_total
epoch_time = time.time()-start
print(f" --> Epoch {epoch+1} Summary: Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}% | Time: {epoch_time:.1f}s")
if val_acc > best_acc:
best_acc = val_acc
if not os.path.exists("models"):
os.makedirs("models")
torch.save(model.state_dict(), "models/best_model.pth")
print(f" [+] New Best Model Saved! ({best_acc:.2f}%)")
# Step Scheduler (CosineAnnealing doesn't need val_acc)
scheduler.step()
print("\n" + "="*60)
print(f"TRAINING COMPLETE. Best Accuracy: {best_acc:.2f}%")
print("="*60)
if __name__ == "__main__":
train()