Spaces:
Running on Zero
Running on Zero
| import torch | |
| from unittest.mock import patch | |
| from src.sample import sampling | |
| class DummyModel: | |
| def get_model_object(self, name): | |
| class MS: | |
| def __init__(self): | |
| import torch | |
| # minimal sigmas to avoid heavy computation | |
| self.sigmas = torch.linspace(1, 0, 20) | |
| self.sigma_min = 0.0 | |
| self.sigma_max = 1.0 | |
| def timestep(self, sigma): | |
| return sigma | |
| return MS() | |
| def test_ksampler_sets_scheduler_before_set_steps(monkeypatch): | |
| ks = sampling.KSampler() | |
| captured = {} | |
| def fake_set_steps(steps, denoise=None): | |
| # capture scheduler value visible inside set_steps | |
| captured['scheduler'] = ks.scheduler | |
| # Patch ks.set_steps to avoid running heavy logic and to capture scheduler | |
| monkeypatch.setattr(ks, 'set_steps', fake_set_steps) | |
| # Patch common_ksampler to no-op to prevent further sampling behavior | |
| monkeypatch.setattr(sampling, 'common_ksampler', lambda *args, **kwargs: {}) | |
| # Call sample with scheduler passed explicitly | |
| ks.sample(model=DummyModel(), steps=20, scheduler='simple', sampler_name='euler') | |
| assert captured.get('scheduler') == 'simple', "KSampler.sample must set the provided scheduler before calling set_steps" | |