Spaces:
Sleeping
Sleeping
File size: 807 Bytes
6551a95 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | import pytest
from arena.optimizers.registry import get_example
from arena.submissions.validator import (
SubmissionValidationError,
validate_user_code_ast,
)
def test_validator_accepts_example_optimizer():
code = get_example("Momentum SGD").code
report = validate_user_code_ast(code)
assert "MomentumSGD" in report.optimizer_classes
def test_validator_rejects_builtin_optimizer_wrapper():
code = """
import torch
from torch.optim import Optimizer
class BadOptimizer(Optimizer):
def __init__(self, params):
super().__init__(params, {})
def step(self, closure=None):
return None
def build_optimizer(params, config):
return torch.optim.Adam(params, lr=1e-3)
"""
with pytest.raises(SubmissionValidationError):
validate_user_code_ast(code)
|