eeg-cognitive-load / src /inference_loop.py
dodo-2100's picture
Upload folder using huggingface_hub
2afe0cd verified
import torch
import numpy as np
import time
import sys
import os
# Add parent dir to path
sys.path.append(os.getcwd())
from src.model import EEGConformer
# Mock function to simulate receiving real-time EEG chunk
def get_realtime_eeg_chunk(channels=64, time_points=1000):
# In a real system, this would pull from an LSL stream or hardware buffer
return torch.randn(1, channels, time_points)
def adapt_curriculum(load_level):
"""
The Core Adaptation Logic (Proposal 2).
"""
timestamp = time.strftime("%H:%M:%S")
if load_level == 2: # High Load / Confusion
print(f"[{timestamp}] 🔴 High Cognitive Load detected! -> ACTION: Simplifying Content / Switching to Video.")
elif load_level == 1: # Medium
print(f"[{timestamp}] 🟡 Medium Load. -> ACTION: Maintain current difficulty.")
else: # Low / Rest
print(f"[{timestamp}] 🟢 Low Load / Relaxed. -> ACTION: Increase difficulty / Present new topic.")
def run_inference_loop():
print("--- Starting Closed-Loop Adaptation System (Simulation) ---")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Model (Dummy init for simulation if no weights yet)
model = EEGConformer(n_classes=3, channels=64, time_points=1000).to(device)
# Try to load weights if exist
weights_path = "models/best_model.pth"
if os.path.exists(weights_path):
print(f"Loading trained model from {weights_path}...")
try:
model.load_state_dict(torch.load(weights_path, map_location=device))
except:
print("Failed to load weights. Using random weights for simulation.")
else:
print("No trained model found. Using random initialized model for simulation.")
model.eval()
model.float() # FP32
# Simulation Loop
try:
print("Listening for EEG stream...")
for i in range(10): # Simulate 10 events
# 1. Acquire Data
eeg_chunk = get_realtime_eeg_chunk().to(device)
# 2. Predict
with torch.no_grad():
output = model(eeg_chunk)
probabilities = torch.softmax(output, dim=1)
prediction = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][prediction].item()
# 3. Adapt
print(f"\nEvent {i+1}: Predicted Class {prediction} (Confidence: {confidence:.2f})")
adapt_curriculum(prediction)
time.sleep(1) # Simulate time gap between processing
except KeyboardInterrupt:
print("Stopped.")
if __name__ == "__main__":
run_inference_loop()