Update app.py
Browse files
app.py
CHANGED
|
@@ -22,19 +22,21 @@ def sample_latent(batch, key):
|
|
| 22 |
def to_img(normalized):
|
| 23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
| 24 |
|
|
|
|
|
|
|
| 25 |
def generate_images(previous=None):
|
| 26 |
-
latents = sample_latent(
|
| 27 |
if previous:
|
| 28 |
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
| 29 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
| 30 |
img = np.array(to_img(g_out128))
|
| 31 |
-
for row in range(
|
| 32 |
with st.container():
|
| 33 |
-
for (col_idx, col) in enumerate(st.columns(
|
| 34 |
with col:
|
| 35 |
-
idx = row*
|
| 36 |
st.image(Image.fromarray(img[idx]))
|
| 37 |
-
st.button(label="Generate similar", on_click=generate_images, args=latents[idx])
|
| 38 |
|
| 39 |
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
| 40 |
if st.button('Generate Random'):
|
|
|
|
| 22 |
def to_img(normalized):
|
| 23 |
return ((normalized+1)*255./2.).astype(np.uint8)
|
| 24 |
|
| 25 |
+
ROWS = 4
|
| 26 |
+
COLUMNS = 4
|
| 27 |
def generate_images(previous=None):
|
| 28 |
+
latents = sample_latent(ROWS * COLUMNS, jax.random.PRNGKey(int(1_000_000 * time.time())))
|
| 29 |
if previous:
|
| 30 |
latents = np.repeat([previous], repeats=16, axis=0) + 0.25 * latents
|
| 31 |
(g_out128, _, _, _, _, _) = generator.apply({'params': g_state['params'], 'batch_stats': g_state['batch_stats']}, latents, training=False)
|
| 32 |
img = np.array(to_img(g_out128))
|
| 33 |
+
for row in range(ROWS):
|
| 34 |
with st.container():
|
| 35 |
+
for (col_idx, col) in enumerate(st.columns(COLUMNS)):
|
| 36 |
with col:
|
| 37 |
+
idx = row*COLUMNS + col_idx
|
| 38 |
st.image(Image.fromarray(img[idx]))
|
| 39 |
+
st.button(label="Generate similar", on_click=generate_images, args=(latents[idx]))
|
| 40 |
|
| 41 |
st.write("The model and its details are at https://huggingface.co/PrakhAI/AIPlane2")
|
| 42 |
if st.button('Generate Random'):
|