Spaces:
Running
on
Zero
Running
on
Zero
| # Authors: Hui Ren (rhfeiyang.github.io) | |
| import spaces | |
| import os | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from PIL import Image | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 | |
| print(f"Using {device} device, dtype={dtype}") | |
| pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1", | |
| torch_dtype=dtype).to(device) | |
| from inference import get_lora_network, inference, get_validation_dataloader | |
| lora_map = { | |
| "None": "None", | |
| "Andre Derain": "andre-derain_subset1", | |
| "Vincent van Gogh": "van_gogh_subset1", | |
| "Andy Warhol": "andy_subset1", | |
| "Walter Battiss": "walter-battiss_subset2", | |
| "Camille Corot": "camille-corot_subset1", | |
| "Claude Monet": "monet_subset2", | |
| "Pablo Picasso": "picasso_subset1", | |
| "Jackson Pollock": "jackson-pollock_subset1", | |
| "Gerhard Richter": "gerhard-richter_subset1", | |
| "M.C. Escher": "m.c.-escher_subset1", | |
| "Albert Gleizes": "albert-gleizes_subset1", | |
| "Hokusai": "katsushika-hokusai_subset1", | |
| "Wassily Kandinsky": "kandinsky_subset1", | |
| "Gustav Klimt": "klimt_subset3", | |
| "Roy Lichtenstein": "roy-lichtenstein_subset1", | |
| "Henri Matisse": "henri-matisse_subset1", | |
| "Joan Miro": "joan-miro_subset2", | |
| } | |
| def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5): | |
| adapter_path = lora_map[adapter_choice] | |
| if adapter_path not in [None, "None"]: | |
| adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" | |
| style_prompt="sks art" | |
| else: | |
| style_prompt=None | |
| prompts = [prompt]*samples | |
| infer_loader = get_validation_dataloader(prompts,num_workers=0) | |
| network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"] | |
| pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
| height=512, width=512, scales=[1.0], | |
| save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale, | |
| start_noise=-1, show=False, style_prompt=style_prompt, no_load=True, | |
| from_scratch=True, device=device, weight_dtype=dtype)[0][1.0] | |
| return pred_images | |
| def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0): | |
| infer_loader = get_validation_dataloader(prompts, image,num_workers=0) | |
| network = get_lora_network(pipe.unet, adapter_path, weight_dtype=dtype)["network"] | |
| pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader, | |
| height=512, width=512, scales=[0.,1.], | |
| save_dir=None, seed=seed,steps=20, guidance_scale=7.5, | |
| start_noise=start_noise, show=True, style_prompt="sks art", no_load=True, | |
| from_scratch=False, device=device, weight_dtype=dtype)[0][1.0] | |
| return pred_images | |
| # def infer(prompt, samples, steps, scale, seed): | |
| # generator = torch.Generator(device=device).manual_seed(seed) | |
| # images_list = pipe( # type: ignore | |
| # [prompt] * samples, | |
| # num_inference_steps=steps, | |
| # guidance_scale=scale, | |
| # generator=generator, | |
| # ) | |
| # images = [] | |
| # safe_image = Image.open(r"data/unsafe.png") | |
| # print(images_list) | |
| # for i, image in enumerate(images_list["images"]): # type: ignore | |
| # if images_list["nsfw_content_detected"][i]: # type: ignore | |
| # images.append(safe_image) | |
| # else: | |
| # images.append(image) | |
| # return images | |
| block = gr.Blocks() | |
| # Direct infer | |
| with block: | |
| with gr.Group(): | |
| gr.Markdown(" # Art-Free Diffusion Demo") | |
| gr.Markdown("(More features in development...)") | |
| with gr.Row(): | |
| text = gr.Textbox( | |
| label="Enter your prompt", | |
| max_lines=2, | |
| placeholder="Enter your prompt", | |
| container=False, | |
| value="Park with cherry blossom trees, picnicker’s and a clear blue pond.", | |
| ) | |
| btn = gr.Button("Run", scale=0) | |
| gallery = gr.Gallery( | |
| label="Generated images", | |
| show_label=False, | |
| elem_id="gallery", | |
| columns=[1], | |
| ) | |
| advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") | |
| with gr.Row(elem_id="advanced-options"): | |
| adapter_choice = gr.Dropdown( | |
| label="Select Art Adapter", | |
| choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss", | |
| "Camille Corot", "Claude Monet", "Pablo Picasso", | |
| "Jackson Pollock", "Gerhard Richter", "M.C. Escher", | |
| "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein", | |
| "Henri Matisse", "Joan Miro" | |
| ], | |
| value="None" | |
| ) | |
| # print(adapter_choice[0]) | |
| # lora_path = lora_map[adapter_choice.value] | |
| # if lora_path is not None: | |
| # lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt" | |
| samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1) | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1) | |
| scale = gr.Slider( | |
| label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 | |
| ) | |
| print(scale) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=2147483647, | |
| step=1, | |
| randomize=True, | |
| ) | |
| gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery) | |
| advanced_button.click( | |
| None, | |
| [], | |
| text, | |
| ) | |
| block.launch() |