import torch from minidreamer.planning.cem import DiscreteCEMPlanner class DummyWorldModel: def __init__(self): self.device = torch.device("cpu") def score_action_sequences(self, state, action_sequences, discount=0.99, use_done_mask=True): target = torch.tensor([1, 2, 0, 1], device=action_sequences.device) scores = -(action_sequences != target).float().sum(dim=-1) return {"scores": scores} def test_discrete_cem_planner_finds_high_scoring_sequence(): torch.manual_seed(0) planner = DiscreteCEMPlanner( world_model=DummyWorldModel(), action_dim=3, horizon=4, candidates=512, elites=64, iterations=5, discount=1.0, use_done_mask=False, ) output = planner.plan(state=object()) assert output.action == 1 assert output.sequence == [1, 2, 0, 1]