ivus-segmentation / scripts /finetune /shared /preview_augmentations.py
Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""Generate a 4x4 augmentation preview grid without training.
Examples:
python -u scripts/finetune/shared/preview_augmentations.py \
--image-path data/roboflow/frame_01_0001_001_png.rf.19fb74b147f4e2ea7aeeeba9a8f9bb60.jpg
python -u scripts/finetune/shared/preview_augmentations.py \
--mode bifurcation --image-path data/roboflow/frame_01_0001_001_png.rf.19fb74b147f4e2ea7aeeeba9a8f9bb60.jpg
"""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import tensorflow as tf
from PIL import Image
IMG_MEAN = tf.constant([60.3486], dtype=tf.float32)
def _load_gray(image_path: Path) -> np.ndarray:
img = Image.open(image_path).convert("L")
return np.asarray(img, dtype=np.uint8)
def _prepare_lumen_input(gray: np.ndarray) -> tf.Tensor:
x = tf.convert_to_tensor(gray[None, ...], dtype=tf.float32)
x = x - IMG_MEAN
x = tf.expand_dims(x, axis=-1)
x = tf.tile(x, [1, 1, 1, 3])
return x
def _prepare_bif_input(gray: np.ndarray, center_images: bool = False) -> tf.Tensor:
x = tf.convert_to_tensor(gray[None, ...], dtype=tf.float32)
if center_images:
x = x - IMG_MEAN
x = tf.expand_dims(x, axis=-1)
x = tf.tile(x, [1, 1, 1, 3])
return x
def _augment_seg_batch(images: tf.Tensor, seed: int) -> tf.Tensor:
stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32)
flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5
images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images)
stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32)
flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5
images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images)
stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32)
k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32)
images = tf.image.rot90(images, k=k)
shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08)
shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08)
p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4)
p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4)
base_t = tf.stack(
[
tf.constant(1.0, tf.float32),
shx,
tf.constant(0.0, tf.float32),
shy,
tf.constant(1.0, tf.float32),
tf.constant(0.0, tf.float32),
p0,
p1,
]
)
transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1])
images = tf.raw_ops.ImageProjectiveTransformV3(
images=images,
transforms=transforms,
output_shape=tf.shape(images)[1:3],
interpolation="BILINEAR",
fill_mode="REFLECT",
fill_value=0.0,
)
stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32)
images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed)
stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32)
images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed)
images = tf.clip_by_value(images, -255.0, 255.0)
return images
def _augment_cls_batch(images: tf.Tensor, seed: int) -> tf.Tensor:
stateless_seed = tf.convert_to_tensor([seed, 0], dtype=tf.int32)
flip_lr = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5
images = tf.cond(flip_lr, lambda: tf.image.flip_left_right(images), lambda: images)
stateless_seed = tf.convert_to_tensor([seed, 1], dtype=tf.int32)
flip_ud = tf.random.stateless_uniform([], seed=stateless_seed) > 0.5
images = tf.cond(flip_ud, lambda: tf.image.flip_up_down(images), lambda: images)
stateless_seed = tf.convert_to_tensor([seed, 2], dtype=tf.int32)
k = tf.random.stateless_uniform([], seed=stateless_seed, minval=0, maxval=4, dtype=tf.int32)
images = tf.image.rot90(images, k=k)
shx = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 30], dtype=tf.int32), minval=-0.08, maxval=0.08)
shy = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 31], dtype=tf.int32), minval=-0.08, maxval=0.08)
p0 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 32], dtype=tf.int32), minval=-8e-4, maxval=8e-4)
p1 = tf.random.stateless_uniform([], seed=tf.convert_to_tensor([seed, 33], dtype=tf.int32), minval=-8e-4, maxval=8e-4)
base_t = tf.stack(
[
tf.constant(1.0, tf.float32),
shx,
tf.constant(0.0, tf.float32),
shy,
tf.constant(1.0, tf.float32),
tf.constant(0.0, tf.float32),
p0,
p1,
]
)
transforms = tf.tile(base_t[tf.newaxis, :], [tf.shape(images)[0], 1])
images = tf.raw_ops.ImageProjectiveTransformV3(
images=images,
transforms=transforms,
output_shape=tf.shape(images)[1:3],
interpolation="BILINEAR",
fill_mode="REFLECT",
fill_value=0.0,
)
stateless_seed = tf.convert_to_tensor([seed, 3], dtype=tf.int32)
images = tf.image.stateless_random_brightness(images, max_delta=10.0, seed=stateless_seed)
stateless_seed = tf.convert_to_tensor([seed, 4], dtype=tf.int32)
images = tf.image.stateless_random_contrast(images, lower=0.9, upper=1.1, seed=stateless_seed)
images = tf.clip_by_value(images, -255.0, 255.0)
return images
def _tile_to_uint8(tile: np.ndarray) -> np.ndarray:
x = np.asarray(tile, dtype=np.float32)
if x.ndim == 3 and x.shape[-1] > 1:
x = x[..., 0]
if not np.isfinite(x).any():
return np.zeros_like(x, dtype=np.uint8)
mn = float(np.nanmin(x))
mx = float(np.nanmax(x))
if mx <= mn + 1e-6:
return np.clip(x, 0, 255).astype(np.uint8)
x = (x - mn) / (mx - mn)
return (x * 255.0).clip(0, 255).astype(np.uint8)
def _build_grid(tiles: list[np.ndarray], h: int, w: int) -> Image.Image:
grid = Image.new("L", (4 * w, 4 * h), color=0)
for i, tile in enumerate(tiles[:16]):
r = i // 4
c = i % 4
grid.paste(Image.fromarray(_tile_to_uint8(tile), mode="L"), (c * w, r * h))
return grid
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--image-path", type=Path, required=True)
parser.add_argument("--mode", type=str, default="lumen", choices=["lumen", "bifurcation"])
parser.add_argument("--seed", type=int, default=7)
parser.add_argument(
"--output-path",
type=Path,
default=Path("output/augment_preview/augment_preview_4x4.png"),
)
parser.add_argument(
"--center-images",
action="store_true",
help="For bifurcation custom-model style preview (subtract IMG_MEAN before aug).",
)
args = parser.parse_args()
if not args.image_path.exists():
raise FileNotFoundError(f"Image not found: {args.image_path}")
gray = _load_gray(args.image_path)
h, w = int(gray.shape[0]), int(gray.shape[1])
if args.mode == "lumen":
base = _prepare_lumen_input(gray)
aug_fn = _augment_seg_batch
else:
base = _prepare_bif_input(gray, center_images=args.center_images)
aug_fn = _augment_cls_batch
tiles: list[np.ndarray] = []
for i in range(16):
out = aug_fn(base, seed=args.seed + i).numpy()[0]
tiles.append(out)
grid = _build_grid(tiles, h=h, w=w)
args.output_path.parent.mkdir(parents=True, exist_ok=True)
grid.save(args.output_path)
print(f"Saved augmentation grid: {args.output_path}")
if __name__ == "__main__":
main()