boatbomber commited on
Commit
3f8604c
·
1 Parent(s): 34f26f8

Refactor process_images and improve filtering out bad inputs

Browse files
Files changed (1) hide show
  1. dev_scripts/process_images.py +197 -53
dev_scripts/process_images.py CHANGED
@@ -1,6 +1,8 @@
1
  """Process a directory of images through NisabaRelief and save as PNG."""
2
 
3
  import argparse
 
 
4
  from pathlib import Path
5
 
6
  from PIL import Image
@@ -22,26 +24,36 @@ from nisaba_relief.constants import MAX_TILE, MIN_IMAGE_DIMENSION
22
 
23
  Image.MAX_IMAGE_PIXELS = None
24
 
25
- IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp"}
 
 
 
 
 
 
 
26
 
27
 
28
  class SimpleTimeRemainingColumn(ProgressColumn):
29
- """Estimates remaining time from the average duration of the last 10 iterations.
30
 
31
- Only recomputes when a new step completes so the display is stable.
 
32
  """
33
 
34
- def __init__(self, window: int = 10) -> None:
35
  super().__init__()
36
  self._last_completed: float = 0
37
  self._last_elapsed: float = 0.0
38
  self._durations: list[float] = []
39
- self._window: int = window
40
  self._cached: Text = Text("-:--:--", style="progress.remaining")
41
 
42
  def render(self, task: Task) -> Text:
43
  if task.completed <= self._last_completed:
44
  return self._cached
 
 
45
  elapsed = task.finished_time if task.finished else task.elapsed
46
  if not elapsed or not task.completed:
47
  self._last_completed = task.completed
@@ -52,7 +64,7 @@ class SimpleTimeRemainingColumn(ProgressColumn):
52
  if steps > 0 and self._last_completed > 0:
53
  per_step = step_duration / steps
54
  self._durations.append(per_step)
55
- if len(self._durations) > self._window:
56
  self._durations = self._durations[-self._window :]
57
  self._last_completed = task.completed
58
  self._last_elapsed = elapsed
@@ -73,7 +85,81 @@ class SimpleTimeRemainingColumn(ProgressColumn):
73
  return self._cached
74
 
75
 
76
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  parser = argparse.ArgumentParser(
78
  description="Process images through NisabaRelief and save as PNG."
79
  )
@@ -81,26 +167,112 @@ def main():
81
  "--input-dir", type=Path, required=True, help="Source image directory"
82
  )
83
  parser.add_argument(
84
- "--output-dir", type=Path, required=True, help="Destination directory (created if needed)"
 
 
 
85
  )
86
  parser.add_argument(
87
- "--max-size", type=int, default=MAX_TILE * 5,
 
 
88
  help="Downsample images larger than this before processing (default: %(default)s)",
89
  )
90
  parser.add_argument(
91
- "--min-size", type=int, default=1536,
 
 
92
  help="Skip images where max dimension < this (default: %(default)s)",
93
  )
 
 
 
 
 
 
94
  parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
95
- parser.add_argument("--weights-dir", type=Path, default=None, help="Local weights directory")
 
 
96
  parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
97
- parser.add_argument("--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)")
98
- parser.add_argument("--device", default="cuda", help="Torch device (default: %(default)s)")
 
 
 
 
99
  parser.add_argument(
100
  "--overwrite", action="store_true", help="Re-process even if output file exists"
101
  )
102
- args = parser.parse_args()
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  console = Console()
105
 
106
  input_dir: Path = args.input_dir
@@ -119,28 +291,20 @@ def main():
119
 
120
  output_dir.mkdir(parents=True, exist_ok=True)
121
 
122
- to_process = []
123
- skipped_existing = 0
124
- skipped_small = 0
125
- for src in input_images:
126
- dst = output_dir / (src.stem + ".png")
127
- if not args.overwrite and dst.exists():
128
- skipped_existing += 1
129
- continue
130
- with Image.open(src) as img:
131
- if max(img.size) < args.min_size or min(img.size) < MIN_IMAGE_DIMENSION:
132
- skipped_small += 1
133
- continue
134
- to_process.append((src, dst))
135
 
136
  if skipped_existing:
137
  console.print(
138
  f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
139
  )
140
- if skipped_small:
141
- console.print(
142
- f"[dim]Skipping {skipped_small} image(s) smaller than {args.min_size}px[/dim]"
143
- )
 
