Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- hf_demo.py +162 -82
- hf_demo_test.ipynb +188 -125
- utils/train_util.py +2 -1
hf_demo.py
CHANGED
|
@@ -16,70 +16,93 @@ pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
|
|
| 16 |
from inference import get_lora_network, inference, get_validation_dataloader
|
| 17 |
lora_map = {
|
| 18 |
"None": "None",
|
| 19 |
-
"Andre Derain": "andre-derain_subset1",
|
| 20 |
-
"Vincent van Gogh": "van_gogh_subset1",
|
| 21 |
-
"Andy Warhol": "andy_subset1",
|
| 22 |
"Walter Battiss": "walter-battiss_subset2",
|
| 23 |
-
"Camille Corot": "camille-corot_subset1",
|
| 24 |
-
"Claude Monet": "monet_subset2",
|
| 25 |
-
"Pablo Picasso": "picasso_subset1",
|
| 26 |
"Jackson Pollock": "jackson-pollock_subset1",
|
| 27 |
-
"Gerhard Richter": "gerhard-richter_subset1",
|
| 28 |
"M.C. Escher": "m.c.-escher_subset1",
|
| 29 |
"Albert Gleizes": "albert-gleizes_subset1",
|
| 30 |
-
"Hokusai": "katsushika-hokusai_subset1",
|
| 31 |
"Wassily Kandinsky": "kandinsky_subset1",
|
| 32 |
-
"Gustav Klimt": "klimt_subset3",
|
| 33 |
"Roy Lichtenstein": "roy-lichtenstein_subset1",
|
| 34 |
-
"Henri Matisse": "henri-matisse_subset1",
|
| 35 |
"Joan Miro": "joan-miro_subset2",
|
| 36 |
}
|
|
|
|
| 37 |
@spaces.GPU
|
| 38 |
-
def
|
| 39 |
adapter_path = lora_map[adapter_choice]
|
| 40 |
if adapter_path not in [None, "None"]:
|
| 41 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
| 42 |
style_prompt="sks art"
|
| 43 |
else:
|
| 44 |
style_prompt=None
|
| 45 |
-
prompts = [prompt]
|
| 46 |
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
| 47 |
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
|
| 48 |
|
| 49 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 50 |
-
height=512, width=512, scales=[
|
| 51 |
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
| 52 |
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
|
| 53 |
-
from_scratch=True, device=device, weight_dtype=dtype)[0][1.0]
|
| 54 |
return pred_images
|
|
|
|
| 55 |
@spaces.GPU
|
| 56 |
-
def
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 60 |
-
height=512, width=512, scales=[0
|
| 61 |
-
save_dir=None, seed=seed,steps=
|
| 62 |
-
start_noise
|
| 63 |
-
from_scratch=
|
| 64 |
return pred_images
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
|
|
@@ -92,62 +115,119 @@ with block:
|
|
| 92 |
gr.Markdown("(More features in development...)")
|
| 93 |
with gr.Row():
|
| 94 |
text = gr.Textbox(
|
| 95 |
-
label="Enter your prompt",
|
| 96 |
max_lines=2,
|
| 97 |
-
placeholder="Enter your prompt",
|
| 98 |
-
container=
|
| 99 |
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
|
| 100 |
)
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
label="
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
| 113 |
-
|
| 114 |
-
with gr.Row(elem_id="advanced-options"):
|
| 115 |
adapter_choice = gr.Dropdown(
|
| 116 |
label="Select Art Adapter",
|
| 117 |
-
choices=[
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
value="
|
|
|
|
| 124 |
)
|
| 125 |
-
# print(adapter_choice[0])
|
| 126 |
-
# lora_path = lora_map[adapter_choice.value]
|
| 127 |
-
# if lora_path is not None:
|
| 128 |
-
# lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
| 129 |
|
| 130 |
-
|
| 131 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
seed = gr.Slider(
|
| 137 |
-
label="Seed",
|
| 138 |
-
minimum=0,
|
| 139 |
-
maximum=2147483647,
|
| 140 |
-
step=1,
|
| 141 |
-
randomize=True,
|
| 142 |
-
)
|
| 143 |
|
| 144 |
-
gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
|
| 145 |
-
advanced_button.click(
|
| 146 |
-
None,
|
| 147 |
-
[],
|
| 148 |
-
text,
|
| 149 |
-
)
|
| 150 |
|
|
|
|
|
|
|
| 151 |
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
block.launch()
|
|
|
|
| 16 |
from inference import get_lora_network, inference, get_validation_dataloader
|
| 17 |
lora_map = {
|
| 18 |
"None": "None",
|
| 19 |
+
"Andre Derain (fauvism)": "andre-derain_subset1",
|
| 20 |
+
"Vincent van Gogh (post impressionism)": "van_gogh_subset1",
|
| 21 |
+
"Andy Warhol (pop art)": "andy_subset1",
|
| 22 |
"Walter Battiss": "walter-battiss_subset2",
|
| 23 |
+
"Camille Corot (realism)": "camille-corot_subset1",
|
| 24 |
+
"Claude Monet (impressionism)": "monet_subset2",
|
| 25 |
+
"Pablo Picasso (cubism)": "picasso_subset1",
|
| 26 |
"Jackson Pollock": "jackson-pollock_subset1",
|
| 27 |
+
"Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1",
|
| 28 |
"M.C. Escher": "m.c.-escher_subset1",
|
| 29 |
"Albert Gleizes": "albert-gleizes_subset1",
|
| 30 |
+
"Hokusai (ukiyo-e)": "katsushika-hokusai_subset1",
|
| 31 |
"Wassily Kandinsky": "kandinsky_subset1",
|
| 32 |
+
"Gustav Klimt (art nouveau)": "klimt_subset3",
|
| 33 |
"Roy Lichtenstein": "roy-lichtenstein_subset1",
|
| 34 |
+
"Henri Matisse (abstract expressionism)": "henri-matisse_subset1",
|
| 35 |
"Joan Miro": "joan-miro_subset2",
|
| 36 |
}
|
| 37 |
+
|
| 38 |
@spaces.GPU
|
| 39 |
+
def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):
|
| 40 |
adapter_path = lora_map[adapter_choice]
|
| 41 |
if adapter_path not in [None, "None"]:
|
| 42 |
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
| 43 |
style_prompt="sks art"
|
| 44 |
else:
|
| 45 |
style_prompt=None
|
| 46 |
+
prompts = [prompt]
|
| 47 |
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
| 48 |
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
|
| 49 |
|
| 50 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 51 |
+
height=512, width=512, scales=[adapter_scale],
|
| 52 |
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
| 53 |
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
|
| 54 |
+
from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]
|
| 55 |
return pred_images
|
| 56 |
+
|
| 57 |
@spaces.GPU
|
| 58 |
+
def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):
|
| 59 |
+
style_prompt=None
|
| 60 |
+
prompts = [prompt]
|
| 61 |
+
infer_loader = get_validation_dataloader(prompts,num_workers=0)
|
| 62 |
+
network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
|
| 63 |
+
|
| 64 |
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 65 |
+
height=512, width=512, scales=[0.0],
|
| 66 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
| 67 |
+
start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,
|
| 68 |
+
from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]
|
| 69 |
return pred_images
|
| 70 |
|
| 71 |
+
|
| 72 |
+
@spaces.GPU
|
| 73 |
+
def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):
|
| 74 |
+
style_prompt=None
|
| 75 |
+
prompts = [prompt]
|
| 76 |
+
# convert np to pil
|
| 77 |
+
ref_image = [Image.fromarray(ref_image)]
|
| 78 |
+
network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"]
|
| 79 |
+
infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
|
| 80 |
+
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 81 |
+
height=512, width=512, scales=[0.0],
|
| 82 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
| 83 |
+
start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
|
| 84 |
+
from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]
|
| 85 |
+
return pred_images
|
| 86 |
+
|
| 87 |
+
@spaces.GPU
|
| 88 |
+
def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):
|
| 89 |
+
adapter_path = lora_map[adapter_choice]
|
| 90 |
+
if adapter_path not in [None, "None"]:
|
| 91 |
+
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
| 92 |
+
style_prompt="sks art"
|
| 93 |
+
else:
|
| 94 |
+
style_prompt=None
|
| 95 |
+
prompts = [prompt]
|
| 96 |
+
# convert np to pil
|
| 97 |
+
ref_image = [Image.fromarray(ref_image)]
|
| 98 |
+
network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"]
|
| 99 |
+
infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)
|
| 100 |
+
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
| 101 |
+
height=512, width=512, scales=[adapter_scale],
|
| 102 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
| 103 |
+
start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,
|
| 104 |
+
from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]
|
| 105 |
+
return pred_images
|
| 106 |
|
| 107 |
|
| 108 |
|
|
|
|
| 115 |
gr.Markdown("(More features in development...)")
|
| 116 |
with gr.Row():
|
| 117 |
text = gr.Textbox(
|
| 118 |
+
label="Enter your prompt(long and detailed would be better):",
|
| 119 |
max_lines=2,
|
| 120 |
+
placeholder="Enter your prompt(long and detailed would be better)",
|
| 121 |
+
container=True,
|
| 122 |
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
|
| 123 |
)
|
| 124 |
|
| 125 |
+
with gr.Tab('Generation'):
|
| 126 |
+
with gr.Row():
|
| 127 |
+
with gr.Column():
|
| 128 |
+
# gr.Markdown("## Art-Free Generation")
|
| 129 |
+
# gr.Markdown("Generate images from text prompts.")
|
| 130 |
+
|
| 131 |
+
gallery_gen_ori = gr.Image(
|
| 132 |
+
label="W/O Adapter",
|
| 133 |
+
show_label=True,
|
| 134 |
+
elem_id="gallery",
|
| 135 |
+
height="auto"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
with gr.Column():
|
| 140 |
+
# gr.Markdown("## Art-Free Generation")
|
| 141 |
+
# gr.Markdown("Generate images from text prompts.")
|
| 142 |
+
gallery_gen_art = gr.Image(
|
| 143 |
+
label="W/ Adapter",
|
| 144 |
+
show_label=True,
|
| 145 |
+
elem_id="gallery",
|
| 146 |
+
height="auto"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
with gr.Row():
|
| 151 |
+
btn_gen_ori = gr.Button("Art-Free Generate", scale=1)
|
| 152 |
+
btn_gen_art = gr.Button("Artistic Generate", scale=1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
with gr.Tab('Stylization'):
|
| 156 |
+
with gr.Row():
|
| 157 |
+
|
| 158 |
+
with gr.Column():
|
| 159 |
+
# gr.Markdown("## Art-Free Generation")
|
| 160 |
+
# gr.Markdown("Generate images from text prompts.")
|
| 161 |
+
|
| 162 |
+
gallery_stylization_ref = gr.Image(
|
| 163 |
+
label="Ref Image",
|
| 164 |
+
show_label=True,
|
| 165 |
+
elem_id="gallery",
|
| 166 |
+
height="auto",
|
| 167 |
+
scale=1,
|
| 168 |
+
)
|
| 169 |
+
with gr.Column(scale=2):
|
| 170 |
+
with gr.Row():
|
| 171 |
+
with gr.Column():
|
| 172 |
+
# gr.Markdown("## Art-Free Generation")
|
| 173 |
+
# gr.Markdown("Generate images from text prompts.")
|
| 174 |
+
|
| 175 |
+
gallery_stylization_ori = gr.Image(
|
| 176 |
+
label="W/O Adapter",
|
| 177 |
+
show_label=True,
|
| 178 |
+
elem_id="gallery",
|
| 179 |
+
height="auto",
|
| 180 |
+
scale=1,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
with gr.Column():
|
| 185 |
+
# gr.Markdown("## Art-Free Generation")
|
| 186 |
+
# gr.Markdown("Generate images from text prompts.")
|
| 187 |
+
gallery_stylization_art = gr.Image(
|
| 188 |
+
label="W/ Adapter",
|
| 189 |
+
show_label=True,
|
| 190 |
+
elem_id="gallery",
|
| 191 |
+
height="auto",
|
| 192 |
+
scale=1,
|
| 193 |
+
)
|
| 194 |
+
start_timestep = gr.Slider(label="Adapter Timestep", minimum=0, maximum=1000, value=800, step=1)
|
| 195 |
+
with gr.Row():
|
| 196 |
+
btn_style_ori = gr.Button("Art-Free Stylization", scale=1)
|
| 197 |
+
btn_style_art = gr.Button("Artistic Stylization", scale=1)
|
| 198 |
|
| 199 |
|
| 200 |
+
with gr.Row():
|
| 201 |
+
# with gr.Column():
|
| 202 |
+
# samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1)
|
| 203 |
+
scale = gr.Slider(
|
| 204 |
+
label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1
|
| 205 |
+
)
|
| 206 |
+
# with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
adapter_choice = gr.Dropdown(
|
| 208 |
label="Select Art Adapter",
|
| 209 |
+
choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)",
|
| 210 |
+
"Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)",
|
| 211 |
+
"Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)",
|
| 212 |
+
"Walter Battiss", "Jackson Pollock", "M.C. Escher", "Albert Gleizes", "Wassily Kandinsky",
|
| 213 |
+
"Roy Lichtenstein", "Joan Miro"
|
| 214 |
+
],
|
| 215 |
+
value="Andre Derain (fauvism)",
|
| 216 |
+
scale=1
|
| 217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
+
with gr.Row():
|
| 220 |
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
|
| 221 |
+
adapter_scale = gr.Slider(label="Stylization Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)
|
| 222 |
+
|
| 223 |
+
with gr.Row():
|
| 224 |
+
seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
+
gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)
|
| 228 |
+
gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)
|
| 229 |
|
| 230 |
+
gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)
|
| 231 |
+
gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)
|
| 232 |
|
| 233 |
+
block.launch(sharing=True)
|
hf_demo_test.ipynb
CHANGED
|
@@ -45,7 +45,9 @@
|
|
| 45 |
},
|
| 46 |
"outputs": [],
|
| 47 |
"source": [
|
| 48 |
-
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
|
|
|
|
|
|
|
| 49 |
]
|
| 50 |
},
|
| 51 |
{
|
|
@@ -70,7 +72,7 @@
|
|
| 70 |
{
|
| 71 |
"data": {
|
| 72 |
"application/vnd.jupyter.widget-view+json": {
|
| 73 |
-
"model_id": "
|
| 74 |
"version_major": 2,
|
| 75 |
"version_minor": 0
|
| 76 |
},
|
|
@@ -83,8 +85,8 @@
|
|
| 83 |
}
|
| 84 |
],
|
| 85 |
"source": [
|
| 86 |
-
"pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\"
|
| 87 |
-
"device
|
| 88 |
]
|
| 89 |
},
|
| 90 |
{
|
|
@@ -102,77 +104,105 @@
|
|
| 102 |
"from inference import get_lora_network, inference, get_validation_dataloader\n",
|
| 103 |
"lora_map = {\n",
|
| 104 |
" \"None\": \"None\",\n",
|
| 105 |
-
" \"Andre Derain\": \"andre-derain_subset1\",\n",
|
| 106 |
-
" \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
|
| 107 |
-
" \"Andy Warhol\": \"andy_subset1\",\n",
|
| 108 |
" \"Walter Battiss\": \"walter-battiss_subset2\",\n",
|
| 109 |
-
" \"Camille Corot\": \"camille-corot_subset1\",\n",
|
| 110 |
-
" \"Claude Monet\": \"monet_subset2\",\n",
|
| 111 |
-
" \"Pablo Picasso\": \"picasso_subset1\",\n",
|
| 112 |
" \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
|
| 113 |
-
" \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
|
| 114 |
" \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
|
| 115 |
" \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
|
| 116 |
-
" \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
|
| 117 |
" \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
|
| 118 |
-
" \"Gustav Klimt\": \"klimt_subset3\",\n",
|
| 119 |
" \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
|
| 120 |
-
" \"Henri Matisse\": \"henri-matisse_subset1\",\n",
|
| 121 |
" \"Joan Miro\": \"joan-miro_subset2\",\n",
|
| 122 |
"}\n",
|
| 123 |
"\n",
|
| 124 |
-
"
|
|
|
|
|
|
|
| 125 |
" adapter_path = lora_map[adapter_choice]\n",
|
| 126 |
" if adapter_path not in [None, \"None\"]:\n",
|
| 127 |
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
"\n",
|
| 129 |
-
" prompts = [prompt]*samples\n",
|
| 130 |
-
" infer_loader = get_validation_dataloader(prompts)\n",
|
| 131 |
-
" network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
|
| 132 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
| 133 |
-
" height=512, width=512, scales=[
|
| 134 |
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
| 135 |
-
" start_noise=-1, show=False, style_prompt
|
| 136 |
-
" from_scratch=True)[0][1.0]\n",
|
| 137 |
" return pred_images\n",
|
| 138 |
"\n",
|
| 139 |
-
"
|
| 140 |
-
"
|
| 141 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
| 143 |
-
" height=512, width=512, scales=[0
|
| 144 |
-
" save_dir=None, seed=seed,steps=
|
| 145 |
-
" start_noise
|
| 146 |
-
" from_scratch=
|
| 147 |
" return pred_images\n",
|
| 148 |
"\n",
|
| 149 |
-
"
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
-
"
|
| 153 |
-
"
|
| 154 |
-
"#
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
-
"
|
| 158 |
-
"
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
]
|
| 167 |
},
|
| 168 |
{
|
| 169 |
"cell_type": "code",
|
| 170 |
-
"execution_count":
|
| 171 |
"id": "aa33e9d104023847",
|
| 172 |
"metadata": {
|
| 173 |
"ExecuteTime": {
|
| 174 |
-
"end_time": "2024-12-
|
| 175 |
-
"start_time": "2024-12-
|
| 176 |
}
|
| 177 |
},
|
| 178 |
"outputs": [
|
|
@@ -180,9 +210,10 @@
|
|
| 180 |
"name": "stdout",
|
| 181 |
"output_type": "stream",
|
| 182 |
"text": [
|
| 183 |
-
"
|
| 184 |
-
"
|
| 185 |
-
"
|
|
|
|
| 186 |
"\n",
|
| 187 |
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
| 188 |
]
|
|
@@ -190,7 +221,7 @@
|
|
| 190 |
{
|
| 191 |
"data": {
|
| 192 |
"text/html": [
|
| 193 |
-
"<div><iframe src=\"https://
|
| 194 |
],
|
| 195 |
"text/plain": [
|
| 196 |
"<IPython.core.display.HTML object>"
|
|
@@ -203,103 +234,135 @@
|
|
| 203 |
"data": {
|
| 204 |
"text/plain": []
|
| 205 |
},
|
| 206 |
-
"execution_count":
|
| 207 |
"metadata": {},
|
| 208 |
"output_type": "execute_result"
|
| 209 |
-
},
|
| 210 |
-
{
|
| 211 |
-
"name": "stdout",
|
| 212 |
-
"output_type": "stream",
|
| 213 |
-
"text": [
|
| 214 |
-
"Train method: None\n",
|
| 215 |
-
"Rank: 1, Alpha: 1\n",
|
| 216 |
-
"create LoRA for U-Net: 0 modules.\n",
|
| 217 |
-
"save dir: None\n",
|
| 218 |
-
"['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
|
| 219 |
-
]
|
| 220 |
-
},
|
| 221 |
-
{
|
| 222 |
-
"name": "stderr",
|
| 223 |
-
"output_type": "stream",
|
| 224 |
-
"text": [
|
| 225 |
-
"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
|
| 226 |
-
" return F.conv2d(input, weight, bias, self.stride,\n",
|
| 227 |
-
"\n",
|
| 228 |
-
"00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
|
| 229 |
-
]
|
| 230 |
-
},
|
| 231 |
-
{
|
| 232 |
-
"name": "stdout",
|
| 233 |
-
"output_type": "stream",
|
| 234 |
-
"text": [
|
| 235 |
-
"Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
|
| 236 |
-
]
|
| 237 |
}
|
| 238 |
],
|
| 239 |
"source": [
|
| 240 |
"block = gr.Blocks()\n",
|
| 241 |
"# Direct infer\n",
|
|
|
|
| 242 |
"with block:\n",
|
| 243 |
" with gr.Group():\n",
|
| 244 |
" gr.Markdown(\" # Art-Free Diffusion Demo\")\n",
|
|
|
|
| 245 |
" with gr.Row():\n",
|
| 246 |
" text = gr.Textbox(\n",
|
| 247 |
-
" label=\"Enter your prompt
|
| 248 |
" max_lines=2,\n",
|
| 249 |
-
" placeholder=\"Enter your prompt\",\n",
|
| 250 |
-
" container=
|
| 251 |
" value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
|
| 252 |
" )\n",
|
| 253 |
-
" \n",
|
| 254 |
"\n",
|
| 255 |
-
"
|
| 256 |
-
"
|
| 257 |
-
"
|
| 258 |
-
"
|
| 259 |
-
"
|
| 260 |
-
" elem_id=\"gallery\",\n",
|
| 261 |
-
" columns=[2],\n",
|
| 262 |
-
" )\n",
|
| 263 |
"\n",
|
| 264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
"\n",
|
| 266 |
-
" with gr.Row(elem_id=\"advanced-options\"):\n",
|
| 267 |
-
" adapter_choice = gr.Dropdown(\n",
|
| 268 |
-
" label=\"Choose adapter\",\n",
|
| 269 |
-
" choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
|
| 270 |
-
" \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
|
| 271 |
-
" \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
|
| 272 |
-
" \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
|
| 273 |
-
" \"Henri Matisse\", \"Joan Miro\"\n",
|
| 274 |
-
" ],\n",
|
| 275 |
-
" value=\"None\"\n",
|
| 276 |
-
" )\n",
|
| 277 |
-
" # print(adapter_choice[0])\n",
|
| 278 |
-
" # lora_path = lora_map[adapter_choice.value]\n",
|
| 279 |
-
" # if lora_path is not None:\n",
|
| 280 |
-
" # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
| 281 |
"\n",
|
| 282 |
-
"
|
| 283 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
" scale = gr.Slider(\n",
|
| 285 |
-
" label=\"Guidance Scale\", minimum=0, maximum=
|
| 286 |
" )\n",
|
| 287 |
-
"
|
| 288 |
-
"
|
| 289 |
-
" label=\"
|
| 290 |
-
"
|
| 291 |
-
"
|
| 292 |
-
"
|
| 293 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
" )\n",
|
| 295 |
"\n",
|
| 296 |
-
" gr.
|
| 297 |
-
"
|
| 298 |
-
"
|
| 299 |
-
"
|
| 300 |
-
"
|
| 301 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
"\n",
|
|
|
|
|
|
|
| 303 |
"\n",
|
| 304 |
"block.launch(share=True)"
|
| 305 |
]
|
|
|
|
| 45 |
},
|
| 46 |
"outputs": [],
|
| 47 |
"source": [
|
| 48 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
|
| 49 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 50 |
+
"dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16"
|
| 51 |
]
|
| 52 |
},
|
| 53 |
{
|
|
|
|
| 72 |
{
|
| 73 |
"data": {
|
| 74 |
"application/vnd.jupyter.widget-view+json": {
|
| 75 |
+
"model_id": "acc42f294243439798e4d77d1a59296d",
|
| 76 |
"version_major": 2,
|
| 77 |
"version_minor": 0
|
| 78 |
},
|
|
|
|
| 85 |
}
|
| 86 |
],
|
| 87 |
"source": [
|
| 88 |
+
"pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",\n",
|
| 89 |
+
" torch_dtype=dtype).to(device)"
|
| 90 |
]
|
| 91 |
},
|
| 92 |
{
|
|
|
|
| 104 |
"from inference import get_lora_network, inference, get_validation_dataloader\n",
|
| 105 |
"lora_map = {\n",
|
| 106 |
" \"None\": \"None\",\n",
|
| 107 |
+
" \"Andre Derain (fauvism)\": \"andre-derain_subset1\",\n",
|
| 108 |
+
" \"Vincent van Gogh (post impressionism)\": \"van_gogh_subset1\",\n",
|
| 109 |
+
" \"Andy Warhol (pop art)\": \"andy_subset1\",\n",
|
| 110 |
" \"Walter Battiss\": \"walter-battiss_subset2\",\n",
|
| 111 |
+
" \"Camille Corot (realism)\": \"camille-corot_subset1\",\n",
|
| 112 |
+
" \"Claude Monet (impressionism)\": \"monet_subset2\",\n",
|
| 113 |
+
" \"Pablo Picasso (cubism)\": \"picasso_subset1\",\n",
|
| 114 |
" \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
|
| 115 |
+
" \"Gerhard Richter (abstract expressionism)\": \"gerhard-richter_subset1\",\n",
|
| 116 |
" \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
|
| 117 |
" \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
|
| 118 |
+
" \"Hokusai (ukiyo-e)\": \"katsushika-hokusai_subset1\",\n",
|
| 119 |
" \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
|
| 120 |
+
" \"Gustav Klimt (art nouveau)\": \"klimt_subset3\",\n",
|
| 121 |
" \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
|
| 122 |
+
" \"Henri Matisse (abstract expressionism)\": \"henri-matisse_subset1\",\n",
|
| 123 |
" \"Joan Miro\": \"joan-miro_subset2\",\n",
|
| 124 |
"}\n",
|
| 125 |
"\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0):\n",
|
| 129 |
" adapter_path = lora_map[adapter_choice]\n",
|
| 130 |
" if adapter_path not in [None, \"None\"]:\n",
|
| 131 |
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
| 132 |
+
" style_prompt=\"sks art\"\n",
|
| 133 |
+
" else:\n",
|
| 134 |
+
" style_prompt=None\n",
|
| 135 |
+
" prompts = [prompt]\n",
|
| 136 |
+
" infer_loader = get_validation_dataloader(prompts,num_workers=0)\n",
|
| 137 |
+
" network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n",
|
| 138 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 139 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
| 140 |
+
" height=512, width=512, scales=[adapter_scale],\n",
|
| 141 |
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
| 142 |
+
" start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n",
|
| 143 |
+
" from_scratch=True, device=device, weight_dtype=dtype)[0][1.0][0]\n",
|
| 144 |
" return pred_images\n",
|
| 145 |
"\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5):\n",
|
| 148 |
+
" style_prompt=None\n",
|
| 149 |
+
" prompts = [prompt]\n",
|
| 150 |
+
" infer_loader = get_validation_dataloader(prompts,num_workers=0)\n",
|
| 151 |
+
" network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n",
|
| 152 |
+
"\n",
|
| 153 |
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
| 154 |
+
" height=512, width=512, scales=[0.0],\n",
|
| 155 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
| 156 |
+
" start_noise=-1, show=False, style_prompt=style_prompt, no_load=True,\n",
|
| 157 |
+
" from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0]\n",
|
| 158 |
" return pred_images\n",
|
| 159 |
"\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800):\n",
|
| 163 |
+
" style_prompt=None\n",
|
| 164 |
+
" prompts = [prompt]\n",
|
| 165 |
+
" # convert np to pil\n",
|
| 166 |
+
" ref_image = [Image.fromarray(ref_image)]\n",
|
| 167 |
+
" network = get_lora_network(pipe.unet, \"None\", weight_dtype=dtype)[\"network\"]\n",
|
| 168 |
+
" infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n",
|
| 169 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
| 170 |
+
" height=512, width=512, scales=[0.0],\n",
|
| 171 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
| 172 |
+
" start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n",
|
| 173 |
+
" from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0]\n",
|
| 174 |
+
" return pred_images\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800):\n",
|
| 178 |
+
" adapter_path = lora_map[adapter_choice]\n",
|
| 179 |
+
" if adapter_path not in [None, \"None\"]:\n",
|
| 180 |
+
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
| 181 |
+
" style_prompt=\"sks art\"\n",
|
| 182 |
+
" else:\n",
|
| 183 |
+
" style_prompt=None\n",
|
| 184 |
+
" prompts = [prompt]\n",
|
| 185 |
+
" # convert np to pil\n",
|
| 186 |
+
" ref_image = [Image.fromarray(ref_image)]\n",
|
| 187 |
+
" network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)[\"network\"]\n",
|
| 188 |
+
" infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0)\n",
|
| 189 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
| 190 |
+
" height=512, width=512, scales=[adapter_scale],\n",
|
| 191 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
| 192 |
+
" start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True,\n",
|
| 193 |
+
" from_scratch=False, device=device, weight_dtype=dtype)[0][1.0][0]\n",
|
| 194 |
+
" return pred_images\n",
|
| 195 |
+
"\n"
|
| 196 |
]
|
| 197 |
},
|
| 198 |
{
|
| 199 |
"cell_type": "code",
|
| 200 |
+
"execution_count": 15,
|
| 201 |
"id": "aa33e9d104023847",
|
| 202 |
"metadata": {
|
| 203 |
"ExecuteTime": {
|
| 204 |
+
"end_time": "2024-12-10T02:56:13.419303Z",
|
| 205 |
+
"start_time": "2024-12-10T02:56:13.002796Z"
|
| 206 |
}
|
| 207 |
},
|
| 208 |
"outputs": [
|
|
|
|
| 210 |
"name": "stdout",
|
| 211 |
"output_type": "stream",
|
| 212 |
"text": [
|
| 213 |
+
"Running on local URL: http://127.0.0.1:7869\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB\n",
|
| 216 |
+
"Running on public URL: https://0fd0c028b349b76a72.gradio.live\n",
|
| 217 |
"\n",
|
| 218 |
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
| 219 |
]
|
|
|
|
| 221 |
{
|
| 222 |
"data": {
|
| 223 |
"text/html": [
|
| 224 |
+
"<div><iframe src=\"https://0fd0c028b349b76a72.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
| 225 |
],
|
| 226 |
"text/plain": [
|
| 227 |
"<IPython.core.display.HTML object>"
|
|
|
|
| 234 |
"data": {
|
| 235 |
"text/plain": []
|
| 236 |
},
|
| 237 |
+
"execution_count": 15,
|
| 238 |
"metadata": {},
|
| 239 |
"output_type": "execute_result"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
}
|
| 241 |
],
|
| 242 |
"source": [
|
| 243 |
"block = gr.Blocks()\n",
|
| 244 |
"# Direct infer\n",
|
| 245 |
+
"# Direct infer\n",
|
| 246 |
"with block:\n",
|
| 247 |
" with gr.Group():\n",
|
| 248 |
" gr.Markdown(\" # Art-Free Diffusion Demo\")\n",
|
| 249 |
+
" gr.Markdown(\"(More features in development...)\")\n",
|
| 250 |
" with gr.Row():\n",
|
| 251 |
" text = gr.Textbox(\n",
|
| 252 |
+
" label=\"Enter your prompt(long and detailed would be better):\",\n",
|
| 253 |
" max_lines=2,\n",
|
| 254 |
+
" placeholder=\"Enter your prompt(long and detailed would be better)\",\n",
|
| 255 |
+
" container=True,\n",
|
| 256 |
" value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
|
| 257 |
" )\n",
|
|
|
|
| 258 |
"\n",
|
| 259 |
+
" with gr.Tab('Generation'):\n",
|
| 260 |
+
" with gr.Row():\n",
|
| 261 |
+
" with gr.Column():\n",
|
| 262 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
| 263 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
|
|
|
|
|
|
|
|
|
| 264 |
"\n",
|
| 265 |
+
" gallery_gen_ori = gr.Image(\n",
|
| 266 |
+
" label=\"W/O Adapter\",\n",
|
| 267 |
+
" show_label=True,\n",
|
| 268 |
+
" elem_id=\"gallery\",\n",
|
| 269 |
+
" height=\"auto\"\n",
|
| 270 |
+
" )\n",
|
| 271 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
"\n",
|
| 273 |
+
" with gr.Column():\n",
|
| 274 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
| 275 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
| 276 |
+
" gallery_gen_art = gr.Image(\n",
|
| 277 |
+
" label=\"W/ Adapter\",\n",
|
| 278 |
+
" show_label=True,\n",
|
| 279 |
+
" elem_id=\"gallery\",\n",
|
| 280 |
+
" height=\"auto\"\n",
|
| 281 |
+
" )\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"\n",
|
| 284 |
+
" with gr.Row():\n",
|
| 285 |
+
" btn_gen_ori = gr.Button(\"Art-Free Generate\", scale=1)\n",
|
| 286 |
+
" btn_gen_art = gr.Button(\"Artistic Generate\", scale=1)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"\n",
|
| 289 |
+
" with gr.Tab('Stylization'):\n",
|
| 290 |
+
" with gr.Row():\n",
|
| 291 |
+
"\n",
|
| 292 |
+
" with gr.Column():\n",
|
| 293 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
| 294 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
| 295 |
+
"\n",
|
| 296 |
+
" gallery_stylization_ref = gr.Image(\n",
|
| 297 |
+
" label=\"Ref Image\",\n",
|
| 298 |
+
" show_label=True,\n",
|
| 299 |
+
" elem_id=\"gallery\",\n",
|
| 300 |
+
" height=\"auto\",\n",
|
| 301 |
+
" scale=1,\n",
|
| 302 |
+
" )\n",
|
| 303 |
+
" with gr.Column(scale=2):\n",
|
| 304 |
+
" with gr.Row():\n",
|
| 305 |
+
" with gr.Column():\n",
|
| 306 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
| 307 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
| 308 |
+
" \n",
|
| 309 |
+
" gallery_stylization_ori = gr.Image(\n",
|
| 310 |
+
" label=\"W/O Adapter\",\n",
|
| 311 |
+
" show_label=True,\n",
|
| 312 |
+
" elem_id=\"gallery\",\n",
|
| 313 |
+
" height=\"auto\",\n",
|
| 314 |
+
" scale=1,\n",
|
| 315 |
+
" )\n",
|
| 316 |
+
" \n",
|
| 317 |
+
" \n",
|
| 318 |
+
" with gr.Column():\n",
|
| 319 |
+
" # gr.Markdown(\"## Art-Free Generation\")\n",
|
| 320 |
+
" # gr.Markdown(\"Generate images from text prompts.\")\n",
|
| 321 |
+
" gallery_stylization_art = gr.Image(\n",
|
| 322 |
+
" label=\"W/ Adapter\",\n",
|
| 323 |
+
" show_label=True,\n",
|
| 324 |
+
" elem_id=\"gallery\",\n",
|
| 325 |
+
" height=\"auto\",\n",
|
| 326 |
+
" scale=1,\n",
|
| 327 |
+
" )\n",
|
| 328 |
+
" start_timestep = gr.Slider(label=\"Adapter Timestep\", minimum=0, maximum=1000, value=800, step=1)\n",
|
| 329 |
+
" with gr.Row():\n",
|
| 330 |
+
" btn_style_ori = gr.Button(\"Art-Free Stylization\", scale=1)\n",
|
| 331 |
+
" btn_style_art = gr.Button(\"Artistic Stylization\", scale=1)\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"\n",
|
| 334 |
+
" with gr.Row():\n",
|
| 335 |
+
" # with gr.Column():\n",
|
| 336 |
+
" # samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1, scale=1)\n",
|
| 337 |
" scale = gr.Slider(\n",
|
| 338 |
+
" label=\"Guidance Scale\", minimum=0, maximum=20, value=7.5, step=0.1\n",
|
| 339 |
" )\n",
|
| 340 |
+
" # with gr.Column():\n",
|
| 341 |
+
" adapter_choice = gr.Dropdown(\n",
|
| 342 |
+
" label=\"Select Art Adapter\",\n",
|
| 343 |
+
" choices=[ \"Andre Derain (fauvism)\",\"Vincent van Gogh (post impressionism)\",\"Andy Warhol (pop art)\",\n",
|
| 344 |
+
" \"Camille Corot (realism)\", \"Claude Monet (impressionism)\", \"Pablo Picasso (cubism)\", \"Gerhard Richter (abstract expressionism)\",\n",
|
| 345 |
+
" \"Hokusai (ukiyo-e)\", \"Gustav Klimt (art nouveau)\", \"Henri Matisse (abstract expressionism)\",\n",
|
| 346 |
+
" \"Walter Battiss\", \"Jackson Pollock\", \"M.C. Escher\", \"Albert Gleizes\", \"Wassily Kandinsky\",\n",
|
| 347 |
+
" \"Roy Lichtenstein\", \"Joan Miro\"\n",
|
| 348 |
+
" ],\n",
|
| 349 |
+
" value=\"Andre Derain (fauvism)\",\n",
|
| 350 |
+
" scale=1\n",
|
| 351 |
" )\n",
|
| 352 |
"\n",
|
| 353 |
+
" with gr.Row():\n",
|
| 354 |
+
" steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
|
| 355 |
+
" adapter_scale = gr.Slider(label=\"Stylization Scale\", minimum=0, maximum=1.5, value=1., step=0.1, scale=1)\n",
|
| 356 |
+
"\n",
|
| 357 |
+
" with gr.Row():\n",
|
| 358 |
+
" seed = gr.Slider(label=\"Seed\",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1)\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"\n",
|
| 361 |
+
" gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori)\n",
|
| 362 |
+
" gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art)\n",
|
| 363 |
"\n",
|
| 364 |
+
" gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori)\n",
|
| 365 |
+
" gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art)\n",
|
| 366 |
"\n",
|
| 367 |
"block.launch(share=True)"
|
| 368 |
]
|
utils/train_util.py
CHANGED
|
@@ -249,7 +249,8 @@ def get_noisy_image(
|
|
| 249 |
image = img
|
| 250 |
# im_orig = image
|
| 251 |
device = vae.device
|
| 252 |
-
|
|
|
|
| 253 |
|
| 254 |
init_latents = vae.encode(image).latent_dist.sample(None)
|
| 255 |
init_latents = vae.config.scaling_factor * init_latents
|
|
|
|
| 249 |
image = img
|
| 250 |
# im_orig = image
|
| 251 |
device = vae.device
|
| 252 |
+
weight_dtype = vae.dtype
|
| 253 |
+
image = image_processor.preprocess(image).to(device).to(weight_dtype).to(weight_dtype)
|
| 254 |
|
| 255 |
init_latents = vae.encode(image).latent_dist.sample(None)
|
| 256 |
init_latents = vae.config.scaling_factor * init_latents
|