minidreamer / tests /test_cem_planner.py
PatrykT's picture
Sync local repo state (#1)
f6d8768
import torch
from minidreamer.planning.cem import DiscreteCEMPlanner
class DummyWorldModel:
def __init__(self):
self.device = torch.device("cpu")
def score_action_sequences(self, state, action_sequences, discount=0.99, use_done_mask=True):
target = torch.tensor([1, 2, 0, 1], device=action_sequences.device)
scores = -(action_sequences != target).float().sum(dim=-1)
return {"scores": scores}
def test_discrete_cem_planner_finds_high_scoring_sequence():
torch.manual_seed(0)
planner = DiscreteCEMPlanner(
world_model=DummyWorldModel(),
action_dim=3,
horizon=4,
candidates=512,
elites=64,
iterations=5,
discount=1.0,
use_done_mask=False,
)
output = planner.plan(state=object())
assert output.action == 1
assert output.sequence == [1, 2, 0, 1]