Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import numpy as np | |
| from contextlib import nullcontext | |
| from PIL import Image | |
| from einops import rearrange | |
| import torch | |
| from torch import autocast | |
| from .ldm.models.diffusion.ddim import DDIMSampler | |
| def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, \ | |
| ddim_eta, x, y, z): | |
| precision_scope = autocast if precision=='autocast' else nullcontext | |
| with precision_scope('cuda'): | |
| with model.ema_scope(): | |
| c = model.get_learned_conditioning(input_im).tile(n_samples,1,1) | |
| T = torch.tensor([x, math.sin(y), math.cos(y), z]) | |
| T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device) | |
| c = torch.cat([c, T], dim=-1) | |
| c = model.cc_projection(c) | |
| cond = {} | |
| cond['c_crossattn'] = [c] | |
| c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach() | |
| cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()\ | |
| .repeat(n_samples, 1, 1, 1)] | |
| if scale != 1.0: | |
| uc = {} | |
| uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] | |
| uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] | |
| else: | |
| uc = None | |
| shape = [4, h // 8, w // 8] | |
| samples_ddim, _ = sampler.sample(S=ddim_steps, | |
| conditioning=cond, | |
| batch_size=n_samples, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=uc, | |
| eta=ddim_eta, | |
| x_T=None) | |
| # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False) | |
| x_samples_ddim = model.decode_first_stage(samples_ddim) | |
| return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() | |
| def sample_images( | |
| model, | |
| input_im, | |
| x=0., | |
| y=0., | |
| z=0., | |
| scale=3.0, | |
| n_samples=4, | |
| ddim_steps=50, | |
| ddim_eta=1.0, | |
| precision='fp32', | |
| h=256, | |
| w=256, | |
| ): | |
| sampler = DDIMSampler(model) | |
| x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w,\ | |
| ddim_steps, n_samples, scale, ddim_eta, x, y, z) | |
| output_ims = [] | |
| for x_sample in x_samples_ddim: | |
| x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
| output_ims.append(x_sample.astype(np.uint8)) | |
| return output_ims |