import torch from arena.tasks.registry import TASKS def test_all_tasks_create_and_backward_on_cpu(): for task in TASKS.values(): instance = task.create(device=torch.device("cpu"), seed=0, config={}) loss = instance.loss_fn() assert torch.isfinite(loss).item(), task.name loss.backward() assert len(instance.params) > 0