Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler | |
| from PIL import Image | |
| import PIL | |
| import torch | |
| import numpy as np | |
| model_path = "Linaqruf/anything-v3.0" | |
| vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae") | |
| print(f"vae loaded from {model_path}") | |
| def snap(w, h, d=64, area=640 * 640): | |
| s = min(1.0, (area / w / h) ** 0.5) | |
| err = lambda a, b: 1 - min(a, b) / max(a, b) | |
| sw, sh = map(lambda x: int((x * s) // d * d), (w, h)) | |
| return min( | |
| ( | |
| (ww, hh) | |
| for ww, hh in [(sw, sh), (sw, sh + d), (sw + d, sh), (sw + d, sh + d)] | |
| if ww * hh <= area | |
| ), | |
| key=lambda wh: err(w / h, wh[0] / wh[1]), | |
| ) | |
| def center_crop_image(image, hx, wx): | |
| # Get the original image dimensions (HxW) | |
| original_width, original_height = image.size | |
| # Calculate the coordinates for center cropping | |
| if original_width / original_height > wx / hx: | |
| ww = original_height * wx / hx | |
| left, right, top, bottom = ( | |
| (original_width - ww) / 2, | |
| (original_width + ww) / 2, | |
| 0, | |
| original_height, | |
| ) | |
| else: | |
| hh = original_width * hx / wx | |
| left, right, top, bottom = ( | |
| 0, | |
| original_width, | |
| (original_height - hh) / 2, | |
| (original_height + hh) / 2, | |
| ) | |
| # Crop the image | |
| cropped_image = image.crop((left, top, right, bottom)) | |
| # Resize the cropped image to the target size (hxw) | |
| cropped_image = cropped_image.resize((wx, hx), Image.Resampling.LANCZOS) | |
| return cropped_image | |
| def preprocess(image): | |
| if isinstance(image, torch.Tensor): | |
| return image | |
| elif isinstance(image, PIL.Image.Image): | |
| image = [image] | |
| if isinstance(image[0], PIL.Image.Image): | |
| image = [np.array(i)[None, :] for i in image] | |
| image = np.concatenate(image, axis=0) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = image.transpose(0, 3, 1, 2) | |
| image = 2.0 * image - 1.0 | |
| image = torch.from_numpy(image) | |
| elif isinstance(image[0], torch.Tensor): | |
| image = torch.cat(image, dim=0) | |
| return image | |
| def numpy_to_pil(images): | |
| """ | |
| Convert a numpy image or a batch of images to a PIL image. | |
| """ | |
| if images.ndim == 3: | |
| images = images[None, ...] | |
| images = (images * 255).round().astype("uint8") | |
| if images.shape[-1] == 1: | |
| # special case for grayscale (single channel) images | |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
| else: | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def postprocess_image(sample: torch.FloatTensor, output_type: str = "pil"): | |
| if output_type not in ["pt", "np", "pil"]: | |
| raise ValueError( | |
| f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" | |
| ) | |
| # Equivalent to diffusers.VaeImageProcessor.denormalize | |
| sample = (sample / 2 + 0.5).clamp(0, 1) | |
| if output_type == "pt": | |
| return sample | |
| # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy | |
| sample = sample.cpu().permute(0, 2, 3, 1).numpy() | |
| if output_type == "np": | |
| return sample | |
| # Output_type must be 'pil' | |
| sample = numpy_to_pil(sample) | |
| return sample | |
| def vae_roundtrip(image, max_resolution: int): | |
| w, h = image.size | |
| ww, hh = snap(w, h, area=max_resolution**2) | |
| cropped = center_crop_image(image, hh, ww) | |
| image = preprocess(cropped) | |
| with torch.no_grad(): | |
| dist = vae.encode(image)[0] | |
| res = vae.decode(dist.mean, return_dict=False)[0] | |
| return cropped, postprocess_image(res)[0] | |
| iface = gr.Interface( | |
| fn=vae_roundtrip, | |
| inputs=[gr.Image(type="pil"), gr.Slider(384, 1024, step=64, value=640)], | |
| outputs=[gr.Image(label="center cropped"), gr.Image(label="after roundtrip")], | |
| allow_flagging="never", | |
| ) | |
| iface.launch() | |