| import torch
|
| import numpy as np
|
| import time
|
| import sys
|
| import os
|
|
|
|
|
| sys.path.append(os.getcwd())
|
|
|
| from src.model import EEGConformer
|
|
|
|
|
| def get_realtime_eeg_chunk(channels=64, time_points=1000):
|
|
|
| return torch.randn(1, channels, time_points)
|
|
|
| def adapt_curriculum(load_level):
|
| """
|
| The Core Adaptation Logic (Proposal 2).
|
| """
|
| timestamp = time.strftime("%H:%M:%S")
|
|
|
| if load_level == 2:
|
| print(f"[{timestamp}] 🔴 High Cognitive Load detected! -> ACTION: Simplifying Content / Switching to Video.")
|
| elif load_level == 1:
|
| print(f"[{timestamp}] 🟡 Medium Load. -> ACTION: Maintain current difficulty.")
|
| else:
|
| print(f"[{timestamp}] 🟢 Low Load / Relaxed. -> ACTION: Increase difficulty / Present new topic.")
|
|
|
| def run_inference_loop():
|
| print("--- Starting Closed-Loop Adaptation System (Simulation) ---")
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
| model = EEGConformer(n_classes=3, channels=64, time_points=1000).to(device)
|
|
|
|
|
| weights_path = "models/best_model.pth"
|
| if os.path.exists(weights_path):
|
| print(f"Loading trained model from {weights_path}...")
|
| try:
|
| model.load_state_dict(torch.load(weights_path, map_location=device))
|
| except:
|
| print("Failed to load weights. Using random weights for simulation.")
|
| else:
|
| print("No trained model found. Using random initialized model for simulation.")
|
|
|
| model.eval()
|
| model.float()
|
|
|
|
|
| try:
|
| print("Listening for EEG stream...")
|
| for i in range(10):
|
|
|
| eeg_chunk = get_realtime_eeg_chunk().to(device)
|
|
|
|
|
| with torch.no_grad():
|
| output = model(eeg_chunk)
|
| probabilities = torch.softmax(output, dim=1)
|
| prediction = torch.argmax(probabilities, dim=1).item()
|
| confidence = probabilities[0][prediction].item()
|
|
|
|
|
| print(f"\nEvent {i+1}: Predicted Class {prediction} (Confidence: {confidence:.2f})")
|
| adapt_curriculum(prediction)
|
|
|
| time.sleep(1)
|
|
|
| except KeyboardInterrupt:
|
| print("Stopped.")
|
|
|
| if __name__ == "__main__":
|
| run_inference_loop()
|
|
|