#!/usr/bin/env python3 """Generate a 4x4 augmentation preview grid without training. Examples: python -u scripts/finetune/shared/preview_augmentations.py \ --image-path data/roboflow/frame_01_0001_001_png.rf.19fb74b147f4e2ea7aeeeba9a8f9bb60.jpg python -u scripts/finetune/shared/preview_augmentations.py \ --mode bifurcation --image-path data/roboflow/frame_01_0001_001_png.rf.19fb74b147f4e2ea7aeeeba9a8f9bb60.jpg """ from __future__ import annotations import argparse from pathlib import Path import numpy as np import tensorflow as tf from PIL import Image IMG_MEAN = tf.constant([60.3486], dtype=tf.float32) def _load_gray(image_path: Path) -> np.ndarray: img = Image.open(image_path).convert("L") return np.asarray(img, dtype=np.uint8) def _prepare_lumen_input(gray: np.ndarray) -> tf.Tensor: x = tf.convert_to_tensor(gray[None, ...], dtype=tf.float32) x = x - IMG_MEAN x = tf.expand_dims(x, axis=-1) x = tf.tile(x, [1, 1, 1, 3]) return x def _prepare_bif_input(gray: np.ndarray, center_images: bool = False) -> tf.Tensor: x = tf.convert_to_tensor(gray[None, ...], dtype=tf.float32) if center_images: x = x - IMG_MEAN x = tf.expand_dims(x, axis=-1) x = tf.tile(x, [1, 1, 1, 3]) return x def _augment_seg_batch(images: tf.Tensor, seed: int) -> tf.Tensor: stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32) flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images) stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32) flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images) stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32) k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32) images = tf.image.rot90(images, k=k) shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08) shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08) p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4) p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4) base_t = tf.stack( [ tf.constant(1.0, tf.float32), shx, tf.constant(0.0, tf.float32), shy, tf.constant(1.0, tf.float32), tf.constant(0.0, tf.float32), p0, p1, ] ) transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1]) images = tf.raw_ops.ImageProjectiveTransformV3( images=images, transforms=transforms, output_shape=tf.shape(images)[1:3], interpolation="BILINEAR", fill_mode="REFLECT", fill_value=0.0, ) stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32) images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed) stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32) images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed) images = tf.clip_by_value(images, -255.0, 255.0) return images def _augment_cls_batch(images: tf.Tensor, seed: int) -> tf.Tensor: stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32) flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images) stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32) flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5 images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images) stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32) k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32) images = tf.image.rot90(images, k=k) shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08) shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08) p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4) p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4) base_t = tf.stack( [ tf.constant(1.0, tf.float32), shx, tf.constant(0.0, tf.float32), shy, tf.constant(1.0, tf.float32), tf.constant(0.0, tf.float32), p0, p1, ] ) transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1]) images = tf.raw_ops.ImageProjectiveTransformV3( images=images, transforms=transforms, output_shape=tf.shape(images)[1:3], interpolation="BILINEAR", fill_mode="REFLECT", fill_value=0.0, ) stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32) images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed) stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32) images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed) images = tf.clip_by_value(images, -255.0, 255.0) return images def _tile_to_uint8(tile: np.ndarray) -> np.ndarray: x = np.asarray(tile, dtype=np.float32) if x.ndim == 3 and x.shape[-1] > 1: x = x[..., 0] if not np.isfinite(x).any(): return np.zeros_like(x, dtype=np.uint8) mn = float(np.nanmin(x)) mx = float(np.nanmax(x)) if mx <= mn + 1e-6: return np.clip(x, 0, 255).astype(np.uint8) x = (x - mn) / (mx - mn) return (x * 255.0).clip(0, 255).astype(np.uint8) def _build_grid(tiles: list[np.ndarray], h: int, w: int) -> Image.Image: grid = Image.new("L", (4 * w, 4 * h), color=0) for i, tile in enumerate(tiles[:16]): r = i // 4 c = i % 4 grid.paste(Image.fromarray(_tile_to_uint8(tile), mode="L"), (c * w, r * h)) return grid def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--image-path", type=Path, required=True) parser.add_argument("--mode", type=str, default="lumen", choices=["lumen", "bifurcation"]) parser.add_argument("--seed", type=int, default=7) parser.add_argument( "--output-path", type=Path, default=Path("output/augment_preview/augment_preview_4x4.png"), ) parser.add_argument( "--center-images", action="store_true", help="For bifurcation custom-model style preview (subtract IMG_MEAN before aug).", ) args = parser.parse_args() if not args.image_path.exists(): raise FileNotFoundError(f"Image not found: {args.image_path}") gray = _load_gray(args.image_path) h, w = int(gray.shape[0]), int(gray.shape[1]) if args.mode == "lumen": base = _prepare_lumen_input(gray) aug_fn = _augment_seg_batch else: base = _prepare_bif_input(gray, center_images=args.center_images) aug_fn = _augment_cls_batch tiles: list[np.ndarray] = [] for i in range(16): out = aug_fn(base, seed=args.seed + i).numpy()[0] tiles.append(out) grid = _build_grid(tiles, h=h, w=w) args.output_path.parent.mkdir(parents=True, exist_ok=True) grid.save(args.output_path) print(f"Saved augmentation grid: {args.output_path}") if __name__ == "__main__": main()