vinhtong97's picture
Upload folder using huggingface_hub
d382778 verified
"""SAMPLING ONLY."""
import torch
from .dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
class DPMSolverv3Sampler:
def __init__(self, ckp_path, stats_dir, model, steps, guidance_scale, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.alphas_cumprod = to_torch(model.alphas_cumprod)
self.device = self.model.betas.device
self.guidance_scale = guidance_scale
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
assert stats_dir is not None, f"No statistics file found in {stats_base}."
print("Use statistics", stats_dir)
self.dpm_solver_v3 = DPM_Solver_v3(
statistics_dir=stats_dir,
noise_schedule=self.ns,
steps=steps,
t_start=None,
t_end=None,
skip_type="time_uniform",
degenerated=False,
device=self.device,
)
self.steps = steps
@torch.no_grad()
def sample(
self,
batch_size,
shape,
conditioning=None,
x_T=None,
unconditional_conditioning=None,
use_corrector=False,
half=False,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
if x_T is None:
img = torch.randn(size, device=self.device)
else:
img = x_T
if conditioning is None:
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
self.ns,
model_type="noise",
guidance_type="uncond",
)
ORDER = 3
else:
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
self.ns,
model_type="noise",
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=self.guidance_scale,
)
ORDER = 2
x = self.dpm_solver_v3.sample(
img,
model_fn,
order=ORDER,
p_pseudo=False,
c_pseudo=True,
lower_order_final=True,
use_corrector=use_corrector,
half=half,
)
return x.to(self.device), None