Spaces:
Running
on
Zero
Running
on
Zero
| # Authors: Hui Ren (rhfeiyang.github.io) | |
| import spaces | |
| import os | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from PIL import Image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
| print(f"Using {device} device, dtype={dtype}") | |
| pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1", | |
| torch_dtype=dtype).to(device) | |
| from inference import get_lora_network, inference, get_validation_dataloader | |
| lora_map = { | |
| "None": "None", | |
| "Andre Derain (fauvism)": "andre-derain_subset1", | |
| "Vincent van Gogh (post impressionism)": "van_gogh_subset1", | |
| "Andy Warhol (pop art)": "andy_subset1", | |
| "Walter Battiss": "walter-battiss_subset2", | |
| "Camille Corot (realism)": "camille-corot_subset1", | |
| "Claude Monet (impressionism)": "monet_subset2", | |
| "Pablo Picasso (cubism)": "picasso_subset1", | |
| "Jackson Pollock": "jackson-pollock_subset1", | |
| "Gerhard Richter (abstract expressionism)": "gerhard-richter_subset1", | |
| "M.C. Escher": "m.c.-escher_subset1", | |
| "Albert Gleizes": "albert-gleizes_subset1", | |
| "Hokusai (ukiyo-e)": "katsushika-hokusai_subset1", | |
| "Wassily Kandinsky": "kandinsky_subset1", | |
| "Gustav Klimt (art nouveau)": "klimt_subset3", | |
| "Roy Lichtenstein": "roy-lichtenstein_subset1", | |
| "Henri Matisse (abstract expressionism)": "henri-matisse_subset1", | |
| "Joan Miro": "joan-miro_subset2", | |
| } | |
| def demo_inference_gen_artistic(adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0): | |
| adapter_path = lora_map[adapter_choice] | |
| if adapter_path not in [None, "None"]: | |
| adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" | |
| style_prompt="sks art" | |
| else: | |
| style_prompt=None | |
| prompts = [prompt] | |
| infer_loader = get_validation_dataloader(prompts,num_workers=0) | |
| network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"] | |
| pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
| height=512, width=512, scales=[adapter_scale], | |
| save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, | |
| start_noise=-1, show=False, style_prompt=style_prompt, no_load=True, | |
| from_scratch=True, device=device, weight_dtype=dtype)[0][adapter_scale][0] | |
| return pred_images | |
| def demo_inference_gen_ori( prompt:str, seed:int=0, steps=50, guidance_scale=7.5): | |
| style_prompt=None | |
| prompts = [prompt] | |
| infer_loader = get_validation_dataloader(prompts,num_workers=0) | |
| network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"] | |
| pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
| height=512, width=512, scales=[0.0], | |
| save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, | |
| start_noise=-1, show=False, style_prompt=style_prompt, no_load=True, | |
| from_scratch=True, device=device, weight_dtype=dtype)[0][0.0][0] | |
| return pred_images | |
| def demo_inference_stylization_ori(ref_image, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, start_noise=800): | |
| style_prompt=None | |
| prompts = [prompt] | |
| # convert np to pil | |
| ref_image = [Image.fromarray(ref_image)] | |
| network = get_lora_network(pipe.unet, "None", weight_dtype=dtype)["network"] | |
| infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0) | |
| pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
| height=512, width=512, scales=[0.0], | |
| save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, | |
| start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True, | |
| from_scratch=False, device=device, weight_dtype=dtype)[0][0.0][0] | |
| return pred_images | |
| def demo_inference_stylization_artistic(ref_image, adapter_choice:str, prompt:str, seed:int=0, steps=50, guidance_scale=7.5, adapter_scale=1.0,start_noise=800): | |
| adapter_path = lora_map[adapter_choice] | |
| if adapter_path not in [None, "None"]: | |
| adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" | |
| style_prompt="sks art" | |
| else: | |
| style_prompt=None | |
| prompts = [prompt] | |
| # convert np to pil | |
| ref_image = [Image.fromarray(ref_image)] | |
| network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"] | |
| infer_loader = get_validation_dataloader(prompts, ref_image,num_workers=0) | |
| pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
| height=512, width=512, scales=[adapter_scale], | |
| save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, | |
| start_noise=start_noise, show=False, style_prompt=style_prompt, no_load=True, | |
| from_scratch=False, device=device, weight_dtype=dtype)[0][adapter_scale][0] | |
| return pred_images | |
| block = gr.Blocks() | |
| # Direct infer | |
| with block: | |
| with gr.Group(): | |
| gr.Markdown(" # Art-Free Diffusion Demo") | |
| gr.Markdown("(More features in development...)") | |
| with gr.Row(): | |
| text = gr.Textbox( | |
| label="Enter your prompt (long and detailed would be better):", | |
| max_lines=10, | |
| placeholder="Enter your prompt (long and detailed would be better)", | |
| container=True, | |
| value="A blue bench situated in a park, surrounded by trees and leaves. The bench is positioned under a tree, providing shade and a peaceful atmosphere. There are several benches in the park, with one being closer to the foreground and the others further in the background. A person can be seen in the distance, possibly enjoying the park or taking a walk. The overall scene is serene and inviting, with the bench serving as a focal point in the park's landscape.", | |
| ) | |
| with gr.Tab('Generation'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # gr.Markdown("## Art-Free Generation") | |
| # gr.Markdown("Generate images from text prompts.") | |
| gallery_gen_ori = gr.Image( | |
| label="W/O Adapter", | |
| show_label=True, | |
| elem_id="gallery", | |
| height="auto" | |
| ) | |
| with gr.Column(): | |
| # gr.Markdown("## Art-Free Generation") | |
| # gr.Markdown("Generate images from text prompts.") | |
| gallery_gen_art = gr.Image( | |
| label="W/ Adapter", | |
| show_label=True, | |
| elem_id="gallery", | |
| height="auto" | |
| ) | |
| with gr.Row(): | |
| btn_gen_ori = gr.Button("Art-Free Generate", scale=1) | |
| btn_gen_art = gr.Button("Artistic Generate", scale=1) | |
| with gr.Tab('Stylization'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # gr.Markdown("## Art-Free Generation") | |
| # gr.Markdown("Generate images from text prompts.") | |
| gallery_stylization_ref = gr.Image( | |
| label="Ref Image", | |
| show_label=True, | |
| elem_id="gallery", | |
| height="auto", | |
| scale=1, | |
| value="data/003904765.jpg" | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # gr.Markdown("## Art-Free Generation") | |
| # gr.Markdown("Generate images from text prompts.") | |
| gallery_stylization_ori = gr.Image( | |
| label="W/O Adapter", | |
| show_label=True, | |
| elem_id="gallery", | |
| height="auto", | |
| scale=1, | |
| ) | |
| with gr.Column(): | |
| # gr.Markdown("## Art-Free Generation") | |
| # gr.Markdown("Generate images from text prompts.") | |
| gallery_stylization_art = gr.Image( | |
| label="W/ Adapter", | |
| show_label=True, | |
| elem_id="gallery", | |
| height="auto", | |
| scale=1, | |
| ) | |
| start_timestep = gr.Slider(label="Adapter Timestep", minimum=0, maximum=1000, value=800, step=1) | |
| with gr.Row(): | |
| btn_style_ori = gr.Button("Art-Free Stylize", scale=1) | |
| btn_style_art = gr.Button("Artistic Stylize", scale=1) | |
| with gr.Row(): | |
| # with gr.Column(): | |
| # samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1, scale=1) | |
| scale = gr.Slider( | |
| label="Guidance Scale", minimum=0, maximum=20, value=7.5, step=0.1 | |
| ) | |
| # with gr.Column(): | |
| adapter_choice = gr.Dropdown( | |
| label="Select Art Adapter", | |
| choices=[ "Andre Derain (fauvism)","Vincent van Gogh (post impressionism)","Andy Warhol (pop art)", | |
| "Camille Corot (realism)", "Claude Monet (impressionism)", "Pablo Picasso (cubism)", "Gerhard Richter (abstract expressionism)", | |
| "Hokusai (ukiyo-e)", "Gustav Klimt (art nouveau)", "Henri Matisse (abstract expressionism)", | |
| "Walter Battiss", "Jackson Pollock", "M.C. Escher", "Albert Gleizes", "Wassily Kandinsky", | |
| "Roy Lichtenstein", "Joan Miro" | |
| ], | |
| value="Andre Derain (fauvism)", | |
| scale=1 | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1) | |
| adapter_scale = gr.Slider(label="Adapter Scale", minimum=0, maximum=1.5, value=1., step=0.1, scale=1) | |
| with gr.Row(): | |
| seed = gr.Slider(label="Seed",minimum=0,maximum=2147483647,step=1,randomize=True,scale=1) | |
| gr.on([btn_gen_ori.click], demo_inference_gen_ori, inputs=[text, seed, steps, scale], outputs=gallery_gen_ori) | |
| gr.on([btn_gen_art.click], demo_inference_gen_artistic, inputs=[adapter_choice, text, seed, steps, scale, adapter_scale], outputs=gallery_gen_art) | |
| gr.on([btn_style_ori.click], demo_inference_stylization_ori, inputs=[gallery_stylization_ref, text, seed, steps, scale, start_timestep], outputs=gallery_stylization_ori) | |
| gr.on([btn_style_art.click], demo_inference_stylization_artistic, inputs=[gallery_stylization_ref, adapter_choice, text, seed, steps, scale, adapter_scale, start_timestep], outputs=gallery_stylization_art) | |
| block.launch() | |
| # block.launch(sharing=True) |