144
 
145
  if not to_process:
146
  console.print("[green]All images already processed.[/green]")
@@ -160,32 +324,12 @@ def main():
160
  model_kwargs["batch_size"] = args.batch_size
161
  model = NisabaRelief(**model_kwargs)
162
 
163
- progress = Progress(
164
- SpinnerColumn(),
165
- TextColumn("[progress.description]{task.description}"),
166
- BarColumn(),
167
- MofNCompleteColumn(),
168
- TimeElapsedColumn(),
169
- TextColumn("eta"),
170
- SimpleTimeRemainingColumn(),
171
- )
172
  with progress:
173
  task = progress.add_task("Processing", total=len(to_process))
174
  for src, dst in to_process:
175
  progress.update(task, description=f"[cyan]{src.name}[/cyan]")
176
- image = Image.open(src).convert("RGB")
177
- original_size = image.size
178
- if max(image.size) > args.max_size:
179
- scale = args.max_size / max(image.size)
180
- new_size = (
181
- round(image.width * scale) // 16 * 16,
182
- round(image.height * scale) // 16 * 16,
183
- )
184
- image = image.resize(new_size, Image.LANCZOS)
185
- result = model.process(image, show_pbar=False)
186
- if result.size != original_size:
187
- result = result.resize(original_size, Image.LANCZOS)
188
- result.save(dst)
189
  progress.advance(task)
190
 
191
  console.print(
 
1
  """Process a directory of images through NisabaRelief and save as PNG."""
2
 
3
  import argparse
4
+ import warnings
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
  from pathlib import Path
7
 
8
  from PIL import Image
 
24
 
25
  Image.MAX_IMAGE_PIXELS = None
26
 
27
+ IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp", ".webp", ".gif"}
28
+
29
+ SKIP_LABELS = {
30
+ "small": "image(s) smaller than {min_size}px",
31
+ "empty": "mostly-empty image(s)",
32
+ "bw": "black-and-white image(s)",
33
+ "corrupt": "corrupt/truncated image(s)",
34
+ }
35
 
36
 
37
  class SimpleTimeRemainingColumn(ProgressColumn):
38
+ """Estimates remaining time from the average duration of recent iterations.
39
 
40
+ The window is 0.5% of the task total (minimum 1, maximum 200). Only recomputes when a new
41
+ step completes so the display is stable.
42
  """
43
 
44
+ def __init__(self) -> None:
45
  super().__init__()
46
  self._last_completed: float = 0
47
  self._last_elapsed: float = 0.0
48
  self._durations: list[float] = []
49
+ self._window: int = 0
50
  self._cached: Text = Text("-:--:--", style="progress.remaining")
51
 
52
  def render(self, task: Task) -> Text:
53
  if task.completed <= self._last_completed:
54
  return self._cached
55
+ if not self._window and task.total:
56
+ self._window = min(max(1, int(task.total * 0.005)), 200)
57
  elapsed = task.finished_time if task.finished else task.elapsed
58
  if not elapsed or not task.completed:
59
  self._last_completed = task.completed
 
64
  if steps > 0 and self._last_completed > 0:
65
  per_step = step_duration / steps
66
  self._durations.append(per_step)
67
+ if self._window and len(self._durations) > self._window:
68
  self._durations = self._durations[-self._window :]
69
  self._last_completed = task.completed
70
  self._last_elapsed = elapsed
 
85
  return self._cached
86
 
87
 
88
+ def _make_progress(label: str) -> Progress:
89
+ """Build a Progress bar with the standard column layout."""
90
+ return Progress(
91
+ SpinnerColumn(),
92
+ TextColumn(label),
93
+ BarColumn(),
94
+ MofNCompleteColumn(),
95
+ TimeElapsedColumn(),
96
+ TextColumn("eta"),
97
+ SimpleTimeRemainingColumn(),
98
+ )
99
+
100
+
101
+ def _classify_histogram(
102
+ img: Image.Image,
103
+ uniform_threshold: float,
104
+ sat_threshold: float = 0.03,
105
+ mid_threshold: float = 0.28,
106
+ sample_size: int = 256,
107
+ ) -> str | None:
108
+ """Classify an image by its grayscale histogram. Returns a skip reason or None.
109
+
110
+ Builds a single thumbnail + histogram and runs two checks:
111
+ 1. Black and White: lacking saturated colors and mid tones.
112
+ 2. Mostly-empty: a single non-black color dominates (±5 sliding window).
113
+ """
114
+ # JPEG: decode at reduced resolution via libjpeg DCT scaling (fast, low memory)
115
+ # Other formats: no-op, thumbnail handles resize after full load
116
+ img.draft("RGB", (sample_size, sample_size))
117
+ img.thumbnail((sample_size, sample_size), Image.NEAREST)
118
+ hist = img.convert("L").histogram()
119
+ total = sum(hist)
120
+
121
+ # Check if it contains only black and white with no midtones (eg: lineart, text screenshots)
122
+ sat_hist = img.convert("HSV").split()[1].histogram()
123
+ high_sat = sum(sat_hist[31:]) / total
124
+ if high_sat < sat_threshold and sum(hist[45:205]) / total < mid_threshold:
125
+ return "bw"
126
+
127
+ # Check for dominant single color (sliding window of width 11, ±5)
128
+ if uniform_threshold < 1:
129
+ window = 11
130
+ half = window // 2
131
+ running = sum(hist[:window])
132
+ best_count = running
133
+ best_center = half
134
+ for center in range(half + 1, 256 - half):
135
+ running += hist[center + half] - hist[center - half - 1]
136
+ if running > best_count:
137
+ best_count = running
138
+ best_center = center
139
+ if best_center >= 10 and best_count / total >= uniform_threshold:
140
+ return "empty"
141
+
142
+ return None
143
+
144
+
145
+ def _check_image(
146
+ src: Path, dst: Path, min_size: int, max_uniform: float
147
+ ) -> tuple[Path, Path, str]:
148
+ """Classify a single image for filtering. Returns (src, dst, status)."""
149
+ try:
150
+ with warnings.catch_warnings(), Image.open(src) as img:
151
+ warnings.simplefilter("ignore", UserWarning)
152
+ if max(img.size) < min_size or min(img.size) < MIN_IMAGE_DIMENSION:
153
+ return src, dst, "small"
154
+ reason = _classify_histogram(img, max_uniform)
155
+ if reason:
156
+ return src, dst, reason
157
+ except (OSError, SyntaxError):
158
+ return src, dst, "corrupt"
159
+ return src, dst, "process"
160
+
161
+
162
+ def _parse_args() -> argparse.Namespace:
163
  parser = argparse.ArgumentParser(
164
  description="Process images through NisabaRelief and save as PNG."
165
  )
 
167
  "--input-dir", type=Path, required=True, help="Source image directory"
168
  )
