michaelriedl commited on
Commit
4e39abd
·
1 Parent(s): 14fc1e9

Added medium model

Browse files
Files changed (1) hide show
  1. app.py +81 -6
app.py CHANGED
@@ -25,7 +25,7 @@ def generate_small(color_indexed: bool, color_num: int) -> list:
25
  List of PIL images.
26
  """
27
  # Get the latent dimension
28
- latent_dim = model.model.latent_dim
29
  # Initialize the list of images
30
  images_list = []
31
  # Generate MAX_IMAGES images
@@ -34,7 +34,7 @@ def generate_small(color_indexed: bool, color_num: int) -> list:
34
  latents = torch.randn((1, latent_dim))
35
  # Generate the image
36
  with torch.no_grad():
37
- generated_image = model(latents)
38
  # Clamp the image to [0, 1]
39
  generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy()
40
 
@@ -58,14 +58,69 @@ def generate_small(color_indexed: bool, color_num: int) -> list:
58
  return images_list
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Create the demo interface
62
  demo = gr.Blocks()
63
 
64
- # Create the model
65
- model = AutoModel.from_pretrained(
66
  "michaelriedl/MonsterForge-small", trust_remote_code=True
67
  )
68
- model.eval()
 
 
 
 
 
 
69
 
70
  # Create the interface
71
  with demo:
@@ -102,7 +157,27 @@ with demo:
102
  outputs=gallery_small,
103
  )
104
  with gr.TabItem("Medium Sprite"):
105
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  gr.HTML(
107
  """
108
  <div class="footer">
 
25
  List of PIL images.
26
  """
27
  # Get the latent dimension
28
+ latent_dim = model_small.model.latent_dim
29
  # Initialize the list of images
30
  images_list = []
31
  # Generate MAX_IMAGES images
 
34
  latents = torch.randn((1, latent_dim))
35
  # Generate the image
36
  with torch.no_grad():
37
+ generated_image = model_small(latents)
38
  # Clamp the image to [0, 1]
39
  generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy()
40
 
 
58
  return images_list
59
 
60
 
61
+ def generate_med(color_indexed: bool, color_num: int) -> list:
62
+ """Generates a medium sprite.
63
+
64
+ Parameters
65
+ ----------
66
+ color_indexed : bool
67
+ Whether to use color indexing.
68
+ color_num : int
69
+ Number of colors in the palette.
70
+
71
+ Returns
72
+ -------
73
+ list
74
+ List of PIL images.
75
+ """
76
+ # Get the latent dimension
77
+ latent_dim = model_med.model.latent_dim
78
+ # Initialize the list of images
79
+ images_list = []
80
+ # Generate MAX_IMAGES images
81
+ for _ in range(MAX_IMAGES):
82
+ # Generate a random latent vector
83
+ latents = torch.randn((1, latent_dim))
84
+ # Generate the image
85
+ with torch.no_grad():
86
+ generated_image = model_med(latents)
87
+ # Clamp the image to [0, 1]
88
+ generated_image = generated_image.clamp_(0.0, 1.0).cpu().numpy()
89
+
90
+ # Convert the generated image to PIL image
91
+ color_image = Image.fromarray(
92
+ np.uint8(generated_image[0] * 255).transpose(1, 2, 0), "RGBA"
93
+ )
94
+
95
+ # Convert to color indexed image if needed
96
+ if color_indexed:
97
+ # Convert using adaptive palette of given color depth
98
+ color_image_indexed = color_image.convert(
99
+ "P", palette=Image.ADAPTIVE, colors=color_num
100
+ )
101
+ # Add the color indexed image to the list
102
+ images_list.append(color_image_indexed)
103
+
104
+ # Add the image to the list
105
+ images_list.append(color_image)
106
+
107
+ return images_list
108
+
109
+
110
  # Create the demo interface
111
  demo = gr.Blocks()
112
 
113
+ # Create the small model
114
+ model_small = AutoModel.from_pretrained(
115
  "michaelriedl/MonsterForge-small", trust_remote_code=True
116
  )
117
+ model_small.eval()
118
+
119
+ # Create the medium model
120
+ model_med = AutoModel.from_pretrained(
121
+ "michaelriedl/MonsterForge-medium", trust_remote_code=True
122
+ )
123
+ model_med.eval()
124
 
125
  # Create the interface
126
  with demo:
 
157
  outputs=gallery_small,
158
  )
159
  with gr.TabItem("Medium Sprite"):
160
+ with gr.Column():
161
+ with gr.Row():
162
+ gallery_med = gr.Gallery(
163
+ columns=4,
164
+ object_fit="scale-down",
165
+ )
166
+ with gr.Row():
167
+ color_index_med = gr.Checkbox(label="Color indexed", value=False)
168
+ color_num_med = gr.Slider(
169
+ minimum=8,
170
+ maximum=32,
171
+ value=32,
172
+ step=4,
173
+ label="Number of colors in the palette",
174
+ )
175
+ gen_btn_med = gr.Button("Generate")
176
+ gen_btn_med.click(
177
+ fn=generate_med,
178
+ inputs=[color_index_med, color_num_med],
179
+ outputs=gallery_med,
180
+ )
181
  gr.HTML(
182
  """
183
  <div class="footer">