File size: 366 Bytes
6551a95
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

from arena.tasks.registry import TASKS


def test_all_tasks_create_and_backward_on_cpu():
    for task in TASKS.values():
        instance = task.create(device=torch.device("cpu"), seed=0, config={})
        loss = instance.loss_fn()
        assert torch.isfinite(loss).item(), task.name
        loss.backward()
        assert len(instance.params) > 0