WiNE-iNEFF commited on
Commit
ea60a6a
·
1 Parent(s): 4695b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -9
app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from time import time, ctime
7
  from PIL import Image, ImageColor
8
  from diffusers import DDPMPipeline
9
- from diffusers import DDIMScheduler
10
  from tqdm import tqdm
11
 
12
  device = (
@@ -20,9 +20,16 @@ device = (
20
  pipeline_name = 'WiNE-iNEFF/Minecraft-Skin-Diffusion'
21
  image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
22
 
23
- # Set up the scheduler
24
- scheduler = DDIMScheduler.from_pretrained(pipeline_name)
25
- scheduler.set_timesteps(num_inference_steps=40)
 
 
 
 
 
 
 
26
 
27
  def show_images_save(x):
28
  """Given a batch of images x, make a grid and convert to PIL"""
@@ -30,10 +37,13 @@ def show_images_save(x):
30
  grid = torchvision.utils.make_grid(x, nrow=4)
31
  grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
32
  grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
33
- grid_im.save(f"test.png")
34
  return grid_im
35
 
36
- def generate():
 
 
 
 
37
  x = torch.randn(1, 4, 64, 64).to(device)
38
  # Minimal sampling loop
39
  for i, t in enumerate(scheduler.timesteps):
@@ -44,10 +54,10 @@ def generate():
44
  # View the results
45
  return show_images_save(x)
46
 
47
- def ex():
48
  t = time()
49
  print(ctime(t))
50
- return generate(), generate(), generate(), generate()
51
 
52
  demo = gr.Blocks(css="#img_size {max-height: 128px} .container {max-width: 730px; margin: auto;} .min-h-\[15rem\]{min-height: 5rem !important;}")
53
 
@@ -67,6 +77,7 @@ with demo:
67
  """
68
  )
69
  with gr.Column():
 
70
  with gr.Row().style(equal_height=True):
71
  out = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
72
  out2 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
@@ -74,7 +85,7 @@ with demo:
74
  out3 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
75
  out4 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
76
  greet_btn = gr.Button("Generate")
77
- greet_btn.click(fn=ex, inputs=None, outputs=[out, out2, out3, out4])
78
  gr.HTML(
79
  """
80
  <div class="footer">
 
6
  from time import time, ctime
7
  from PIL import Image, ImageColor
8
  from diffusers import DDPMPipeline
9
+ from diffusers import DDIMScheduler, PNDMScheduler
10
  from tqdm import tqdm
11
 
12
  device = (
 
20
  pipeline_name = 'WiNE-iNEFF/Minecraft-Skin-Diffusion'
21
  image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
22
 
23
+ class Model:
24
+ def __init__(self, name, code):
25
+ self.name = name
26
+ self.path = path
27
+
28
+ model = [
29
+ Model("DDIMScheduler", "scheduler = DDIMScheduler.from_pretrained(pipeline_name)"),
30
+ Model("PNDMScheduler", "scheduler = PNDMScheduler.from_pretrained(pipeline_name)")]
31
+
32
+ current_model = model[0]
33
 
34
  def show_images_save(x):
35
  """Given a batch of images x, make a grid and convert to PIL"""
 
37
  grid = torchvision.utils.make_grid(x, nrow=4)
38
  grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
39
  grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
 
40
  return grid_im
41
 
42
+ def generate(schedul):
43
+ for i in model:
44
+ if schedul.name == i.name:
45
+ scheduler = i.code
46
+ scheduler.set_timesteps(num_inference_steps=40)
47
  x = torch.randn(1, 4, 64, 64).to(device)
48
  # Minimal sampling loop
49
  for i, t in enumerate(scheduler.timesteps):
 
54
  # View the results
55
  return show_images_save(x)
56
 
57
+ def ex(scheduler):
58
  t = time()
59
  print(ctime(t))
60
+ return generate(scheduler), generate(scheduler), generate(scheduler), generate(scheduler)
61
 
62
  demo = gr.Blocks(css="#img_size {max-height: 128px} .container {max-width: 730px; margin: auto;} .min-h-\[15rem\]{min-height: 5rem !important;}")
63
 
 
77
  """
78
  )
79
  with gr.Column():
80
+ model_name = gr.Dropdown(label="Base Scheduler", choices=[m.name for m in model], value=current_model.name)
81
  with gr.Row().style(equal_height=True):
82
  out = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
83
  out2 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
 
85
  out3 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
86
  out4 = gr.Image(shape=(64,64), image_mode='RGBA', type='pil', elem_id='img_size')
87
  greet_btn = gr.Button("Generate")
88
+ greet_btn.click(fn=ex, inputs=[model_name], outputs=[out, out2, out3, out4])
89
  gr.HTML(
90
  """
91
  <div class="footer">