{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4f62bfd9-5396-48e2-aac7-bdf639cab345", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The config attributes {'block_out_channels': [128, 256, 512, 768, 768], 'force_upcast': False} were passed to AsymmetricAutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ok\n" ] } ], "source": [ "import torch\n", "\n", "from torchvision import transforms, utils\n", "\n", "import diffusers\n", "from diffusers import AsymmetricAutoencoderKL\n", "\n", "from diffusers.utils import load_image\n", "\n", "def crop_image_to_nearest_divisible_by_8(img):\n", " # Check if the image height and width are divisible by 8\n", " if img.shape[1] % 8 == 0 and img.shape[2] % 8 == 0:\n", " return img\n", " else:\n", " # Calculate the closest lower resolution divisible by 8\n", " new_height = img.shape[1] - (img.shape[1] % 8)\n", " new_width = img.shape[2] - (img.shape[2] % 8)\n", " \n", " # Use CenterCrop to crop the image\n", " transform = transforms.CenterCrop((new_height, new_width), interpolation=transforms.InterpolationMode.BILINEAR)\n", " img = transform(img).to(torch.float32).clamp(-1, 1)\n", " \n", " return img\n", " \n", "to_tensor = transforms.ToTensor()\n", "\n", "device = \"cuda\"\n", "dtype=torch.float16\n", "vae = AsymmetricAutoencoderKL.from_pretrained(\"vae\",torch_dtype=dtype).to(device).eval()\n", "\n", "image = load_image(\"generated.png\")\n", "\n", "image = crop_image_to_nearest_divisible_by_8(to_tensor(image)).unsqueeze(0).to(device,dtype=dtype)\n", "\n", "upscaled_image = vae(image).sample\n", "# Save the reconstructed image\n", "utils.save_image(upscaled_image, \"test.png\")\n", "print('ok')" ] }, { "cell_type": "code", "execution_count": null, "id": "7e3ad326-c410-44b6-a738-15b7f7e15075", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }