"""Process a directory of images through NisabaRelief and save as PNG.""" import argparse from pathlib import Path from PIL import Image from rich.console import Console from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, ProgressColumn, SpinnerColumn, Task, TextColumn, TimeElapsedColumn, ) from rich.text import Text from nisaba_relief import NisabaRelief from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION Image.MAX_IMAGE_PIXELS = None IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"} class SimpleTimeRemainingColumn(ProgressColumn): """Estimates remaining time from the average duration of the last 10 iterations. Only recomputes when a new step completes so the display is stable. """ def __init__(self, window: int = 10) -> None: super().__init__() self._last_completed: float = 0 self._last_elapsed: float = 0.0 self._durations: list[float] = [] self._window: int = window self._cached: Text = Text("-:--:--", style="progress.remaining") def render(self, task: Task) -> Text: if task.completed <= self._last_completed: return self._cached elapsed = task.finished_time if task.finished else task.elapsed if not elapsed or not task.completed: self._last_completed = task.completed self._cached = Text("-:--:--", style="progress.remaining") return self._cached step_duration = elapsed - self._last_elapsed steps = task.completed - self._last_completed if steps > 0 and self._last_completed > 0: per_step = step_duration / steps self._durations.append(per_step) if len(self._durations) > self._window: self._durations = self._durations[-self._window :] self._last_completed = task.completed self._last_elapsed = elapsed if not self._durations: self._cached = Text("-:--:--", style="progress.remaining") return self._cached avg = sum(self._durations) / len(self._durations) remaining = task.total - task.completed eta_seconds = avg * remaining hours, rem = divmod(int(eta_seconds), 3600) minutes, seconds = divmod(rem, 60) if hours: self._cached = Text( f"{hours}:{minutes:02d}:{seconds:02d}", style="progress.remaining" ) else: self._cached = Text(f"{minutes}:{seconds:02d}", style="progress.remaining") return self._cached def main(): parser = argparse.ArgumentParser( description="Process images through NisabaRelief and save as PNG." ) parser.add_argument( "--input-dir", type=Path, required=True, help="Source image directory" ) parser.add_argument( "--output-dir", type=Path, required=True, help="Destination directory (created if needed)" ) parser.add_argument( "--max-size", type=int, default=MAX_TILE * 5, help="Downsample images larger than this before processing (default: %(default)s)", ) parser.add_argument( "--min-size", type=int, default=1536, help="Skip images where max dimension < this (default: %(default)s)", ) parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed") parser.add_argument("--weights-dir", type=Path, default=None, help="Local weights directory") parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size") parser.add_argument("--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)") parser.add_argument("--device", default="cuda", help="Torch device (default: %(default)s)") parser.add_argument( "--overwrite", action="store_true", help="Re-process even if output file exists" ) args = parser.parse_args() console = Console() input_dir: Path = args.input_dir output_dir: Path = args.output_dir if not input_dir.is_dir(): console.print(f"[red]Input directory not found:[/red] [cyan]{input_dir}[/cyan]") return input_images = sorted( p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS ) if not input_images: console.print(f"[red]No images found in[/red] [cyan]{input_dir}[/cyan]") return output_dir.mkdir(parents=True, exist_ok=True) to_process = [] skipped_existing = 0 skipped_small = 0 for src in input_images: dst = output_dir / (src.stem + ".png") if not args.overwrite and dst.exists(): skipped_existing += 1 continue with Image.open(src) as img: if max(img.size) < args.min_size or min(img.size) < MIN_IMAGE_DIMENSION: skipped_small += 1 continue to_process.append((src, dst)) if skipped_existing: console.print( f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]" ) if skipped_small: console.print( f"[dim]Skipping {skipped_small} image(s) smaller than {args.min_size}px[/dim]" ) if not to_process: console.print("[green]All images already processed.[/green]") return console.print( f"Processing [bold]{len(to_process)}[/bold] / {len(input_images)} images " f"[dim]({input_dir} → {output_dir})[/dim]" ) model_kwargs = dict(num_steps=args.num_steps, device=args.device) if args.seed is not None: model_kwargs["seed"] = args.seed if args.weights_dir is not None: model_kwargs["weights_dir"] = args.weights_dir if args.batch_size is not None: model_kwargs["batch_size"] = args.batch_size model = NisabaRelief(**model_kwargs) progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), MofNCompleteColumn(), TimeElapsedColumn(), TextColumn("eta"), SimpleTimeRemainingColumn(), ) with progress: task = progress.add_task("Processing", total=len(to_process)) for src, dst in to_process: progress.update(task, description=f"[cyan]{src.name}[/cyan]") image = Image.open(src).convert("RGB") original_size = image.size if max(image.size) > args.max_size: scale = args.max_size / max(image.size) new_size = ( round(image.width * scale) // 16 * 16, round(image.height * scale) // 16 * 16, ) image = image.resize(new_size, Image.LANCZOS) result = model.process(image, show_pbar=False) if result.size != original_size: result = result.resize(original_size, Image.LANCZOS) result.save(dst) progress.advance(task) console.print( f"[green]Done.[/green] {len(to_process)} image(s) saved to [cyan]{output_dir}[/cyan]" ) if __name__ == "__main__": main()