HenryZhang's picture
Upload folder using huggingface_hub
810379d verified
#!/usr/bin/env python3
import time
import sys
import queue
import inspect
import torch
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower
from lerobot.teleoperators.so101_leader import SO101LeaderConfig
from lerobot.teleoperators import make_teleoperator_from_config
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# -------------------------
# CONFIG
# -------------------------
FOLLOWER_PORT = "/dev/ttyACM1"
LEADER_PORT = "/dev/ttyACM2"
TOP_CAM_INDEX = 4
WRIST_CAM_INDEX = 9
MODEL_ID = "lerobot/smolvla_base"
TASK = "Pick up the red block."
ROBOT_TYPE = "so101_follower"
FPS = 20
POLICY_SCALE = 1
EPISODE_SECONDS = 10.0
# ---- Recording / Hub ----
curr_time = time.strftime("%Y%m%d_%H%M%S", time.localtime())
DATASET_REPO_ID = f"HenryZhang/so101_smolvla_eval_{curr_time}"
DATASET_ROOT = None
USE_VIDEOS = True
PUSH_TO_HUB_ON_EXIT = True
PRIVATE_ON_HUB = False
DATASET_TAGS = ["LeRobot", "so101", "smolvla", "policy-eval"]
# -------------------------
def log(msg):
print(msg, flush=True)
def start_enter_listener(cmd_q: "queue.Queue[str]"):
"""Press Enter to start one episode."""
while True:
try:
sys.stdin.readline()
cmd_q.put("start_episode")
except Exception:
break
def send_leader_action(robot, leader_action):
if not isinstance(leader_action, dict):
return
out = {k: float(leader_action[k]) for k in robot.action_features.keys() if k in leader_action}
if out:
robot.send_action(out)
def _import_build_dataset_frame():
try:
from lerobot.common.datasets.utils import build_dataset_frame
return build_dataset_frame
except Exception:
from lerobot.datasets.utils import build_dataset_frame
return build_dataset_frame
def create_dataset(repo_id, fps, root, robot_type, features, use_videos, num_cameras):
kwargs = dict(
repo_id=repo_id,
fps=fps,
root=root,
robot_type=robot_type,
features=features,
use_videos=use_videos,
image_writer_processes=0,
image_writer_threads=4 * max(num_cameras, 1),
)
try:
if "single_task" in inspect.signature(LeRobotDataset.create).parameters:
kwargs["single_task"] = TASK
except Exception:
pass
try:
ds = LeRobotDataset.create(**kwargs, exist_ok=True)
except TypeError:
ds = LeRobotDataset.create(**kwargs)
if hasattr(ds, "start_image_writer") and num_cameras > 0:
ds.start_image_writer(num_processes=0, num_threads=4 * num_cameras)
log(f"[INFO] Dataset ready: {repo_id}")
return ds
def dataset_add_frame_compat(dataset, frame, task):
try:
if "task" in inspect.signature(dataset.add_frame).parameters:
dataset.add_frame(frame, task=task)
return
except Exception:
pass
frame["task"] = task
dataset.add_frame(frame)
def dataset_push_compat(dataset, repo_id, tags, private):
try:
if len(inspect.signature(dataset.push_to_hub).parameters) >= 1:
dataset.push_to_hub(repo_id, tags=tags, private=private)
return
except Exception:
pass
dataset.push_to_hub(tags=tags, private=private)
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"[INFO] Device: {device}")
build_dataset_frame = _import_build_dataset_frame()
# ---- Load policy ----
log(f"[INFO] Loading SmolVLA: {MODEL_ID}")
policy = SmolVLAPolicy.from_pretrained(MODEL_ID).to(device)
policy.eval()
preprocess, postprocess = make_pre_post_processors(
policy.config,
MODEL_ID,
preprocessor_overrides={"device_processor": {"device": str(device)}},
postprocessor_overrides={"device_processor": {"device": str(device)}},
)
# ---- Cameras ----
camera_cfg = {
"camera1": OpenCVCameraConfig(index_or_path=TOP_CAM_INDEX, width=640, height=480, fps=30),
"camera2": OpenCVCameraConfig(index_or_path=WRIST_CAM_INDEX, width=640, height=480, fps=30),
}
# ---- Robots ----
robot_cfg = SO101FollowerConfig(port=FOLLOWER_PORT, id="so101_follower_arm", cameras=camera_cfg)
leader_cfg = SO101LeaderConfig(port=LEADER_PORT, id="so101_leader_arm")
log("[INFO] Connecting follower...")
robot = SO101Follower(robot_cfg)
robot.connect()
log("[INFO] Connecting leader...")
teleop = make_teleoperator_from_config(leader_cfg)
teleop.connect()
# ---- Dataset ----
action_features = hw_to_dataset_features(robot.action_features, "action", USE_VIDEOS)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", USE_VIDEOS)
dataset_features = {**action_features, **obs_features}
dataset = create_dataset(
DATASET_REPO_ID,
FPS,
DATASET_ROOT,
robot.name,
dataset_features,
USE_VIDEOS,
len(getattr(robot, "cameras", [])),
)
# ---- Enter listener ----
cmd_q = queue.Queue()
import threading
threading.Thread(target=start_enter_listener, args=(cmd_q,), daemon=True).start()
log("\n[INFO] Press Enter to run ONE episode. Ctrl+C to exit.\n")
dt = 1.0 / FPS
mode = "RESET"
episode_idx = 0
episode_end_time = None
policy.reset()
try:
while True:
t0 = time.time()
if mode == "RESET" and not cmd_q.empty():
cmd_q.get_nowait()
episode_idx += 1
policy.reset()
if hasattr(dataset, "clear_episode_buffer"):
dataset.clear_episode_buffer()
episode_end_time = time.time() + EPISODE_SECONDS
mode = "POLICY"
log(f"[INFO] Episode {episode_idx} START")
if mode == "RESET":
send_leader_action(robot, teleop.get_action())
else:
if time.time() >= episode_end_time:
log(f"[INFO] Episode {episode_idx} END — saving...")
t_save = time.time()
dataset.save_episode()
log(f"[INFO] Saved in {time.time() - t_save:.1f}s")
mode = "RESET"
episode_end_time = None
else:
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs,
ds_features=dataset_features,
device=device,
task=TASK,
robot_type=ROBOT_TYPE,
)
with torch.no_grad():
batch = preprocess(obs_frame)
action = policy.select_action(batch)
action = postprocess(action)
if isinstance(action, torch.Tensor):
action = action.squeeze(0) * POLICY_SCALE
robot_action = make_robot_action(action, dataset_features)
sent_action = robot.send_action(robot_action)
print("Predicted:", action, "robot:", robot_action, "sent:", sent_action)
frame = {
**build_dataset_frame(dataset.features, obs, "observation"),
**build_dataset_frame(dataset.features, sent_action, "action"),
}
dataset_add_frame_compat(dataset, frame, TASK)
time.sleep(max(0.0, dt - (time.time() - t0)))
except KeyboardInterrupt:
log("\n[INFO] Ctrl+C received.")
finally:
teleop.disconnect()
robot.disconnect()
if PUSH_TO_HUB_ON_EXIT:
log("[INFO] Pushing dataset to Hub...")
dataset_push_compat(dataset, DATASET_REPO_ID, DATASET_TAGS, PRIVATE_ON_HUB)
log("[INFO] Done.")
if __name__ == "__main__":
main()