Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| from pathlib import Path | |
| ''' | |
| os.system("pip install -U huggingface_hub") | |
| os.system("pip install -U diffusers") | |
| if os.path.exists("wuerstchen"): | |
| shutil.rmtree("wuerstchen") | |
| os.system("git clone https://huggingface.co/warp-ai/wuerstchen") | |
| if os.path.exists("wuerstchen/.git"): | |
| shutil.rmtree("wuerstchen/.git") | |
| ''' | |
| import sys | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import random | |
| from diffusers import AutoPipelineForText2Image | |
| from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| ''' | |
| assert os.path.exists("wuerstchen") | |
| pipe = AutoPipelineForText2Image.from_pretrained(Path("wuerstchen"), local_files_only = True, | |
| torch_dtype=torch.float32) | |
| ''' | |
| pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", | |
| torch_dtype=torch.float32) | |
| pipe.to(device) | |
| pipe.safety_checker = None | |
| ''' | |
| #### 9min a sample (2 cores) | |
| caption = "Anthropomorphic cat dressed as a fire fighter" | |
| images = pipe( | |
| caption, | |
| width=512, | |
| height=512, | |
| prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, #### length of 30 | |
| prior_guidance_scale=4.0, | |
| num_images_per_prompt=1, | |
| num_inference_steps = 6, #### default num of 12, 6 favour | |
| ).images | |
| ''' | |
| def process(prompt, num_samples, image_resolution, sample_steps, seed,): | |
| from PIL import Image | |
| with torch.no_grad(): | |
| if seed == -1: | |
| seed = random.randint(0, 65535) | |
| #control_image = Image.fromarray(detected_map) | |
| # run inference | |
| #generator = torch.Generator(device=device).manual_seed(seed) | |
| H = image_resolution | |
| W = image_resolution | |
| images = [] | |
| for i in range(num_samples): | |
| image = pipe( | |
| prompt, | |
| prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, | |
| prior_guidance_scale=4.0, | |
| num_inference_steps = sample_steps, | |
| num_images_per_prompt=1, | |
| height=H, width=W).images[0] | |
| images.append(np.asarray(image)) | |
| results = images | |
| return results | |
| #return [255 - detected_map] + results | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.Markdown("## Rapid Diffusion model from warp-ai/wuerstchen") | |
| #gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n") | |
| with gr.Row(): | |
| with gr.Column(): | |
| #input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png") | |
| prompt = gr.Textbox(label="Prompt", value = "Anthropomorphic cat dressed as a fire fighter") | |
| run_button = gr.Button(label="Run") | |
| with gr.Accordion("Advanced options", open=False): | |
| num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
| image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) | |
| #low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1) | |
| #high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1) | |
| sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=6, step=1) | |
| #scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
| seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
| #eta = gr.Number(label="eta", value=0.0) | |
| #a_prompt = gr.Textbox(label="Added Prompt", value='') | |
| #n_prompt = gr.Textbox(label="Negative Prompt", | |
| # value='低质量,模糊,混乱') | |
| with gr.Column(): | |
| result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
| #ips = [None, prompt, None, None, num_samples, image_resolution, sample_steps, None, seed, None, None, None] | |
| ips = [prompt, num_samples, image_resolution, sample_steps, seed] | |
| run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True) | |
| gr.Examples( | |
| [ | |
| ["A glass of cola, 8k", 1, 512, 8, 10], | |
| ["Anthropomorphic cat dressed as a fire fighter", 1, 512, 8, 20], | |
| ], | |
| inputs = [prompt, num_samples, image_resolution, sample_steps, seed], | |
| label = "Examples" | |
| ) | |
| block.launch(server_name='0.0.0.0') | |