add info
Browse files
app.py
CHANGED
|
@@ -4,13 +4,10 @@ import gradio as gr
|
|
| 4 |
from PIL import Image
|
| 5 |
from cli import iterative_refinement
|
| 6 |
from viz import grid_of_images_default
|
| 7 |
-
# from subprocess
|
| 8 |
-
# subprocess.call("download_models.sh", shell=True)
|
| 9 |
models = {
|
| 10 |
"convae": torch.load("convae.th", map_location="cpu"),
|
| 11 |
"deep_convae": torch.load("deep_convae.th", map_location="cpu"),
|
| 12 |
}
|
| 13 |
-
|
| 14 |
def gen(model, seed, nb_iter, nb_samples, width, height):
|
| 15 |
torch.manual_seed(int(seed))
|
| 16 |
bs = 64
|
|
@@ -26,9 +23,17 @@ def gen(model, seed, nb_iter, nb_samples, width, height):
|
|
| 26 |
grid = (grid*255).astype("uint8")
|
| 27 |
return Image.fromarray(grid)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
iface = gr.Interface(
|
| 30 |
fn=gen,
|
| 31 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
| 32 |
outputs="image"
|
| 33 |
)
|
| 34 |
iface.launch()
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
from cli import iterative_refinement
|
| 6 |
from viz import grid_of_images_default
|
|
|
|
|
|
|
| 7 |
models = {
|
| 8 |
"convae": torch.load("convae.th", map_location="cpu"),
|
| 9 |
"deep_convae": torch.load("deep_convae.th", map_location="cpu"),
|
| 10 |
}
|
|
|
|
| 11 |
def gen(model, seed, nb_iter, nb_samples, width, height):
|
| 12 |
torch.manual_seed(int(seed))
|
| 13 |
bs = 64
|
|
|
|
| 23 |
grid = (grid*255).astype("uint8")
|
| 24 |
return Image.fromarray(grid)
|
| 25 |
|
| 26 |
+
text = """
|
| 27 |
+
Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`)
|
| 28 |
+
|
| 29 |
+
These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.
|
| 30 |
+
"""
|
| 31 |
iface = gr.Interface(
|
| 32 |
fn=gen,
|
| 33 |
+
inputs=[
|
| 34 |
+
gr.Markdown(text),
|
| 35 |
+
gr.Dropdown(list(models.keys()), value="deep_convae"), gr.Number(value=0), gr.Number(value=20), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28)
|
| 36 |
+
],
|
| 37 |
outputs="image"
|
| 38 |
)
|
| 39 |
iface.launch()
|