Spaces:
Running
Running
added progress bar
Browse files- app.py +23 -50
- load_model.py +7 -4
app.py
CHANGED
|
@@ -10,8 +10,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 10 |
device = "mps" if torch.backends.mps.is_available() else device
|
| 11 |
|
| 12 |
image_size = 128
|
| 13 |
-
upscale = False
|
| 14 |
-
clicked = False
|
| 15 |
|
| 16 |
|
| 17 |
transform = transforms.Compose(
|
|
@@ -23,49 +21,26 @@ transform = transforms.Compose(
|
|
| 23 |
)
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
sketch = transforms.ToTensor()(sketch).to(device)
|
| 36 |
-
scribbles = transforms.ToTensor()(scribbles).to(device)
|
| 37 |
-
|
| 38 |
-
scribble_where_grey_mask = torch.eq(scribbles, grey_tensor)
|
| 39 |
-
|
| 40 |
-
merged = torch.where(scribble_where_grey_mask, sketch, scribbles)
|
| 41 |
-
|
| 42 |
-
return transforms.Lambda(lambda t: (t * 2) - 1)(sketch), transforms.Lambda(
|
| 43 |
-
lambda t: (t * 2) - 1
|
| 44 |
-
)(merged)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def process_images(sketch, scribbles, sampling_steps, is_scribbles, seed_nr, upscale):
|
| 48 |
-
global clicked
|
| 49 |
-
clicked = True
|
| 50 |
w, h = sketch.size
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
else:
|
| 56 |
-
sketch = transform(sketch.convert("RGB"))
|
| 57 |
-
scribbles = transform(scribbles.convert("RGB"))
|
| 58 |
|
| 59 |
if upscale:
|
| 60 |
-
|
| 61 |
-
sample(sketch, scribbles, sampling_steps, seed_nr)
|
| 62 |
)
|
| 63 |
-
|
| 64 |
-
return output
|
| 65 |
else:
|
| 66 |
-
|
| 67 |
-
clicked = False
|
| 68 |
-
return output
|
| 69 |
|
| 70 |
|
| 71 |
theme = gr.themes.Monochrome()
|
|
@@ -87,7 +62,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 87 |
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
|
| 88 |
"</p>"
|
| 89 |
)
|
| 90 |
-
|
| 91 |
with gr.Column():
|
| 92 |
output = gr.Image(type="pil", label="Output")
|
| 93 |
upscale_info = gr.Markdown(
|
|
@@ -96,14 +71,12 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 96 |
"</p>"
|
| 97 |
)
|
| 98 |
upscale_button = gr.Checkbox(label="Stretch", value=False)
|
|
|
|
| 99 |
with gr.Row():
|
| 100 |
with gr.Column():
|
| 101 |
seed_slider = gr.Number(
|
| 102 |
-
label="Random Seed 🎲",
|
| 103 |
-
value=random.randint(
|
| 104 |
-
1,
|
| 105 |
-
1000,
|
| 106 |
-
),
|
| 107 |
)
|
| 108 |
|
| 109 |
with gr.Column():
|
|
@@ -111,12 +84,12 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 111 |
minimum=1,
|
| 112 |
maximum=250,
|
| 113 |
step=1,
|
| 114 |
-
label="DDPM Sampling Steps 🔄",
|
| 115 |
value=50,
|
| 116 |
)
|
| 117 |
|
| 118 |
with gr.Row():
|
| 119 |
-
generate_button = gr.Button(value="Generate"
|
| 120 |
with gr.Row():
|
| 121 |
generate_info = gr.Markdown(
|
| 122 |
"<p style='text-align: center; font-size: 16px;'>"
|
|
@@ -130,13 +103,13 @@ with gr.Blocks(theme=theme) as demo:
|
|
| 130 |
sketch_input,
|
| 131 |
scribbles_input,
|
| 132 |
sampling_slider,
|
| 133 |
-
is_scribbles,
|
| 134 |
seed_slider,
|
| 135 |
upscale_button,
|
| 136 |
],
|
| 137 |
outputs=output,
|
| 138 |
-
|
|
|
|
| 139 |
)
|
| 140 |
|
| 141 |
if __name__ == "__main__":
|
| 142 |
-
demo.launch(max_threads=1)
|
|
|
|
| 10 |
device = "mps" if torch.backends.mps.is_available() else device
|
| 11 |
|
| 12 |
image_size = 128
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
transform = transforms.Compose(
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
|
| 24 |
+
def process_images(
|
| 25 |
+
sketch,
|
| 26 |
+
scribbles,
|
| 27 |
+
sampling_steps,
|
| 28 |
+
seed_nr,
|
| 29 |
+
upscale,
|
| 30 |
+
progress=gr.Progress(),
|
| 31 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
w, h = sketch.size
|
| 33 |
|
| 34 |
+
sketch = transform(sketch.convert("RGB"))
|
| 35 |
+
scribbles = transform(scribbles.convert("RGB"))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
if upscale:
|
| 38 |
+
return transforms.Resize((h, w))(
|
| 39 |
+
sample(sketch, scribbles, sampling_steps, seed_nr, progress)
|
| 40 |
)
|
| 41 |
+
|
|
|
|
| 42 |
else:
|
| 43 |
+
return sample(sketch, scribbles, sampling_steps, seed_nr, progress)
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
theme = gr.themes.Monochrome()
|
|
|
|
| 62 |
"By default the scribbles are assumed to be merged with the sketch, if they appear on a grey background check the box below. "
|
| 63 |
"</p>"
|
| 64 |
)
|
| 65 |
+
|
| 66 |
with gr.Column():
|
| 67 |
output = gr.Image(type="pil", label="Output")
|
| 68 |
upscale_info = gr.Markdown(
|
|
|
|
| 71 |
"</p>"
|
| 72 |
)
|
| 73 |
upscale_button = gr.Checkbox(label="Stretch", value=False)
|
| 74 |
+
|
| 75 |
with gr.Row():
|
| 76 |
with gr.Column():
|
| 77 |
seed_slider = gr.Number(
|
| 78 |
+
label="Random Seed 🎲 (if the image generated is not to your liking, simply use another seed)",
|
| 79 |
+
value=random.randint(0, 10000),
|
|
|
|
|
|
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
with gr.Column():
|
|
|
|
| 84 |
minimum=1,
|
| 85 |
maximum=250,
|
| 86 |
step=1,
|
| 87 |
+
label="DDPM Sampling Steps 🔄 (the higher the number of steps the higher the quality of the images)",
|
| 88 |
value=50,
|
| 89 |
)
|
| 90 |
|
| 91 |
with gr.Row():
|
| 92 |
+
generate_button = gr.Button(value="Generate")
|
| 93 |
with gr.Row():
|
| 94 |
generate_info = gr.Markdown(
|
| 95 |
"<p style='text-align: center; font-size: 16px;'>"
|
|
|
|
| 103 |
sketch_input,
|
| 104 |
scribbles_input,
|
| 105 |
sampling_slider,
|
|
|
|
| 106 |
seed_slider,
|
| 107 |
upscale_button,
|
| 108 |
],
|
| 109 |
outputs=output,
|
| 110 |
+
concurrency_limit=1,
|
| 111 |
+
trigger_mode="once",
|
| 112 |
)
|
| 113 |
|
| 114 |
if __name__ == "__main__":
|
| 115 |
+
demo.queue().launch(max_threads=1)
|
load_model.py
CHANGED
|
@@ -8,6 +8,7 @@ from torchvision import transforms
|
|
| 8 |
import pathlib
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
from safetensors.torch import load_model, save_model
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
denoising_timesteps = 4000
|
|
@@ -61,7 +62,7 @@ else:
|
|
| 61 |
raise Exception("No model files found in the folder.")
|
| 62 |
|
| 63 |
|
| 64 |
-
def sample(sketch, scribbles, sampling_steps, seed_nr):
|
| 65 |
torch.manual_seed(seed_nr)
|
| 66 |
|
| 67 |
noise_scheduler = DDPMScheduler(
|
|
@@ -80,9 +81,9 @@ def sample(sketch, scribbles, sampling_steps, seed_nr):
|
|
| 80 |
|
| 81 |
noise_for_plain = torch.randn_like(sketch, device=device)
|
| 82 |
|
| 83 |
-
for
|
| 84 |
-
|
| 85 |
-
|
| 86 |
):
|
| 87 |
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
|
| 88 |
device
|
|
@@ -105,6 +106,8 @@ def sample(sketch, scribbles, sampling_steps, seed_nr):
|
|
| 105 |
noise_for_plain,
|
| 106 |
).prev_sample
|
| 107 |
|
|
|
|
|
|
|
| 108 |
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
|
| 109 |
|
| 110 |
return transforms.ToPILImage()(sample[0].cpu())
|
|
|
|
| 8 |
import pathlib
|
| 9 |
from torchvision.utils import save_image
|
| 10 |
from safetensors.torch import load_model, save_model
|
| 11 |
+
import time as tm
|
| 12 |
|
| 13 |
|
| 14 |
denoising_timesteps = 4000
|
|
|
|
| 62 |
raise Exception("No model files found in the folder.")
|
| 63 |
|
| 64 |
|
| 65 |
+
def sample(sketch, scribbles, sampling_steps, seed_nr, progress):
|
| 66 |
torch.manual_seed(seed_nr)
|
| 67 |
|
| 68 |
noise_scheduler = DDPMScheduler(
|
|
|
|
| 81 |
|
| 82 |
noise_for_plain = torch.randn_like(sketch, device=device)
|
| 83 |
|
| 84 |
+
for t in progress.tqdm(
|
| 85 |
+
noise_scheduler.timesteps,
|
| 86 |
+
desc="Sampling",
|
| 87 |
):
|
| 88 |
noise_for_plain = noise_scheduler.scale_model_input(noise_for_plain, t).to(
|
| 89 |
device
|
|
|
|
| 106 |
noise_for_plain,
|
| 107 |
).prev_sample
|
| 108 |
|
| 109 |
+
tm.sleep(0.01)
|
| 110 |
+
|
| 111 |
sample = torch.clamp((noise_for_plain / 2) + 0.5, 0, 1)
|
| 112 |
|
| 113 |
return transforms.ToPILImage()(sample[0].cpu())
|