Spaces:
Running
Running
File size: 5,186 Bytes
7be1ca4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | {
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "d8d2437e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from PIL import Image\n",
"from RealESRGAN import RealESRGAN\n",
"import gradio as gr\n",
"import numpy as np\n",
"import tempfile\n",
"import time\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "871d9b94",
"metadata": {},
"outputs": [],
"source": [
"def load_model(scale):\n",
" model = RealESRGAN(device, scale=scale)\n",
" weights_path = f'weights/RealESRGAN_x{scale}.pth'\n",
" try:\n",
" model.load_weights(weights_path, download=True)\n",
" print(f\"Weights for scale {scale} loaded successfully.\")\n",
" except Exception as e:\n",
" print(f\"Error loading weights for scale {scale}: {e}\")\n",
" model.load_weights(weights_path, download=False)\n",
" return model\n",
"\n",
"model2 = load_model(2)\n",
"model4 = load_model(4)\n",
"model8 = load_model(8)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c891d4b3",
"metadata": {},
"outputs": [],
"source": [
"def enhance_image(image, scale):\n",
" try:\n",
" print(f\"Enhancing image with scale {scale}...\")\n",
" start_time = time.time()\n",
" image_np = np.array(image.convert('RGB'))\n",
" print(f\"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}\")\n",
" \n",
" if scale == '2x':\n",
" result = model2.predict(image_np)\n",
" elif scale == '4x':\n",
" result = model4.predict(image_np)\n",
" else:\n",
" result = model8.predict(image_np)\n",
" \n",
" enhanced_image = Image.fromarray(np.uint8(result))\n",
" print(f\"Image enhanced in {time.time() - start_time:.2f} seconds\")\n",
" return enhanced_image\n",
" except Exception as e:\n",
" print(f\"Error enhancing image: {e}\")\n",
" return image\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9073bff6",
"metadata": {},
"outputs": [],
"source": [
"def muda_dpi(input_image, dpi):\n",
" dpi_tuple = (dpi, dpi)\n",
" image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
" image.save(temp_file, format='PNG', dpi=dpi_tuple)\n",
" temp_file.close()\n",
" return Image.open(temp_file.name)\n",
"\n",
"def resize_image(input_image, width, height):\n",
" image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
" resized_image = image.resize((width, height))\n",
" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
" resized_image.save(temp_file, format='PNG')\n",
" temp_file.close()\n",
" return Image.open(temp_file.name)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e470926d",
"metadata": {},
"outputs": [],
"source": [
"def process_image(input_image, enhance, scale, adjust_dpi, dpi, resize, width, height):\n",
" original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
" \n",
" if enhance:\n",
" original_image = enhance_image(original_image, scale)\n",
" \n",
" if adjust_dpi:\n",
" original_image = muda_dpi(np.array(original_image), dpi)\n",
" \n",
" if resize:\n",
" original_image = resize_image(np.array(original_image), width, height)\n",
" \n",
" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
" original_image.save(temp_file.name)\n",
" return original_image, temp_file.name\n",
"\n",
"iface = gr.Interface(\n",
" fn=process_image,\n",
" inputs=[\n",
" gr.Image(label=\"Upload\"),\n",
" gr.Checkbox(label=\"Enhance Image\"),\n",
" gr.Radio(['2x', '4x', '8x'], type=\"value\", value='2x', label='Select Resolution model'),\n",
" gr.Checkbox(label=\"Apply DPI\"),\n",
" gr.Number(label=\"DPI\", value=300),\n",
" gr.Checkbox(label=\"Apply Resize\"),\n",
" gr.Number(label=\"Width\", value=512),\n",
" gr.Number(label=\"Height\", value=512)\n",
" ],\n",
" outputs=[\n",
" gr.Image(label=\"Final Image\"),\n",
" gr.File(label=\"Download Final Image\")\n",
" ],\n",
" title=\"Image Enhancer\",\n",
" description=\"Sorry for the inconvenience. The model is currently running on the CPU, which might affect performance. We appreciate your understanding.\",\n",
" theme=\"Yntec/HaleyCH_Theme_Orange\"\n",
")\n",
"\n",
"iface.launch(debug=True)\n"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
} |