File size: 16,703 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 |
"""Utilities related to image handling."""
import io
import pathlib
import warnings
import numpy as np
from keras.src import backend
from keras.src.api_export import keras_export
try:
from PIL import Image as pil_image
try:
pil_image_resampling = pil_image.Resampling
except AttributeError:
pil_image_resampling = pil_image
except ImportError:
pil_image = None
pil_image_resampling = None
if pil_image_resampling is not None:
PIL_INTERPOLATION_METHODS = {
"nearest": pil_image_resampling.NEAREST,
"bilinear": pil_image_resampling.BILINEAR,
"bicubic": pil_image_resampling.BICUBIC,
"hamming": pil_image_resampling.HAMMING,
"box": pil_image_resampling.BOX,
"lanczos": pil_image_resampling.LANCZOS,
}
@keras_export(
[
"keras.utils.array_to_img",
"keras.preprocessing.image.array_to_img",
]
)
def array_to_img(x, data_format=None, scale=True, dtype=None):
"""Converts a 3D NumPy array to a PIL Image instance.
Example:
```python
from PIL import Image
img = np.random.random(size=(100, 100, 3))
pil_img = keras.utils.array_to_img(img)
```
Args:
x: Input data, in any form that can be converted to a NumPy array.
data_format: Image data format, can be either `"channels_first"` or
`"channels_last"`. Defaults to `None`, in which case the global
setting `keras.backend.image_data_format()` is used (unless you
changed it, it defaults to `"channels_last"`).
scale: Whether to rescale the image such that minimum and maximum values
are 0 and 255 respectively. Defaults to `True`.
dtype: Dtype to use. `None` means the global setting
`keras.backend.floatx()` is used (unless you changed it, it
defaults to `"float32"`). Defaults to `None`.
Returns:
A PIL Image instance.
"""
data_format = backend.standardize_data_format(data_format)
if dtype is None:
dtype = backend.floatx()
if pil_image is None:
raise ImportError(
"Could not import PIL.Image. "
"The use of `array_to_img` requires PIL."
)
x = np.asarray(x, dtype=dtype)
if x.ndim != 3:
raise ValueError(
"Expected image array to have rank 3 (single image). "
f"Got array with shape: {x.shape}"
)
# Original NumPy array x has format (height, width, channel)
# or (channel, height, width)
# but target PIL image has format (width, height, channel)
if data_format == "channels_first":
x = x.transpose(1, 2, 0)
if scale:
x = x - np.min(x)
x_max = np.max(x)
if x_max != 0:
x /= x_max
x *= 255
if x.shape[2] == 4:
# RGBA
return pil_image.fromarray(x.astype("uint8"), "RGBA")
elif x.shape[2] == 3:
# RGB
return pil_image.fromarray(x.astype("uint8"), "RGB")
elif x.shape[2] == 1:
# grayscale
if np.max(x) > 255:
# 32-bit signed integer grayscale image. PIL mode "I"
return pil_image.fromarray(x[:, :, 0].astype("int32"), "I")
return pil_image.fromarray(x[:, :, 0].astype("uint8"), "L")
else:
raise ValueError(f"Unsupported channel number: {x.shape[2]}")
@keras_export(
[
"keras.utils.img_to_array",
"keras.preprocessing.image.img_to_array",
]
)
def img_to_array(img, data_format=None, dtype=None):
"""Converts a PIL Image instance to a NumPy array.
Example:
```python
from PIL import Image
img_data = np.random.random(size=(100, 100, 3))
img = keras.utils.array_to_img(img_data)
array = keras.utils.image.img_to_array(img)
```
Args:
img: Input PIL Image instance.
data_format: Image data format, can be either `"channels_first"` or
`"channels_last"`. Defaults to `None`, in which case the global
setting `keras.backend.image_data_format()` is used (unless you
changed it, it defaults to `"channels_last"`).
dtype: Dtype to use. `None` means the global setting
`keras.backend.floatx()` is used (unless you changed it, it
defaults to `"float32"`).
Returns:
A 3D NumPy array.
"""
data_format = backend.standardize_data_format(data_format)
if dtype is None:
dtype = backend.floatx()
# NumPy array x has format (height, width, channel)
# or (channel, height, width)
# but original PIL image has format (width, height, channel)
x = np.asarray(img, dtype=dtype)
if len(x.shape) == 3:
if data_format == "channels_first":
x = x.transpose(2, 0, 1)
elif len(x.shape) == 2:
if data_format == "channels_first":
x = x.reshape((1, x.shape[0], x.shape[1]))
else:
x = x.reshape((x.shape[0], x.shape[1], 1))
else:
raise ValueError(f"Unsupported image shape: {x.shape}")
return x
@keras_export(["keras.utils.save_img", "keras.preprocessing.image.save_img"])
def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
"""Saves an image stored as a NumPy array to a path or file object.
Args:
path: Path or file object.
x: NumPy array.
data_format: Image data format, either `"channels_first"` or
`"channels_last"`.
file_format: Optional file format override. If omitted, the format to
use is determined from the filename extension. If a file object was
used instead of a filename, this parameter should always be used.
scale: Whether to rescale image values to be within `[0, 255]`.
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)
@keras_export(["keras.utils.load_img", "keras.preprocessing.image.load_img"])
def load_img(
path,
color_mode="rgb",
target_size=None,
interpolation="nearest",
keep_aspect_ratio=False,
):
"""Loads an image into PIL format.
Example:
```python
image = keras.utils.load_img(image_path)
input_arr = keras.utils.img_to_array(image)
input_arr = np.array([input_arr]) # Convert single image to a batch.
predictions = model.predict(input_arr)
```
Args:
path: Path to image file.
color_mode: One of `"grayscale"`, `"rgb"`, `"rgba"`. Default: `"rgb"`.
The desired image format.
target_size: Either `None` (default to original size) or tuple of ints
`(img_height, img_width)`.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image. Supported
methods are `"nearest"`, `"bilinear"`, and `"bicubic"`.
If PIL version 1.1.3 or newer is installed, `"lanczos"`
is also supported. If PIL version 3.4.0 or newer is installed,
`"box"` and `"hamming"` are also
supported. By default, `"nearest"` is used.
keep_aspect_ratio: Boolean, whether to resize images to a target
size without aspect ratio distortion. The image is cropped in
the center with target aspect ratio before resizing.
Returns:
A PIL Image instance.
"""
if pil_image is None:
raise ImportError(
"Could not import PIL.Image. The use of `load_img` requires PIL."
)
if isinstance(path, io.BytesIO):
img = pil_image.open(path)
elif isinstance(path, (pathlib.Path, bytes, str)):
if isinstance(path, pathlib.Path):
path = str(path.resolve())
with open(path, "rb") as f:
img = pil_image.open(io.BytesIO(f.read()))
else:
raise TypeError(
f"path should be path-like or io.BytesIO, not {type(path)}"
)
if color_mode == "grayscale":
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ("L", "I;16", "I"):
img = img.convert("L")
elif color_mode == "rgba":
if img.mode != "RGBA":
img = img.convert("RGBA")
elif color_mode == "rgb":
if img.mode != "RGB":
img = img.convert("RGB")
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in PIL_INTERPOLATION_METHODS:
raise ValueError(
"Invalid interpolation method {} specified. Supported "
"methods are {}".format(
interpolation,
", ".join(PIL_INTERPOLATION_METHODS.keys()),
)
)
resample = PIL_INTERPOLATION_METHODS[interpolation]
if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple
crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height
# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)
crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [
crop_box_wstart,
crop_box_hstart,
crop_box_wend,
crop_box_hend,
]
img = img.resize(width_height_tuple, resample, box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img
@keras_export("keras.preprocessing.image.smart_resize")
def smart_resize(
x,
size,
interpolation="bilinear",
data_format="channels_last",
backend_module=None,
):
"""Resize images to a target size without aspect ratio distortion.
Image datasets typically yield images that have each a different
size. However, these images need to be batched before they can be
processed by Keras layers. To be batched, images need to share the same
height and width.
You could simply do, in TF (or JAX equivalent):
```python
size = (200, 200)
ds = ds.map(lambda img: resize(img, size))
```
However, if you do this, you distort the aspect ratio of your images, since
in general they do not all have the same aspect ratio as `size`. This is
fine in many cases, but not always (e.g. for image generation models
this can be a problem).
Note that passing the argument `preserve_aspect_ratio=True` to `resize`
will preserve the aspect ratio, but at the cost of no longer respecting the
provided target size.
This calls for:
```python
size = (200, 200)
ds = ds.map(lambda img: smart_resize(img, size))
```
Your output images will actually be `(200, 200)`, and will not be distorted.
Instead, the parts of the image that do not fit within the target size
get cropped out.
The resizing process is:
1. Take the largest centered crop of the image that has the same aspect
ratio as the target size. For instance, if `size=(200, 200)` and the input
image has size `(340, 500)`, we take a crop of `(340, 340)` centered along
the width.
2. Resize the cropped image to the target size. In the example above,
we resize the `(340, 340)` crop to `(200, 200)`.
Args:
x: Input image or batch of images (as a tensor or NumPy array).
Must be in format `(height, width, channels)`
or `(batch_size, height, width, channels)`.
size: Tuple of `(height, width)` integer. Target size.
interpolation: String, interpolation to use for resizing.
Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
`"lanczos3"`, `"lanczos5"`.
Defaults to `"bilinear"`.
data_format: `"channels_last"` or `"channels_first"`.
backend_module: Backend module to use (if different from the default
backend).
Returns:
Array with shape `(size[0], size[1], channels)`.
If the input image was a NumPy array, the output is a NumPy array,
and if it was a backend-native tensor,
the output is a backend-native tensor.
"""
backend_module = backend_module or backend
if len(size) != 2:
raise ValueError(
f"Expected `size` to be a tuple of 2 integers, but got: {size}."
)
img = backend_module.convert_to_tensor(x)
if len(img.shape) is not None:
if len(img.shape) < 3 or len(img.shape) > 4:
raise ValueError(
"Expected an image array with shape `(height, width, "
"channels)`, or `(batch_size, height, width, channels)`, but "
f"got input with incorrect rank, of shape {img.shape}."
)
shape = backend_module.shape(img)
if data_format == "channels_last":
height, width = shape[-3], shape[-2]
else:
height, width = shape[-2], shape[-1]
target_height, target_width = size
# Set back to input height / width if crop_height / crop_width is not
# smaller.
if isinstance(height, int) and isinstance(width, int):
# For JAX, we need to keep the slice indices as static integers
crop_height = int(float(width * target_height) / target_width)
crop_height = max(min(height, crop_height), 1)
crop_width = int(float(height * target_width) / target_height)
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
else:
crop_height = backend_module.cast(
backend_module.cast(width * target_height, "float32")
/ target_width,
"int32",
)
crop_height = backend_module.numpy.minimum(height, crop_height)
crop_height = backend_module.numpy.maximum(crop_height, 1)
crop_height = backend_module.cast(crop_height, "int32")
crop_width = backend_module.cast(
backend_module.cast(height * target_width, "float32")
/ target_height,
"int32",
)
crop_width = backend_module.numpy.minimum(width, crop_width)
crop_width = backend_module.numpy.maximum(crop_width, 1)
crop_width = backend_module.cast(crop_width, "int32")
crop_box_hstart = backend_module.cast(
backend_module.cast(height - crop_height, "float32") / 2, "int32"
)
crop_box_wstart = backend_module.cast(
backend_module.cast(width - crop_width, "float32") / 2, "int32"
)
if data_format == "channels_last":
if len(img.shape) == 4:
img = img[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
img = img[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(img.shape) == 4:
img = img[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
img = img[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
img = backend_module.image.resize(
img, size=size, interpolation=interpolation, data_format=data_format
)
if isinstance(x, np.ndarray):
return np.array(img)
return img
|