thoughtlink / src /integration.py
Srilekha23's picture
Upload folder using huggingface_hub
832948a verified
"""Phase 8: Brain-Robot Interface integration."""
import sys
import os
import time
import threading
import json
import numpy as np
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
# Add BRI to path
BRI_SRC = str(PROJECT_ROOT / "brain-robot-interface" / "src")
if BRI_SRC not in sys.path:
sys.path.insert(0, BRI_SRC)
# Import our pipeline (src/ is already on path when run from src/)
from pipeline import ThoughtLinkPipeline
def _import_bri():
"""Try to import BRI modules, return (Action, Controller) or raise."""
from bri import Action, Controller
return Action, Controller
ACTION_STR_MAP = None # populated lazily
def _get_action_map():
global ACTION_STR_MAP
if ACTION_STR_MAP is not None:
return ACTION_STR_MAP
Action, _ = _import_bri()
ACTION_STR_MAP = {
"FORWARD": Action.FORWARD,
"BACKWARD": Action.BACKWARD,
"LEFT": Action.LEFT,
"RIGHT": Action.RIGHT,
"STOP": Action.STOP,
}
return ACTION_STR_MAP
def bci_policy_loop(ctrl, pipeline, npz_file, stop_event, set_action_fn):
"""
Run BCI pipeline on an .npz file and send decoded actions to the robot.
Paces output to simulate real-time (0.5s per window step).
"""
action_map = _get_action_map()
Action, _ = _import_bri()
current_action = Action.STOP
for action_str, confidence, latency_ms, phase in pipeline.process_file(npz_file):
if stop_event.is_set():
break
action = action_map.get(action_str, Action.STOP)
set_action_fn(action)
current_action = action
print(f" Action: {action_str:8s} [{phase:10s}] Confidence: {confidence:.2f} Latency: {latency_ms:.1f}ms")
# Simulate real-time: each window step = 0.5s
sleep_time = max(0, 0.5 - (latency_ms / 1000.0))
sleep_start = time.time()
while time.time() - sleep_start < sleep_time:
if stop_event.is_set():
break
set_action_fn(current_action) # re-send to prevent hold_s timeout
time.sleep(min(0.1, max(0, sleep_time - (time.time() - sleep_start))))
set_action_fn(action_map.get("STOP", current_action))
print("\n Recording complete. Robot stopped.")
def run_bci_sim(npz_file, model_dir=None):
"""
Main entry point: load models, start MuJoCo sim, run BCI loop.
Returns pipeline metrics dict.
"""
if model_dir is None:
model_dir = str(PROJECT_ROOT / "models")
Action, Controller = _import_bri()
pipeline = ThoughtLinkPipeline()
pipeline.load_models(
os.path.join(model_dir, "stage1_binary.pkl"),
os.path.join(model_dir, "stage2_direction.pkl"),
)
# Bundle dir: must point to the BRI repo's bundles
bundle_dir = str(PROJECT_ROOT / "brain-robot-interface" / "bundles" / "g1_mjlab")
ctrl = Controller(
backend="sim",
hold_s=1.0,
forward_speed=0.4,
yaw_rate=1.0,
smooth_alpha=0.3,
bundle_dir=bundle_dir,
)
print(" Starting MuJoCo simulation (G1 humanoid)...")
ctrl.start()
# Configure camera: zoomed out, slightly above and in front
viewer = ctrl._backend._viewer
if viewer is not None:
viewer.cam.distance = 5.0 # far enough to see full body
viewer.cam.azimuth = 180.0 # looking from in front of the robot
viewer.cam.elevation = -25.0 # slightly above
# Initial lookat at the robot's root body position
data = ctrl._backend._data
if data is not None:
viewer.cam.lookat[0] = float(data.qpos[0])
viewer.cam.lookat[1] = float(data.qpos[1])
viewer.cam.lookat[2] = 0.8 # roughly torso height
print(" MuJoCo viewer launched. Decoding brain signals...")
stop_event = threading.Event()
thread = threading.Thread(
target=bci_policy_loop,
args=(ctrl, pipeline, npz_file, stop_event, ctrl.set_action),
daemon=True,
)
thread.start()
try:
while thread.is_alive():
# Track the robot: update camera lookat to follow root body
data = ctrl._backend._data
vw = ctrl._backend._viewer
if data is not None and vw is not None and vw.is_running():
vw.cam.lookat[0] = float(data.qpos[0])
vw.cam.lookat[1] = float(data.qpos[1])
vw.cam.lookat[2] = 0.8
time.sleep(0.05)
except KeyboardInterrupt:
print("\n Interrupted by user.")
stop_event.set()
finally:
thread.join(timeout=2.0)
ctrl.stop()
metrics = pipeline.get_metrics()
os.makedirs(os.path.join(str(PROJECT_ROOT), "results"), exist_ok=True)
with open(os.path.join(str(PROJECT_ROOT), "results", "demo_metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
return metrics
def run_fallback_demo(npz_file, model_dir=None):
"""
Fallback if MuJoCo fails: matplotlib top-down grid visualization.
"""
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
if model_dir is None:
model_dir = str(PROJECT_ROOT / "models")
pipeline = ThoughtLinkPipeline()
pipeline.load_models(
os.path.join(model_dir, "stage1_binary.pkl"),
os.path.join(model_dir, "stage2_direction.pkl"),
)
# Robot state
x, y = 5.0, 5.0
heading = 90 # degrees
trail = [(x, y)]
speed = 0.3
turn_rate = 20
fig, (ax_map, ax_conf) = plt.subplots(1, 2, figsize=(14, 6))
plt.ion()
actions_log = []
confidences_log = []
arr = np.load(npz_file, allow_pickle=True)
label_info = arr["label"].item()
gt_label = label_info["label"]
for action_str, confidence, latency_ms, phase in pipeline.process_file(npz_file):
if action_str == "FORWARD":
x += speed * np.cos(np.radians(heading))
y += speed * np.sin(np.radians(heading))
elif action_str == "BACKWARD":
x -= speed * np.cos(np.radians(heading))
y -= speed * np.sin(np.radians(heading))
elif action_str == "LEFT":
heading += turn_rate
elif action_str == "RIGHT":
heading -= turn_rate
trail.append((x, y))
actions_log.append(action_str)
confidences_log.append(confidence)
ax_map.clear()
ax_map.set_xlim(0, 10)
ax_map.set_ylim(0, 10)
ax_map.set_aspect("equal")
ax_map.set_title(f"Robot Position | GT: {gt_label} | Action: {action_str} | Conf: {confidence:.2f}")
trail_arr = np.array(trail)
ax_map.plot(trail_arr[:, 0], trail_arr[:, 1], "b-", alpha=0.3, linewidth=1)
ax_map.plot(x, y, "ro", markersize=12)
dx = 0.5 * np.cos(np.radians(heading))
dy = 0.5 * np.sin(np.radians(heading))
ax_map.arrow(x, y, dx, dy, head_width=0.15, head_length=0.1, fc="red", ec="red")
ax_conf.clear()
ax_conf.plot(confidences_log, "g-")
ax_conf.set_title("Confidence Over Time")
ax_conf.set_xlabel("Window")
ax_conf.set_ylabel("Confidence")
ax_conf.set_ylim(0, 1)
plt.tight_layout()
plt.pause(0.3)
plt.ioff()
results_dir = os.path.join(str(PROJECT_ROOT), "results")
os.makedirs(results_dir, exist_ok=True)
plt.savefig(os.path.join(results_dir, "fallback_demo.png"), dpi=150)
print(f" Fallback demo saved to results/fallback_demo.png")
plt.show()
return pipeline.get_metrics()
def run_headless_demo(npz_file, model_dir=None):
"""Headless mode: just print actions to console."""
if model_dir is None:
model_dir = str(PROJECT_ROOT / "models")
pipeline = ThoughtLinkPipeline()
pipeline.load_models(
os.path.join(model_dir, "stage1_binary.pkl"),
os.path.join(model_dir, "stage2_direction.pkl"),
)
arr = np.load(npz_file, allow_pickle=True)
label_info = arr["label"].item()
print(f" Ground Truth: {label_info['label']}")
print(f" Subject: {label_info['subject_id']}")
print(f" Duration: {label_info['duration']:.1f}s\n")
for action, conf, lat, phase in pipeline.process_file(npz_file):
print(f" Action: {action:8s} [{phase:10s}] Confidence: {conf:.2f} Latency: {lat:.1f}ms")
metrics = pipeline.get_metrics()
return metrics
if __name__ == "__main__":
import sys as _sys
npz = _sys.argv[1] if len(_sys.argv) > 1 else str(PROJECT_ROOT / "data" / "0b2dbd41-10.npz")
print("Testing headless mode...")
metrics = run_headless_demo(npz)
print(f"\nMetrics: {metrics}")