OptimizationArena / tests /test_validator.py
mmkuznecov's picture
first version
6551a95
import pytest
from arena.optimizers.registry import get_example
from arena.submissions.validator import (
SubmissionValidationError,
validate_user_code_ast,
)
def test_validator_accepts_example_optimizer():
code = get_example("Momentum SGD").code
report = validate_user_code_ast(code)
assert "MomentumSGD" in report.optimizer_classes
def test_validator_rejects_builtin_optimizer_wrapper():
code = """
import torch
from torch.optim import Optimizer
class BadOptimizer(Optimizer):
def __init__(self, params):
super().__init__(params, {})
def step(self, closure=None):
return None
def build_optimizer(params, config):
return torch.optim.Adam(params, lr=1e-3)
"""
with pytest.raises(SubmissionValidationError):
validate_user_code_ast(code)