File size: 5,186 Bytes
7be1ca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8d2437e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from PIL import Image\n",
    "from RealESRGAN import RealESRGAN\n",
    "import gradio as gr\n",
    "import numpy as np\n",
    "import tempfile\n",
    "import time\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "871d9b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(scale):\n",
    "    model = RealESRGAN(device, scale=scale)\n",
    "    weights_path = f'weights/RealESRGAN_x{scale}.pth'\n",
    "    try:\n",
    "        model.load_weights(weights_path, download=True)\n",
    "        print(f\"Weights for scale {scale} loaded successfully.\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading weights for scale {scale}: {e}\")\n",
    "        model.load_weights(weights_path, download=False)\n",
    "    return model\n",
    "\n",
    "model2 = load_model(2)\n",
    "model4 = load_model(4)\n",
    "model8 = load_model(8)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c891d4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def enhance_image(image, scale):\n",
    "    try:\n",
    "        print(f\"Enhancing image with scale {scale}...\")\n",
    "        start_time = time.time()\n",
    "        image_np = np.array(image.convert('RGB'))\n",
    "        print(f\"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}\")\n",
    "        \n",
    "        if scale == '2x':\n",
    "            result = model2.predict(image_np)\n",
    "        elif scale == '4x':\n",
    "            result = model4.predict(image_np)\n",
    "        else:\n",
    "            result = model8.predict(image_np)\n",
    "            \n",
    "        enhanced_image = Image.fromarray(np.uint8(result))\n",
    "        print(f\"Image enhanced in {time.time() - start_time:.2f} seconds\")\n",
    "        return enhanced_image\n",
    "    except Exception as e:\n",
    "        print(f\"Error enhancing image: {e}\")\n",
    "        return image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9073bff6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def muda_dpi(input_image, dpi):\n",
    "    dpi_tuple = (dpi, dpi)\n",
    "    image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
    "    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
    "    image.save(temp_file, format='PNG', dpi=dpi_tuple)\n",
    "    temp_file.close()\n",
    "    return Image.open(temp_file.name)\n",
    "\n",
    "def resize_image(input_image, width, height):\n",
    "    image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
    "    resized_image = image.resize((width, height))\n",
    "    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
    "    resized_image.save(temp_file, format='PNG')\n",
    "    temp_file.close()\n",
    "    return Image.open(temp_file.name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e470926d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_image(input_image, enhance, scale, adjust_dpi, dpi, resize, width, height):\n",
    "    original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
    "    \n",
    "    if enhance:\n",
    "        original_image = enhance_image(original_image, scale)\n",
    "    \n",
    "    if adjust_dpi:\n",
    "        original_image = muda_dpi(np.array(original_image), dpi)\n",
    "        \n",
    "    if resize:\n",
    "        original_image = resize_image(np.array(original_image), width, height)\n",
    "        \n",
    "    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
    "    original_image.save(temp_file.name)\n",
    "    return original_image, temp_file.name\n",
    "\n",
    "iface = gr.Interface(\n",
    "    fn=process_image,\n",
    "    inputs=[\n",
    "        gr.Image(label=\"Upload\"),\n",
    "        gr.Checkbox(label=\"Enhance Image\"),\n",
    "        gr.Radio(['2x', '4x', '8x'], type=\"value\", value='2x', label='Select Resolution model'),\n",
    "        gr.Checkbox(label=\"Apply DPI\"),\n",
    "        gr.Number(label=\"DPI\", value=300),\n",
    "        gr.Checkbox(label=\"Apply Resize\"),\n",
    "        gr.Number(label=\"Width\", value=512),\n",
    "        gr.Number(label=\"Height\", value=512)\n",
    "    ],\n",
    "    outputs=[\n",
    "        gr.Image(label=\"Final Image\"),\n",
    "        gr.File(label=\"Download Final Image\")\n",
    "    ],\n",
    "    title=\"Image Enhancer\",\n",
    "    description=\"Sorry for the inconvenience. The model is currently running on the CPU, which might affect performance. We appreciate your understanding.\",\n",
    "    theme=\"Yntec/HaleyCH_Theme_Orange\"\n",
    ")\n",
    "\n",
    "iface.launch(debug=True)\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}