LightDiffusion-Next / tests /unit /test_sampling.py
Aatricks's picture
Deploy ZeroGPU Gradio Space snapshot
b701455
import torch
from unittest.mock import patch
from src.sample import sampling
class DummyModel:
def get_model_object(self, name):
class MS:
def __init__(self):
import torch
# minimal sigmas to avoid heavy computation
self.sigmas = torch.linspace(1, 0, 20)
self.sigma_min = 0.0
self.sigma_max = 1.0
def timestep(self, sigma):
return sigma
return MS()
def test_ksampler_sets_scheduler_before_set_steps(monkeypatch):
ks = sampling.KSampler()
captured = {}
def fake_set_steps(steps, denoise=None):
# capture scheduler value visible inside set_steps
captured['scheduler'] = ks.scheduler
# Patch ks.set_steps to avoid running heavy logic and to capture scheduler
monkeypatch.setattr(ks, 'set_steps', fake_set_steps)
# Patch common_ksampler to no-op to prevent further sampling behavior
monkeypatch.setattr(sampling, 'common_ksampler', lambda *args, **kwargs: {})
# Call sample with scheduler passed explicitly
ks.sample(model=DummyModel(), steps=20, scheduler='simple', sampler_name='euler')
assert captured.get('scheduler') == 'simple', "KSampler.sample must set the provided scheduler before calling set_steps"