File size: 7,135 Bytes
3050f1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""Process a directory of images through NisabaRelief and save as PNG."""

import argparse
from pathlib import Path

from PIL import Image
from rich.console import Console
from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    ProgressColumn,
    SpinnerColumn,
    Task,
    TextColumn,
    TimeElapsedColumn,
)
from rich.text import Text

from nisaba_relief import NisabaRelief
from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION

Image.MAX_IMAGE_PIXELS = None

IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"}


class SimpleTimeRemainingColumn(ProgressColumn):
    """Estimates remaining time from the average duration of the last 10 iterations.

    Only recomputes when a new step completes so the display is stable.
    """

    def __init__(self, window: int = 10) -> None:
        super().__init__()
        self._last_completed: float = 0
        self._last_elapsed: float = 0.0
        self._durations: list[float] = []
        self._window: int = window
        self._cached: Text = Text("-:--:--", style="progress.remaining")

    def render(self, task: Task) -> Text:
        if task.completed <= self._last_completed:
            return self._cached
        elapsed = task.finished_time if task.finished else task.elapsed
        if not elapsed or not task.completed:
            self._last_completed = task.completed
            self._cached = Text("-:--:--", style="progress.remaining")
            return self._cached
        step_duration = elapsed - self._last_elapsed
        steps = task.completed - self._last_completed
        if steps > 0 and self._last_completed > 0:
            per_step = step_duration / steps
            self._durations.append(per_step)
            if len(self._durations) > self._window:
                self._durations = self._durations[-self._window :]
        self._last_completed = task.completed
        self._last_elapsed = elapsed
        if not self._durations:
            self._cached = Text("-:--:--", style="progress.remaining")
            return self._cached
        avg = sum(self._durations) / len(self._durations)
        remaining = task.total - task.completed
        eta_seconds = avg * remaining
        hours, rem = divmod(int(eta_seconds), 3600)
        minutes, seconds = divmod(rem, 60)
        if hours:
            self._cached = Text(
                f"{hours}:{minutes:02d}:{seconds:02d}", style="progress.remaining"
            )
        else:
            self._cached = Text(f"{minutes}:{seconds:02d}", style="progress.remaining")
        return self._cached


def main():
    parser = argparse.ArgumentParser(
        description="Process images through NisabaRelief and save as PNG."
    )
    parser.add_argument(
        "--input-dir", type=Path, required=True, help="Source image directory"
    )
    parser.add_argument(
        "--output-dir", type=Path, required=True, help="Destination directory (created if needed)"
    )
    parser.add_argument(
        "--max-size", type=int, default=MAX_TILE * 5,
        help="Downsample images larger than this before processing (default: %(default)s)",
    )
    parser.add_argument(
        "--min-size", type=int, default=1536,
        help="Skip images where max dimension < this (default: %(default)s)",
    )
    parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
    parser.add_argument("--weights-dir", type=Path, default=None, help="Local weights directory")
    parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
    parser.add_argument("--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)")
    parser.add_argument("--device", default="cuda", help="Torch device (default: %(default)s)")
    parser.add_argument(
        "--overwrite", action="store_true", help="Re-process even if output file exists"
    )
    args = parser.parse_args()

    console = Console()

    input_dir: Path = args.input_dir
    output_dir: Path = args.output_dir

    if not input_dir.is_dir():
        console.print(f"[red]Input directory not found:[/red] [cyan]{input_dir}[/cyan]")
        return

    input_images = sorted(
        p for p in input_dir.iterdir() if p.suffix.lower() in IMAGE_EXTENSIONS
    )
    if not input_images:
        console.print(f"[red]No images found in[/red] [cyan]{input_dir}[/cyan]")
        return

    output_dir.mkdir(parents=True, exist_ok=True)

    to_process = []
    skipped_existing = 0
    skipped_small = 0
    for src in input_images:
        dst = output_dir / (src.stem + ".png")
        if not args.overwrite and dst.exists():
            skipped_existing += 1
            continue
        with Image.open(src) as img:
            if max(img.size) < args.min_size or min(img.size) < MIN_IMAGE_DIMENSION:
                skipped_small += 1
                continue
        to_process.append((src, dst))

    if skipped_existing:
        console.print(
            f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
        )
    if skipped_small:
        console.print(
            f"[dim]Skipping {skipped_small} image(s) smaller than {args.min_size}px[/dim]"
        )

    if not to_process:
        console.print("[green]All images already processed.[/green]")
        return

    console.print(
        f"Processing [bold]{len(to_process)}[/bold] / {len(input_images)} images  "
        f"[dim]({input_dir}{output_dir})[/dim]"
    )

    model_kwargs = dict(num_steps=args.num_steps, device=args.device)
    if args.seed is not None:
        model_kwargs["seed"] = args.seed
    if args.weights_dir is not None:
        model_kwargs["weights_dir"] = args.weights_dir
    if args.batch_size is not None:
        model_kwargs["batch_size"] = args.batch_size
    model = NisabaRelief(**model_kwargs)

    progress = Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        MofNCompleteColumn(),
        TimeElapsedColumn(),
        TextColumn("eta"),
        SimpleTimeRemainingColumn(),
    )
    with progress:
        task = progress.add_task("Processing", total=len(to_process))
        for src, dst in to_process:
            progress.update(task, description=f"[cyan]{src.name}[/cyan]")
            image = Image.open(src).convert("RGB")
            original_size = image.size
            if max(image.size) > args.max_size:
                scale = args.max_size / max(image.size)
                new_size = (
                    round(image.width * scale) // 16 * 16,
                    round(image.height * scale) // 16 * 16,
                )
                image = image.resize(new_size, Image.LANCZOS)
            result = model.process(image, show_pbar=False)
            if result.size != original_size:
                result = result.resize(original_size, Image.LANCZOS)
            result.save(dst)
            progress.advance(task)

    console.print(
        f"[green]Done.[/green] {len(to_process)} image(s) saved to [cyan]{output_dir}[/cyan]"
    )


if __name__ == "__main__":
    main()