169
  parser.add_argument(
170
+ "--output-dir",
171
+ type=Path,
172
+ required=True,
173
+ help="Destination directory (created if needed)",
174
  )
175
  parser.add_argument(
176
+ "--max-size",
177
+ type=int,
178
+ default=MAX_TILE * 5,
179
  help="Downsample images larger than this before processing (default: %(default)s)",
180
  )
181
  parser.add_argument(
182
+ "--min-size",
183
+ type=int,
184
+ default=1536,
185
  help="Skip images where max dimension < this (default: %(default)s)",
186
  )
187
+ parser.add_argument(
188
+ "--max-uniform",
189
+ type=float,
190
+ default=0.65,
191
+ help="Skip images where this fraction of pixels share a single non-black color (default: %(default)s, set to 1 to disable)",
192
+ )
193
  parser.add_argument("--seed", type=int, default=None, help="Reproducibility seed")
194
+ parser.add_argument(
195
+ "--weights-dir", type=Path, default=None, help="Local weights directory"
196
+ )
197
  parser.add_argument("--batch-size", type=int, default=None, help="Tile batch size")
198
+ parser.add_argument(
199
+ "--num-steps", type=int, default=2, help="Solver steps (default: %(default)s)"
200
+ )
201
+ parser.add_argument(
202
+ "--device", default="cuda", help="Torch device (default: %(default)s)"
203
+ )
204
  parser.add_argument(
205
  "--overwrite", action="store_true", help="Re-process even if output file exists"
206
  )
207
+ return parser.parse_args()
208
+
209
 
