| """Tests for OFO controller internals: PrimalBatchOptimizer, VoltageDualVariables. |
| |
| Focuses on edge cases, projections, NaN handling, and shape validation, |
| not trivial arithmetic.""" |
|
|
| from __future__ import annotations |
|
|
| import math |
|
|
| import numpy as np |
| import pytest |
| from mlenergy_data.modeling import LogisticModel |
|
|
| from openg2g.controller.ofo import ( |
| OFOConfig, |
| PrimalBatchOptimizer, |
| VoltageDualVariables, |
| ) |
| from openg2g.datacenter.config import InferenceModelSpec |
|
|
|
|
| def _trivial_logistic(L: float = 100.0, x0: float = 5.0, k: float = 1.0, b0: float = 0.0): |
| """Build a LogisticModel with known parameters.""" |
| return LogisticModel.from_dict({"L": L, "x0": x0, "k": k, "b0": b0}) |
|
|
|
|
| def _make_primal( |
| *, |
| feasible_batch_sizes: list[int] | None = None, |
| config: OFOConfig | None = None, |
| ) -> PrimalBatchOptimizer: |
| """Build a PrimalBatchOptimizer with a single model and trivial fits.""" |
| if feasible_batch_sizes is None: |
| feasible_batch_sizes = [8, 16, 32, 64, 128] |
| if config is None: |
| config = OFOConfig() |
| model = InferenceModelSpec( |
| model_label="M", num_replicas=10, gpus_per_replica=1, initial_batch_size=128, itl_deadline_s=0.1 |
| ) |
| fit = _trivial_logistic() |
| return PrimalBatchOptimizer( |
| models=[model], |
| feasible_batch_sizes=feasible_batch_sizes, |
| power_fits={"M": fit}, |
| latency_fits={"M": fit}, |
| throughput_fits={"M": fit}, |
| config=config, |
| ) |
|
|
|
|
| def _step_no_gradient(p: PrimalBatchOptimizer) -> dict[str, int]: |
| """Run a step with zero gradient (all weights and duals zero).""" |
| n = 6 |
| return p.step( |
| voltage_dual_diff=np.zeros(n), |
| sensitivity_matrix=np.zeros((n, 3)), |
| phase_share_by_model={}, |
| latency_dual_by_model={"M": 0.0}, |
| replica_count_by_model={"M": 0.0}, |
| ) |
|
|
|
|
| class TestBatchDiscretization: |
| def test_each_batch_size_round_trips(self) -> None: |
| """Initializing to any feasible batch size and taking a zero-gradient |
| step should return that same batch size (log2 -> discrete round-trip).""" |
| for b in [8, 16, 32, 64, 128]: |
| p = _make_primal(feasible_batch_sizes=[8, 16, 32, 64, 128]) |
| p.init_from_batches({"M": b}) |
| assert _step_no_gradient(p)["M"] == b |
|
|
| def test_single_batch_set(self) -> None: |
| """With only one feasible batch size, discretization should always |
| return it regardless of continuous x.""" |
| p = _make_primal(feasible_batch_sizes=[64]) |
| p.init_from_batches({"M": 64}) |
| assert _step_no_gradient(p)["M"] == 64 |
|
|
|
|
| class TestInitFromBatches: |
| def test_missing_model_defaults_to_max(self) -> None: |
| """Models not in the init dict should default to the max batch size.""" |
| p = _make_primal(feasible_batch_sizes=[8, 16, 32, 64, 128]) |
| p.init_from_batches({}) |
| assert p.log_batch_size_by_model["M"] == pytest.approx(math.log2(128)) |
|
|
| def test_x_prev_matches_x(self) -> None: |
| """After init, x_prev should equal x (no switching cost on first step).""" |
| p = _make_primal(feasible_batch_sizes=[8, 16, 32, 64, 128]) |
| p.init_from_batches({"M": 64}) |
| assert p.prev_log_batch_size_by_model["M"] == p.log_batch_size_by_model["M"] |
|
|
|
|
| class TestVoltageDualUpdate: |
| def test_no_violation_stays_zero(self) -> None: |
| """Voltages within bounds should leave both duals at zero.""" |
| vd = VoltageDualVariables(3, OFOConfig(v_min=0.95, v_max=1.05, voltage_dual_step_size=1.0)) |
| vd.update(np.array([1.0, 1.0, 1.0])) |
| np.testing.assert_array_equal(vd.dual_undervoltage, np.zeros(3)) |
| np.testing.assert_array_equal(vd.dual_overvoltage, np.zeros(3)) |
|
|
| def test_dual_projects_to_nonneg(self) -> None: |
| """After a violation clears, the dual should be clamped to zero |
| (not go negative), preserving the [.]+ projection.""" |
| vd = VoltageDualVariables(1, OFOConfig(v_min=0.95, v_max=1.05, voltage_dual_step_size=1.0)) |
| vd.update(np.array([0.93])) |
| vd.update(np.array([1.0])) |
| assert vd.dual_undervoltage[0] == 0.0 |
|
|
| def test_eta_sign_convention(self) -> None: |
| """eta = dual_overvoltage - dual_undervoltage should be negative for undervoltage |
| (drives power down) and positive for overvoltage.""" |
| vd = VoltageDualVariables(2, OFOConfig(v_min=0.95, v_max=1.05, voltage_dual_step_size=1.0)) |
| vd.update(np.array([0.93, 1.07])) |
| eta = vd.dual_difference() |
| assert eta[0] < 0 |
| assert eta[1] > 0 |
|
|
| def test_shape_mismatch_raises(self) -> None: |
| """Passing a v_hat with wrong length should raise ValueError.""" |
| vd = VoltageDualVariables(3, OFOConfig()) |
| with pytest.raises(ValueError, match="len 2 but duals have len 3"): |
| vd.update(np.array([1.0, 1.0])) |
|
|
|
|
| class TestPrimalStep: |
| def test_zero_gradient_no_change(self) -> None: |
| """With all gradient weights and duals set to zero, the batch size |
| should remain unchanged after a step.""" |
| p = _make_primal( |
| feasible_batch_sizes=[8, 16, 32, 64, 128], |
| config=OFOConfig( |
| primal_step_size=0.1, |
| w_throughput=0.0, |
| w_switch=0.0, |
| voltage_gradient_scale=0.0, |
| ), |
| ) |
| p.init_from_batches({"M": 32}) |
| batch = _step_no_gradient(p) |
| assert batch["M"] == 32 |
|
|
| def test_latency_dual_pushes_batch_down(self) -> None: |
| """A positive latency dual (mu > 0) should decrease x, since the |
| logistic latency fit has dL/dx > 0 (latency increases with batch).""" |
| p = _make_primal( |
| feasible_batch_sizes=[8, 16, 32, 64, 128], |
| config=OFOConfig( |
| primal_step_size=0.5, |
| w_throughput=0.0, |
| w_switch=0.0, |
| voltage_gradient_scale=0.0, |
| ), |
| ) |
| p.init_from_batches({"M": 64}) |
| x_before = p.log_batch_size_by_model["M"] |
| p.step( |
| voltage_dual_diff=np.zeros(6), |
| sensitivity_matrix=np.zeros((6, 3)), |
| phase_share_by_model={}, |
| latency_dual_by_model={"M": 5.0}, |
| replica_count_by_model={"M": 0.0}, |
| ) |
| assert p.log_batch_size_by_model["M"] < x_before |
|
|
| def test_nan_mu_treated_as_zero(self) -> None: |
| """NaN mu (from zero-replica models) should be treated as zero, |
| producing no latency gradient contribution.""" |
| p = _make_primal( |
| feasible_batch_sizes=[8, 16, 32, 64, 128], |
| config=OFOConfig( |
| primal_step_size=0.1, |
| w_throughput=0.0, |
| w_switch=0.0, |
| voltage_gradient_scale=0.0, |
| ), |
| ) |
| p.init_from_batches({"M": 32}) |
| batch = p.step( |
| voltage_dual_diff=np.zeros(6), |
| sensitivity_matrix=np.zeros((6, 3)), |
| phase_share_by_model={}, |
| latency_dual_by_model={"M": float("nan")}, |
| replica_count_by_model={"M": 0.0}, |
| ) |
| assert batch["M"] == 32 |
|
|
| def test_projection_clamps_with_huge_gradient(self) -> None: |
| """Even with an extreme gradient, x should stay within |
| [log2(min_batch), log2(max_batch)] after projection.""" |
| p = _make_primal( |
| feasible_batch_sizes=[8, 128], |
| config=OFOConfig( |
| primal_step_size=100.0, |
| w_throughput=0.0, |
| w_switch=0.0, |
| voltage_gradient_scale=1e9, |
| ), |
| ) |
| p.init_from_batches({"M": 64}) |
| p.step( |
| voltage_dual_diff=np.ones(6) * 100.0, |
| sensitivity_matrix=np.ones((6, 3)), |
| phase_share_by_model={"M": np.array([0.5, 0.3, 0.2])}, |
| latency_dual_by_model={"M": 0.0}, |
| replica_count_by_model={"M": 100.0}, |
| ) |
| assert p.log_batch_size_by_model["M"] >= math.log2(8) |
| assert p.log_batch_size_by_model["M"] <= math.log2(128) |
|
|