Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch as th | |
| from imageio import imread | |
| from skimage.transform import resize as imresize | |
| from PIL import Image | |
| from decomp_diffusion.model_and_diffusion_util import * | |
| from decomp_diffusion.diffusion.respace import SpacedDiffusion | |
| from decomp_diffusion.gen_image import * | |
| from download import download_model | |
| from upsampling import get_pipeline, upscale_image | |
| import gradio as gr | |
| # from huggingface_hub import login | |
| # fix randomness | |
| th.manual_seed(0) | |
| np.random.seed(0) | |
| def get_pil_im(im, resolution=64): | |
| im = imresize(im, (resolution, resolution))[:, :, :3] | |
| im = th.Tensor(im).permute(2, 0, 1)[None, :, :, :].contiguous() | |
| return im | |
| # generate image components and reconstruction | |
| def gen_image_and_components(model, gd, im, num_components=4, sample_method='ddim', batch_size=1, image_size=64, device='cuda', num_images=1): | |
| """Generate row of orig image, individual components, and reconstructed image""" | |
| orig_img = get_pil_im(im, resolution=image_size).to(device) | |
| latent = model.encode_latent(orig_img) | |
| model_kwargs = {'latent': latent} | |
| assert sample_method in ('ddpm', 'ddim') | |
| sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop | |
| if sample_method == 'ddim': | |
| model = gd._wrap_model(model) | |
| # generate imgs | |
| for i in range(num_images): | |
| all_samples = [orig_img] | |
| # individual components | |
| for j in range(num_components): | |
| model_kwargs['latent_index'] = j | |
| sample = sample_loop_func( | |
| model, | |
| (batch_size, 3, image_size, image_size), | |
| device=device, | |
| clip_denoised=True, | |
| progress=True, | |
| model_kwargs=model_kwargs, | |
| cond_fn=None, | |
| )[:batch_size] | |
| # save indiv comp | |
| all_samples.append(sample) | |
| # reconstruction | |
| model_kwargs['latent_index'] = None | |
| sample = sample_loop_func( | |
| model, | |
| (batch_size, 3, image_size, image_size), | |
| device=device, | |
| clip_denoised=True, | |
| progress=True, | |
| model_kwargs=model_kwargs, | |
| cond_fn=None, | |
| )[:batch_size] | |
| # save indiv reconstruction | |
| all_samples.append(sample) | |
| samples = th.cat(all_samples, dim=0).cpu() | |
| grid = make_grid(samples, nrow=samples.shape[0], padding=0) | |
| return grid | |
| # def decompose_image(im): | |
| # sample_method = 'ddim' | |
| # result = gen_image_and_components(clevr_model, GD[sample_method], im, sample_method=sample_method, num_images=1, device=device) | |
| # return result.permute(1, 2, 0).numpy() | |
| # load diffusion | |
| GD = {} # diffusion objects for ddim and ddpm | |
| diffusion_kwargs = diffusion_defaults() | |
| gd = create_gaussian_diffusion(**diffusion_kwargs) | |
| GD['ddpm'] = gd | |
| # set up ddim sampling | |
| desired_timesteps = 50 | |
| num_timesteps = diffusion_kwargs['steps'] | |
| spacing = num_timesteps // desired_timesteps | |
| spaced_ts = list(range(0, num_timesteps + 1, spacing)) | |
| betas = get_named_beta_schedule(diffusion_kwargs['noise_schedule'], num_timesteps) | |
| diffusion_kwargs['betas'] = betas | |
| del diffusion_kwargs['steps'], diffusion_kwargs['noise_schedule'] | |
| gd = SpacedDiffusion(spaced_ts, rescale_timesteps=True, original_num_steps=num_timesteps, **diffusion_kwargs) | |
| GD['ddim'] = gd | |
| def combine_components_slice(model, gd, im1, im2, indices=None, sample_method='ddim', device='cuda', num_images=4, model_kwargs={}, desc='', save_dir='', dataset='clevr', image_size=64): | |
| """Combine by adding components together | |
| """ | |
| assert sample_method in ('ddpm', 'ddim') | |
| im1 = get_pil_im(im1, resolution=image_size).to(device) | |
| im2 = get_pil_im(im2, resolution=image_size).to(device) | |
| latent1 = model.encode_latent(im1) | |
| latent2 = model.encode_latent(im2) | |
| num_comps = model.num_components | |
| # get latent slices | |
| if indices == None: | |
| half = num_comps // 2 | |
| indices = [1] * half + [0] * half # first half 1, second half 0 | |
| indices = th.Tensor(indices) == 1 | |
| indices = indices.reshape(num_comps, 1) | |
| elif type(indices) == str: | |
| indices = indices.split(',') | |
| indices = [int(ind) for ind in indices] | |
| indices = th.Tensor(indices).reshape(-1, 1) == 1 | |
| assert len(indices) == num_comps | |
| indices = indices.to(device) | |
| latent1 = latent1.reshape(num_comps, -1).to(device) | |
| latent2 = latent2.reshape(num_comps, -1).to(device) | |
| combined_latent = th.where(indices, latent1, latent2) | |
| combined_latent = combined_latent.reshape(1, -1) | |
| model_kwargs['latent'] = combined_latent | |
| sample_loop_func = gd.p_sample_loop if sample_method == 'ddpm' else gd.ddim_sample_loop | |
| if sample_method == 'ddim': | |
| model = gd._wrap_model(model) | |
| # sampling loop | |
| sample = sample_loop_func( | |
| model, | |
| (1, 3, image_size, image_size), | |
| device=device, | |
| clip_denoised=True, | |
| progress=True, | |
| model_kwargs=model_kwargs, | |
| cond_fn=None, | |
| )[:1] | |
| return sample[0].cpu() | |
| def decompose_image_demo(im, model): | |
| sample_method = 'ddim' | |
| result = gen_image_and_components(MODELS[model], GD[sample_method], im, sample_method=sample_method, num_images=1, device=device) | |
| # result = Image.fromarray(result.permute(1, 2, 0).numpy()) | |
| return result.permute(1, 2, 0).numpy() | |
| def combine_images_demo(im1, im2, model): | |
| sample_method = 'ddim' | |
| result = combine_components_slice(MODELS[model], GD[sample_method], im1, im2, indices='1,0,1,0', sample_method=sample_method, num_images=1, device=device) | |
| result = result.permute(1, 2, 0).numpy() | |
| # result = Image.fromarray(result.permute(1, 2, 0).numpy()) | |
| # if model == 'CelebA-HQ': | |
| # return upscale_image(result, pipe) | |
| return result | |
| def load_model(dataset, extra_kwargs={}, device='cuda'): | |
| ckpt_path = download_model(dataset) | |
| model_kwargs = unet_model_defaults() | |
| # model parameters | |
| model_kwargs.update(extra_kwargs) | |
| model = create_diffusion_model(**model_kwargs) | |
| model.eval() | |
| model.to(device) | |
| print(f'loading from {ckpt_path}') | |
| checkpoint = th.load(ckpt_path, map_location='cpu') | |
| model.load_state_dict(checkpoint) | |
| return model | |
| device = 'cuda' if th.cuda.is_available() else 'cpu' | |
| clevr_model = load_model('clevr', extra_kwargs=dict(emb_dim=64, enc_channels=128), device=device) | |
| celeb_model = load_model('celebahq', extra_kwargs=dict(enc_channels=128), device=device) | |
| MODELS = { | |
| 'CLEVR': clevr_model, | |
| 'CelebA-HQ': celeb_model | |
| } | |
| # pipe = get_pipeline() | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """<h1 style="text-align: center;"><b>Unsupervised Compositional Image Decomposition with Diffusion Models | |
| </b> - <a href="https://jsu27.github.io/decomp-diffusion-web/">Project Page</a></h1>""") | |
| gr.Markdown( | |
| """<p style="font-size: 18px;">We introduce Decomp Diffusion, an unsupervised approach that discovers compositional concepts from images, represented by diffusion models. | |
| </p>""") | |
| gr.Markdown( | |
| """<br> <h4>Decomposition and reconstruction of images</h4>""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| decomp_input = gr.Image(type='numpy', label='Input') | |
| with gr.Row(): | |
| decomp_model = gr.Radio( | |
| ['CLEVR', 'CelebA-HQ'], type="value", label='Model', | |
| value='CLEVR') | |
| with gr.Row(): | |
| # image_examples = [os.path.join(os.path.dirname(__file__), 'sample_images/clevr_im_10.png'), 'CLEVR'] | |
| decomp_examples = [['sample_images/clevr_im_10.png', 'CLEVR'], | |
| ['sample_images/celebahq_im_15.jpg', 'CelebA-HQ']] | |
| decomp_img_examples = gr.Examples( | |
| examples=decomp_examples, | |
| inputs=[decomp_input, decomp_model] | |
| ) | |
| with gr.Column(): | |
| decomp_output = gr.Image(type='numpy') | |
| decomp_button = gr.Button("Generate") | |
| gr.Markdown( | |
| """<br> <h4>Combination of images</h4>""") | |
| with gr.Row().style(equal_height=True): | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| comb_input1 = gr.Image(type='numpy', label='Input 1') | |
| with gr.Column(): | |
| comb_input2 = gr.Image(type='numpy', label='Input 2') | |
| with gr.Row(): | |
| comb_model = gr.Radio( | |
| ['CLEVR', 'CelebA-HQ'], type="value", label='Model', | |
| value='CLEVR') | |
| with gr.Row(): | |
| comb_examples = [['sample_images/clevr_im_10.png', 'sample_images/clevr_im_25.png', 'CLEVR'], | |
| ['sample_images/celebahq_im_15.jpg', 'sample_images/celebahq_im_21.jpg', 'CelebA-HQ']] | |
| comb_img_examples = gr.Examples( | |
| examples=comb_examples, | |
| inputs=[comb_input1, comb_input2, comb_model] | |
| ) | |
| with gr.Column(scale=1): | |
| comb_output = gr.Image(type='numpy') | |
| comb_button = gr.Button("Generate") | |
| decomp_button.click(decompose_image_demo, | |
| inputs=[decomp_input, decomp_model], | |
| outputs=decomp_output) | |
| comb_button.click(combine_images_demo, | |
| inputs=[comb_input1, comb_input2, comb_model], | |
| outputs=comb_output) | |
| demo.launch(debug=True) |