210
+ def _gather_candidates(
211
+ input_images: list[Path], output_dir: Path, overwrite: bool
212
+ ) -> tuple[list[tuple[Path, Path]], int]:
213
+ """Scan filesystem for images that need processing. Returns (candidates, skipped_existing)."""
214
+ candidates = []
215
+ skipped_existing = 0
216
+ with _make_progress("Gathering candidates") as progress:
217
+ task = progress.add_task("Scanning", total=len(input_images))
218
+ for src in input_images:
219
+ dst = output_dir / (src.stem + ".png")
220
+ if not overwrite and dst.exists():
221
+ skipped_existing += 1
222
+ else:
223
+ candidates.append((src, dst))
224
+ progress.advance(task)
225
+ return candidates, skipped_existing
226
+
227
+
228
+ def _filter_candidates(
229
+ candidates: list[tuple[Path, Path]], min_size: int, max_uniform: float
230
+ ) -> tuple[list[tuple[Path, Path]], dict[str, int]]:
231
+ """Run parallel image checks (size + histogram). Returns (to_process, skipped_counts)."""
232
+ to_process = []
233
+ skipped: dict[str, int] = {}
234
+ executor = ThreadPoolExecutor(max_workers=8)
235
+ futures = [
236
+ executor.submit(_check_image, src, dst, min_size, max_uniform)
237
+ for src, dst in candidates
238
+ ]
239
+ with _make_progress("Filtering candidates") as progress:
240
+ task = progress.add_task("Filtering", total=len(futures))
241
+ try:
242
+ for future in as_completed(futures):
243
+ src, dst, status = future.result()
244
+ if status == "process":
245
+ to_process.append((src, dst))
246
+ else:
247
+ skipped[status] = skipped.get(status, 0) + 1
248
+ progress.advance(task)
249
+ except KeyboardInterrupt:
250
+ executor.shutdown(wait=False, cancel_futures=True)
251
+ raise
252
+ executor.shutdown()
253
+ to_process.sort()
254
+ return to_process, skipped
255
+
256
+
257
+ def _process_image(src: Path, dst: Path, model: NisabaRelief, max_size: int) -> None:
258
+ """Load, optionally downsample, run model, restore size, and save a single image."""
259
+ image = Image.open(src).convert("RGB")
260
+ original_size = image.size
261
+ if max(image.size) > max_size:
262
+ scale = max_size / max(image.size)
263
+ new_size = (
264
+ round(image.width * scale) // 16 * 16,
265
+ round(image.height * scale) // 16 * 16,
266
+ )
267
+ image = image.resize(new_size, Image.LANCZOS)
268
+ result = model.process(image, show_pbar=False)
269
+ if result.size != original_size:
270
+ result = result.resize(original_size, Image.LANCZOS)
271
+ result.save(dst)
272
+
273
+
274
+ def main():
275
+ args = _parse_args()
276
  console = Console()
277
 
278
  input_dir: Path = args.input_dir
 
291
 
292
  output_dir.mkdir(parents=True, exist_ok=True)
293
 
294
+ candidates, skipped_existing = _gather_candidates(
295
+ input_images, output_dir, args.overwrite
296
+ )
297
+ to_process, skipped = _filter_candidates(candidates, args.min_size, args.max_uniform)
 
 
 
 
 
 
 
 
 
298
 
299
  if skipped_existing:
300
  console.print(
301
  f"[dim]Skipping {skipped_existing} already-processed image(s)[/dim]"
302
  )
303
+ for reason, label in SKIP_LABELS.items():
304
+ if count := skipped.get(reason):
305
+ console.print(
306
+ f"[dim]Skipping {count} {label.format(min_size=args.min_size)}[/dim]"
307
+ )
308
 
309
  if not to_process:
310
  console.print("[green]All images already processed.[/green]")
 
324
  model_kwargs["batch_size"] = args.batch_size
325
  model = NisabaRelief(**model_kwargs)
326
 
327
+ progress = _make_progress("[progress.description]{task.description}")
 
 
 
 
 
 
 
 
328
  with progress:
329
  task = progress.add_task("Processing", total=len(to_process))
330
  for src, dst in to_process:
331
  progress.update(task, description=f"[cyan]{src.name}[/cyan]")
332
+ _process_image(src, dst, model, args.max_size)
 
 
 
 
 
 
 
 
 
 
 
 
333
  progress.advance(task)
334
 
335
  console.print(