Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
| from diffusers import UniPCMultistepScheduler | |
| import torch | |
| import torchvision.transforms as T | |
| import torchvision.transforms.v2 as T2 | |
| import cv2 | |
| from PIL import Image | |
| output_res = (768,768) | |
| conditioning_image_transforms = T.Compose( | |
| [ | |
| T2.ScaleJitter(target_size=output_res, scale_range=(0.5, 3.0)), | |
| T2.RandomCrop(size=output_res, pad_if_needed=True, padding_mode="symmetric"), | |
| T.ToTensor(), | |
| T.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| cnet = ControlNetModel.from_pretrained("./models/catcon-controlnet-wd", torch_dtype=torch.float16, from_flax=True) | |
| pipe = ControlNetModel.from_pretrained( | |
| "./models/wd-1-5-b2", | |
| controlnet=cnet, | |
| torch_dtype=torch.float16, | |
| ) | |
| generator = torch.manual_seed(0) | |
| # inference function takes prompt, negative prompt and image | |
| def infer(prompt, negative_prompt, image): | |
| # implement your inference function here | |
| cond_input = conditioning_image_transforms(image) | |
| output = pipe( | |
| prompt, | |
| cond_input, | |
| generator=generator, | |
| num_images_per_prompt=1, | |
| num_inference_steps=20 | |
| ) | |
| return output[0] | |
| # you need to pass inputs and outputs according to inference function | |
| gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image").launch() | |
| title = "Categorical Conditioning Controlnet for One-Shot Image Stylization." | |
| description = "This is a demo on ControlNet which generates images based on the style of the conditioning input." | |
| # you need to pass your examples according to your inputs | |
| # each inner list is one example, each element in the list corresponding to a component in the `inputs`. | |
| examples = [["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "low quality", "wikipe_cond_1.png"]] | |
| gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image", | |
| title = title, description = description, examples = examples, theme='gradio/soft').launch() | |