NisabaRelief / dev_scripts /process_images.py
boatbomber's picture
Initial release
3050f1b
"""Process a directory of images through NisabaRelief and save as PNG."""
import argparse
from pathlib import Path
from PIL import Image
from rich.console import Console
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
ProgressColumn,
SpinnerColumn,
Task,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text
from nisaba_relief import NisabaRelief
from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION
Image.MAX_IMAGE_PIXELS = None
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"}
class SimpleTimeRemainingColumn(ProgressColumn):
"""Estimates remaining time from the average duration of the last 10 iterations.
Only recomputes when a new step completes so the display is stable.
"""
def __init__(self, window: int = 10) -> None:
super().__init__()
self._last_completed: float = 0
self._last_elapsed: float = 0.0
self._durations: list[float] = []
self._window: int = window
self._cached: Text = Text("-:--:--", style="progress.remaining")
def render(self, task: Task) -> Text:
if task.completed <= self._last_completed:
return self._cached
elapsed = task.finished_time if task.finished else task.elapsed
if not elapsed or not task.completed:
self._last_completed = task.completed
self._cached = Text("-:--:--", style="progress.remaining")
return self._cached
step_duration = elapsed - self._last_elapsed
steps = task.completed - self._last_completed
if steps > 0 and self._last_completed > 0:
per_step = step_duration / steps
self._durations.append(per_step)
if len(self._durations) > self._window:
self._durations = self._durations[-self._window :]
self._last_completed = task.completed
self._last_elapsed = elapsed
if not self._durations:
self._cached = Text("-:--:--", style="progress.remaining")
return self._cached
avg = sum(self._durations) / len(self._durations)
remaining = task.total - task.completed
eta_seconds = avg * remaining
hours, rem = divmod(int(eta_seconds), 3600)
minutes, seconds = divmod(rem, 60)
if hours:
self._cached = Text(
f"{hours}:{minutes:02d}:{seconds:02d}", style="progress.remaining"
)
else:
self._cached = Text(f"{minutes}:{seconds:02d}", style="progress.remaining")
return self._cached
def main():
parser = argparse.ArgumentParser(
description="Process images through NisabaRelief and save as PNG."
)
parser.add_argument(
"--input-dir", type=Path, required=True, help="Source image directory"
)
parser.add_argument(
"--output-dir", type=Path, required=True, help="Destination directory (created if needed)"
)
parser.add_argument(
"--max-size", type=int, default=MAX_TILE * 5,
help="Downsample images larger than this before processing (default: %(default)s)",
)
parser.add_argument(
"--min-size", type=int, default=1536,
help="Skip images where max dimension < this (default: %(default)s)",
)
parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
parser.add_argument("--weights-dir", type=Path, default=None, help="Local weights directory")
parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
parser.add_argument("--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)")
parser.add_argument("--device", default="cuda", help="Torch device (default: %(default)s)")
parser.add_argument(
"--overwrite", action="store_true", help="Re-process even if output file exists"
)
args = parser.parse_args()
console = Console()
input_dir: Path = args.input_dir
output_dir: Path = args.output_dir
if not input_dir.is_dir():
console.print(f"[red]Input directory not found:[/red] [cyan]{input_dir}[/cyan]")
return
input_images = sorted(
p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
)
if not input_images:
console.print(f"[red]No images found in[/red] [cyan]{input_dir}[/cyan]")
return
output_dir.mkdir(parents=True, exist_ok=True)
to_process = []
skipped_existing = 0
skipped_small = 0
for src in input_images:
dst = output_dir / (src.stem + ".png")
if not args.overwrite and dst.exists():
skipped_existing += 1
continue
with Image.open(src) as img:
if max(img.size) < args.min_size or min(img.size) < MIN_IMAGE_DIMENSION:
skipped_small += 1
continue
to_process.append((src, dst))
if skipped_existing:
console.print(
f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
)
if skipped_small:
console.print(
f"[dim]Skipping {skipped_small} image(s) smaller than {args.min_size}px[/dim]"
)
if not to_process:
console.print("[green]All images already processed.[/green]")
return
console.print(
f"Processing [bold]{len(to_process)}[/bold] / {len(input_images)} images "
f"[dim]({input_dir}{output_dir})[/dim]"
)
model_kwargs = dict(num_steps=args.num_steps, device=args.device)
if args.seed is not None:
model_kwargs["seed"] = args.seed
if args.weights_dir is not None:
model_kwargs["weights_dir"] = args.weights_dir
if args.batch_size is not None:
model_kwargs["batch_size"] = args.batch_size
model = NisabaRelief(**model_kwargs)
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TextColumn("eta"),
SimpleTimeRemainingColumn(),
)
with progress:
task = progress.add_task("Processing", total=len(to_process))
for src, dst in to_process:
progress.update(task, description=f"[cyan]{src.name}[/cyan]")
image = Image.open(src).convert("RGB")
original_size = image.size
if max(image.size) > args.max_size:
scale = args.max_size / max(image.size)
new_size = (
round(image.width * scale) // 16 * 16,
round(image.height * scale) // 16 * 16,
)
image = image.resize(new_size, Image.LANCZOS)
result = model.process(image, show_pbar=False)
if result.size != original_size:
result = result.resize(original_size, Image.LANCZOS)
result.save(dst)
progress.advance(task)
console.print(
f"[green]Done.[/green] {len(to_process)} image(s) saved to [cyan]{output_dir}[/cyan]"
)
if __name__ == "__main__":
main()