OptimizationArena / tests /test_tasks.py
mmkuznecov's picture
first version
6551a95
import torch
from arena.tasks.registry import TASKS
def test_all_tasks_create_and_backward_on_cpu():
for task in TASKS.values():
instance = task.create(device=torch.device("cpu"), seed=0, config={})
loss = instance.loss_fn()
assert torch.isfinite(loss).item(), task.name
loss.backward()
assert len(instance.params) > 0