File size: 3,463 Bytes
13442a3
dfef42f
af22224
dfef42f
 
7e82443
dfef42f
 
 
99311fd
13442a3
99311fd
 
 
 
 
 
 
13442a3
dfef42f
 
 
 
 
 
 
99311fd
 
 
 
 
 
 
 
 
dfef42f
8ca5a6f
dfef42f
df2c452
dfef42f
 
 
 
 
145e5fc
dfef42f
145e5fc
7e82443
 
 
dfef42f
eb4468f
145e5fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfec533
 
 
 
5f922b6
bfec533
 
145e5fc
 
 
 
 
 
 
 
 
 
 
 
 
dfef42f
145e5fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch, torchvision
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from time import time, ctime
from PIL import Image, ImageColor
from diffusers import DDPMPipeline
from diffusers import DDIMScheduler
from tqdm import tqdm

device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

pipeline_name = 'WiNE-iNEFF/Minecraft-Skin-Diffusion'
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)

# Set up the scheduler
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)

def show_images_save(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x, nrow=4)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    grid_im.save(f"test.png")
    return grid_im

def generate():
    x = torch.randn(1, 4, 64, 64).to(device)
    # Minimal sampling loop
    for i, t in enumerate(scheduler.timesteps):
        model_input = scheduler.scale_model_input(x, t)
        with torch.no_grad():
            noise_pred = image_pipe.unet(model_input, t)["sample"]
        x = scheduler.step(noise_pred, t, x).prev_sample
    # View the results
    return show_images_save(x)

def ex():
    t = time()
    print(ctime(t))
    return generate(), generate(), generate(), generate()

demo = gr.Blocks(css="#img_size {max-height: 128px} .container {max-width: 730px; margin: auto;} .min-h-\[15rem\]{min-height: 5rem !important;}")

with demo:
  gr.HTML(
        """
            <div style="text-align: center; margin: 0 auto;">
              <div style="display: inline-flex;align-items: center;gap: 0.8rem;font-size: 1.75rem;">
                <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px">
                  Minecraft Skin Diffusion
                </h1>
              </div>
              <p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
                Gradio demo for Minecraft Skin Diffusion. This is simple Unconditional Diffusion Model that will help you generate skins for game Minecraft.
              </p>
            </div>
        """
    )
  with gr.Column():
    with gr.Row().style(equal_height=True):
        out = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
        out2 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
    with gr.Row().style(equal_height=True):
        out3 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
        out4 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
  greet_btn = gr.Button("Generate")
  greet_btn.click(fn=ex, inputs=None, outputs=[out, out2, out3, out4])
  gr.HTML(
            """
                <div class="footer">
                    <div style='text-align: center;'>Minecraft Skin Diffusion by <a href='https://twitter.com/wine_ineff' target='_blank'>Artsem Holub (WiNE-iNEFF)</a> | 
                      <center>
                        <img src='https://visitor-badge.glitch.me/badge?page_id=WiNE-iNEFF_MinecraftSkin-Diffusion' alt='visitor badge'>
                      </center>
                    </div>
               </div>
           """
        )

demo.launch()