hvoss-techfak commited on
Commit
d677a26
·
1 Parent(s): c7a2403

hopefully fixed spaces gpu bug?

Browse files
Files changed (2) hide show
  1. app.py +46 -42
  2. auto_forge.py +1087 -0
app.py CHANGED
@@ -780,52 +780,56 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
780
  import threading
781
 
782
  class Worker(threading.Thread):
783
- def __init__(self, cmd, log_path):
784
- super().__init__(daemon=True)
785
- self.cmd, self.log_path = cmd, log_path
786
- self.returncode = None
787
- self.exc = None
788
-
789
- def run(self):
790
- try:
791
- # Import joblib's parallel_backend to match previous environment setup
792
- from joblib import parallel_backend
793
-
794
- # Import the autoforge high-level module which contains main().
795
- # No fallback allowed; require this module to exist in the package.
 
 
 
 
 
796
  try:
797
- autoforge_main = importlib.import_module("autoforge.auto_forge")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  except Exception as e:
799
- # Log and fail the worker if import fails.
800
- with open(self.log_path, "a", encoding="utf-8") as lf:
801
- lf.write(f"\nERROR: Could not import autoforge.auto_forge: {e}\n")
802
  self.exc = e
803
- self.returncode = -1
804
- return
805
-
806
- # Redirect stdout/stderr into the provided log file while running
807
- with open(self.log_path, "a", encoding="utf-8") as lf, \
808
- redirect_stdout(lf), redirect_stderr(lf), parallel_backend("threading", n_jobs=-1):
809
- try:
810
- # Ensure CLI-like argv for the autoforge parser
811
- sys.argv = ["autoforge"] + (self.cmd[1:] if len(self.cmd) > 1 else [])
812
- autoforge_main.main()
813
- # If main exits normally, treat as success (0)
814
- self.returncode = 0
815
- except SystemExit as e:
816
- # Preserve exit code behavior
817
- self.returncode = e.code if isinstance(e.code, int) or e.code is None else 0
818
- except Exception as e:
819
- # Write exception to log and mark failure
820
- lf.write(f"\nERROR: {e}\n")
821
- self.exc = e
822
  self.returncode = -1
823
- except Exception as e_outer:
824
- # If import or log file operations fail, persist exception and code
825
- self.exc = e_outer
826
- with open(self.log_path, "a", encoding="utf-8") as lf:
827
- lf.write("\nERROR: {}\n".format(exc_text(e_outer)))
828
- self.returncode = -1
 
 
829
 
830
  try:
831
  worker = Worker(command, log_file)
 
780
  import threading
781
 
782
  class Worker(threading.Thread):
783
+ def __init__(self, cmd, log_path):
784
+ super().__init__(daemon=True)
785
+ self.cmd, self.log_path = cmd, log_path
786
+ self.returncode = None
787
+ self.exc = None
788
+
789
+ def run(self):
790
+ """Import and run the local `auto_forge.py` module in-process.
791
+
792
+ We load the script from the project dir as a fresh module using
793
+ importlib.util.spec_from_file_location to ensure decorators like
794
+ @spaces.GPU are executed at import time. Stdout/stderr are redirected
795
+ to the run log to preserve the live console stream.
796
+ """
797
+ try:
798
+ # Import or reload the package module so decorators execute on import.
799
+ with open(self.log_path, "a", encoding="utf-8") as lf, \
800
+ redirect_stdout(lf), redirect_stderr(lf):
801
  try:
802
+ if "autoforge.auto_forge" in sys.modules:
803
+ module = importlib.reload(sys.modules["autoforge.auto_forge"])
804
+ else:
805
+ module = importlib.import_module("autoforge.auto_forge")
806
+
807
+ # Provide CLI-like argv for the parser inside the module
808
+ sys.argv = ["autoforge"] + (self.cmd[1:] if len(self.cmd) > 1 else [])
809
+
810
+ if hasattr(module, "main"):
811
+ try:
812
+ module.main()
813
+ self.returncode = 0
814
+ except SystemExit as se:
815
+ self.returncode = se.code if isinstance(se.code, int) or se.code is None else 0
816
+ else:
817
+ raise AttributeError("autoforge.auto_forge does not expose a main() function")
818
  except Exception as e:
819
+ lf.write(f"\nERROR while importing/running autoforge.auto_forge: {exc_text(e)}\n")
 
 
820
  self.exc = e
821
+ if isinstance(e, SystemExit):
822
+ self.returncode = e.code if isinstance(e.code, int) or e.code is None else 1
823
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  self.returncode = -1
825
+ except Exception as outer_e:
826
+ self.exc = outer_e
827
+ try:
828
+ with open(self.log_path, "a", encoding="utf-8") as lf:
829
+ lf.write(f"\nERROR loading autoforge.auto_forge: {exc_text(outer_e)}\n")
830
+ except Exception:
831
+ pass
832
+ self.returncode = -1
833
 
834
  try:
835
  worker = Worker(command, log_file)
auto_forge.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """auto_forge.py
2
+
3
+ High-level orchestration module for the AutoForge optimization pipeline.
4
+
5
+ Responsibilities:
6
+ - Parse CLI / config file arguments.
7
+ - Load image and material properties.
8
+ - (Optionally) auto-select a background filament color based on dominant image color.
9
+ - Initialize a height map using one of several strategies (k-means clustering or depth estimation).
10
+ - Build and run the filament optimization loop (differentiable + periodic discretization checks).
11
+ - Optionally prune the solution to respect practical printer constraints (materials, swaps, layers).
12
+ - Export final artifacts: preview PNG, STL(s), swap instructions, project file, metadata.
13
+
14
+ The implementation intentionally keeps side-effects (disk writes / prints) order-stable to
15
+ preserve prior behavior. Helper functions are factored out for readability; no functional
16
+ behavior should have changed relative to the previous monolithic version.
17
+ """
18
+ import argparse
19
+ import sys
20
+ import os
21
+ import traceback
22
+ from typing import Optional, Tuple, List
23
+
24
+ import configargparse
25
+ import cv2
26
+ try:
27
+ import spaces
28
+ except Exception:
29
+ # Provide a minimal shim so @spaces.GPU can be used even when 'spaces' isn't installed.
30
+ def _spaces_noop_decorator(fn=None):
31
+ # Support usage as @spaces.GPU or @spaces.GPU()
32
+ if fn is None:
33
+ def _inner(f):
34
+ return f
35
+ return _inner
36
+ return fn
37
+
38
+ class _DummySpaces:
39
+ GPU = staticmethod(_spaces_noop_decorator)
40
+
41
+ spaces = _DummySpaces()
42
+ import torch
43
+ import numpy as np
44
+ from tqdm import tqdm
45
+
46
+ from autoforge.Helper import PruningHelper
47
+ from autoforge.Helper.FilamentHelper import hex_to_rgb, load_materials
48
+ from autoforge.Helper.Heightmaps.ChristofidesHeightMap import (
49
+ run_init_threads,
50
+ )
51
+
52
+ from autoforge.Helper.ImageHelper import resize_image, imread
53
+ from autoforge.Helper.OtherHelper import set_seed, perform_basic_check, get_device
54
+ from autoforge.Helper.OutputHelper import (
55
+ generate_stl,
56
+ generate_swap_instructions,
57
+ generate_project_file,
58
+ generate_flatforge_stls,
59
+ )
60
+ from autoforge.Modules.Optimizer import FilamentOptimizer
61
+
62
+ # check if we can use torch.set_float32_matmul_precision('high')
63
+ if torch.__version__ >= "2.0.0":
64
+ try:
65
+ torch.set_float32_matmul_precision("high")
66
+ except Exception as e:
67
+ print("Warning: Could not set float32 matmul precision to high. Error:", e)
68
+ pass
69
+
70
+
71
+ def parse_args() -> argparse.Namespace:
72
+ """Create and parse command-line & config-file arguments.
73
+
74
+ Returns:
75
+ argparse.Namespace: Populated arguments structure. Some parameters may be adjusted later
76
+ (e.g., num_init_cluster_layers when -1 to infer from max_layers).
77
+ """
78
+ parser = configargparse.ArgParser()
79
+ parser.add_argument("--config", is_config_file=True, help="Path to config file")
80
+
81
+ parser.add_argument(
82
+ "--input_image", type=str, required=True, help="Path to input image"
83
+ )
84
+ parser.add_argument(
85
+ "--csv_file",
86
+ type=str,
87
+ default="",
88
+ help="Path to CSV file with material data",
89
+ )
90
+ parser.add_argument(
91
+ "--json_file",
92
+ type=str,
93
+ default="",
94
+ help="Path to json file with material data",
95
+ )
96
+ parser.add_argument(
97
+ "--output_folder", type=str, default="output", help="Folder to write outputs"
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--iterations", type=int, default=6000, help="Number of optimization iterations"
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--warmup_fraction",
106
+ type=float,
107
+ default=1.0,
108
+ help="Fraction of iterations for keeping the tau at the initial value",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--learning_rate_warmup_fraction",
113
+ type=float,
114
+ default=0.01,
115
+ help="Fraction of iterations that the learning rate is increasing (warmup)",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--init_tau",
120
+ type=float,
121
+ default=1.0,
122
+ help="Initial tau value for Gumbel-Softmax",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--final_tau",
127
+ type=float,
128
+ default=0.01,
129
+ help="Final tau value for Gumbel-Softmax",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--learning_rate",
134
+ type=float,
135
+ default=0.015,
136
+ help="Learning rate for optimization",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--layer_height", type=float, default=0.04, help="Layer thickness in mm"
141
+ )
142
+
143
+ parser.add_argument(
144
+ "--max_layers", type=int, default=75, help="Maximum number of layers"
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--min_layers",
149
+ type=int,
150
+ default=0,
151
+ help="Minimum number of layers. Used for pruning.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--background_height",
156
+ type=float,
157
+ default=0.24,
158
+ help="Height of the background in mm",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--background_color", type=str, default="#000000", help="Background color"
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--auto_background_color",
167
+ default=True,
168
+ help="Automatically set background color to the closest filament color matching the dominant image color. Overrides --background_color.",
169
+ )
170
+
171
+ parser.add_argument(
172
+ "--visualize",
173
+ type=bool,
174
+ default=True,
175
+ help="Enable visualization during optimization",
176
+ action=argparse.BooleanOptionalAction,
177
+ )
178
+
179
+ # Instead of an output_size parameter, we use stl_output_size and nozzle_diameter.
180
+ parser.add_argument(
181
+ "--stl_output_size",
182
+ type=int,
183
+ default=150,
184
+ help="Size of the longest dimension of the output STL file in mm",
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--processing_reduction_factor",
189
+ type=int,
190
+ default=2,
191
+ help="Reduction factor for reducing the processing size compared to the output size (default: 2 - half resolution)",
192
+ )
193
+
194
+ parser.add_argument(
195
+ "--nozzle_diameter",
196
+ type=float,
197
+ default=0.4,
198
+ help="Diameter of the printer nozzle in mm (details smaller than half this value will be ignored)",
199
+ )
200
+
201
+ parser.add_argument(
202
+ "--early_stopping",
203
+ type=int,
204
+ default=2000,
205
+ help="Number of steps without improvement before stopping",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--perform_pruning",
210
+ type=bool,
211
+ default=True,
212
+ help="Perform pruning after optimization",
213
+ action=argparse.BooleanOptionalAction,
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--fast_pruning",
218
+ type=bool,
219
+ default=True,
220
+ help="Use fast pruning method",
221
+ action=argparse.BooleanOptionalAction,
222
+ )
223
+ parser.add_argument(
224
+ "--fast_pruning_percent",
225
+ type=float,
226
+ default=0.25,
227
+ help="Percentage of increment search for fast pruning",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--pruning_max_colors",
232
+ type=int,
233
+ default=100,
234
+ help="Max number of colors allowed after pruning",
235
+ )
236
+ parser.add_argument(
237
+ "--pruning_max_swaps",
238
+ type=int,
239
+ default=100,
240
+ help="Max number of swaps allowed after pruning",
241
+ )
242
+
243
+ parser.add_argument(
244
+ "--pruning_max_layer",
245
+ type=int,
246
+ default=75,
247
+ help="Max number of layers allowed after pruning",
248
+ )
249
+
250
+ parser.add_argument(
251
+ "--random_seed",
252
+ type=int,
253
+ default=0,
254
+ help="Specify the random seed, or use 0 for automatic generation",
255
+ )
256
+
257
+ parser.add_argument(
258
+ "--mps",
259
+ action="store_true",
260
+ help="Use the Metal Performance Shaders (MPS) backend, if available.",
261
+ )
262
+
263
+ parser.add_argument(
264
+ "--run_name", type=str, help="Name of the run used for TensorBoard logging"
265
+ )
266
+
267
+ parser.add_argument(
268
+ "--tensorboard", action="store_true", help="Enable TensorBoard logging"
269
+ )
270
+
271
+ parser.add_argument(
272
+ "--num_init_rounds",
273
+ type=int,
274
+ default=16,
275
+ help="Number of rounds to choose the starting height map from.",
276
+ )
277
+
278
+ parser.add_argument(
279
+ "--num_init_cluster_layers",
280
+ type=int,
281
+ default=-1,
282
+ help="Number of layers to cluster the image into.",
283
+ )
284
+
285
+ parser.add_argument(
286
+ "--disable_visualization_for_gradio",
287
+ type=int,
288
+ default=0,
289
+ help="Simple switch to disable the matplotlib render window for gradio rendering.",
290
+ )
291
+
292
+ parser.add_argument(
293
+ "--best_of",
294
+ type=int,
295
+ default=1,
296
+ help="Run the program multiple times and output the best result.",
297
+ )
298
+
299
+ parser.add_argument(
300
+ "--discrete_check",
301
+ type=int,
302
+ default=100,
303
+ help="Modulo how often to check for new discrete results.",
304
+ )
305
+
306
+ parser.add_argument(
307
+ "--flatforge",
308
+ type=bool,
309
+ default=False,
310
+ help="Enable FlatForge mode to generate separate STL files for each color",
311
+ action=argparse.BooleanOptionalAction,
312
+ )
313
+
314
+ parser.add_argument(
315
+ "--cap_layers",
316
+ type=int,
317
+ default=0,
318
+ help="Number of complete clear/transparent layers to add on top in FlatForge mode",
319
+ )
320
+
321
+ # New: choose heightmap initializer
322
+ parser.add_argument(
323
+ "--init_heightmap_method",
324
+ type=str,
325
+ choices=["kmeans", "depth"],
326
+ default="kmeans",
327
+ help="Initializer for the height map: 'kmeans' (fast, default) or 'depth' (requires transformers).",
328
+ )
329
+ # New priority mask argument (optional)
330
+ parser.add_argument(
331
+ "--priority_mask",
332
+ type=str,
333
+ default="",
334
+ help="Optional path to a priority mask image (same dimensions as input image). Non-empty: apply weighted loss (0.1 outside, 1.0 at max inside).",
335
+ )
336
+
337
+ args = parser.parse_args()
338
+ return args
339
+
340
+
341
+ def _compute_dominant_image_color(
342
+ img_rgb: np.ndarray, alpha: Optional[np.ndarray]
343
+ ) -> Optional[Tuple[str, np.ndarray]]:
344
+ """Compute an approximate dominant color of the input image.
345
+
346
+ Strategy:
347
+ - Optionally downscale very large images for efficiency.
348
+ - Ignore (mostly) transparent pixels if alpha channel is provided.
349
+ - Use frequency counts (np.unique) over exact RGB triplets.
350
+
351
+ Args:
352
+ img_rgb: Image array in RGB order (H,W,3) uint8.
353
+ alpha: Optional alpha mask (H,W,1) or (H,W) uint8; pixels <128 are ignored.
354
+
355
+ Returns:
356
+ (hex_color, normalized_rgb) where hex_color is a '#RRGGBB' string and normalized_rgb
357
+ is float32 in [0,1]^3. Returns None if no valid pixels remain.
358
+ """
359
+ try:
360
+ # Downscale if needed (max side 300 px)
361
+ h, w = img_rgb.shape[:2]
362
+ max_side = max(h, w)
363
+ target_side = 300
364
+ alpha_small: Optional[np.ndarray] = None
365
+ if max_side > target_side:
366
+ scale = target_side / max_side
367
+ new_w = max(1, int(w * scale))
368
+ new_h = max(1, int(h * scale))
369
+ img_small = cv2.resize(
370
+ img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA
371
+ )
372
+ if alpha is not None:
373
+ alpha_small = cv2.resize(
374
+ alpha, (new_w, new_h), interpolation=cv2.INTER_NEAREST
375
+ )
376
+ else:
377
+ img_small = img_rgb
378
+ alpha_small = alpha
379
+ # Build mask for valid pixels (ignore transparent)
380
+ if alpha_small is not None:
381
+ valid_mask = (
382
+ alpha_small[..., 0] if alpha_small.ndim == 3 else alpha_small
383
+ ) >= 128
384
+ else:
385
+ valid_mask = np.ones(img_small.shape[:2], dtype=bool)
386
+ if valid_mask.sum() == 0:
387
+ return None
388
+ pixels = img_small[valid_mask]
389
+ # Use np.unique to find most frequent RGB triplet
390
+ unique_colors, counts = np.unique(
391
+ pixels.reshape(-1, 3), axis=0, return_counts=True
392
+ )
393
+ idx = int(np.argmax(counts))
394
+ dom_rgb_uint8 = unique_colors[idx]
395
+ dom_rgb_norm = dom_rgb_uint8.astype(np.float32) / 255.0
396
+ hex_color = "#" + "".join(f"{c:02X}" for c in dom_rgb_uint8)
397
+ return hex_color, dom_rgb_norm
398
+ except Exception:
399
+ traceback.print_exc()
400
+ return None
401
+
402
+
403
+ def _auto_select_background_color(
404
+ args,
405
+ img_rgb: np.ndarray,
406
+ alpha: Optional[np.ndarray],
407
+ material_colors_np: np.ndarray,
408
+ material_names: List[str],
409
+ colors_list: List[str],
410
+ ) -> None:
411
+ """Optionally override the user-provided background color with a closest material color.
412
+
413
+ When --auto_background_color is set:
414
+ - Determine dominant image color (ignoring transparency).
415
+ - Find closest filament (Euclidean in normalized RGB).
416
+ - Persist metadata to 'auto_background_color.txt'.
417
+
418
+ Side effects: Mutates args.background_color and attaches background_material_* fields.
419
+
420
+ Args:
421
+ args: Global argument namespace (mutated).
422
+ img_rgb: Full-resolution RGB image (uint8).
423
+ alpha: Optional alpha channel for transparency filtering.
424
+ material_colors_np: (N,3) array of filament RGB colors in [0,1].
425
+ material_names: List of filament names.
426
+ colors_list: List of filament hex color strings (#RRGGBB).
427
+ """
428
+ if not args.auto_background_color:
429
+ return
430
+ res = _compute_dominant_image_color(img_rgb, alpha)
431
+ if res is not None:
432
+ dominant_hex, dominant_rgb = res
433
+ diffs = material_colors_np - dominant_rgb[None, :]
434
+ dists = np.linalg.norm(diffs, axis=1)
435
+ closest_idx = int(np.argmin(dists))
436
+ chosen_hex = colors_list[closest_idx]
437
+ print(
438
+ f"Auto background color: dominant image color {dominant_hex} -> closest filament {chosen_hex} (index {closest_idx})."
439
+ )
440
+ args.background_color = chosen_hex
441
+ args.background_material_index = closest_idx
442
+ try:
443
+ args.background_material_name = material_names[closest_idx]
444
+ except Exception:
445
+ args.background_material_name = None
446
+ try:
447
+ with open(
448
+ os.path.join(args.output_folder, "auto_background_color.txt"), "w"
449
+ ) as f:
450
+ f.write(f"dominant_image_color={dominant_hex}\n")
451
+ f.write(f"chosen_filament_color={chosen_hex}\n")
452
+ f.write(f"closest_filament_index={closest_idx}\n")
453
+ if getattr(args, "background_material_name", None):
454
+ f.write(
455
+ f"closest_filament_name={args.background_material_name}\n"
456
+ )
457
+ except Exception:
458
+ traceback.print_exc()
459
+ else:
460
+ print(
461
+ "Warning: Auto background color computation failed; using provided --background_color."
462
+ )
463
+
464
+
465
+ def _prepare_background_and_materials(
466
+ args, device: torch.device, material_colors_np: np.ndarray, material_TDs_np: np.ndarray
467
+ ) -> Tuple[Tuple[int, int, int], torch.Tensor, torch.Tensor, torch.Tensor]:
468
+ """Create torch tensors for materials & background color.
469
+
470
+ Args:
471
+ args: Global arguments (uses background_color hex string).
472
+ device: Torch device for tensor placement.
473
+ material_colors_np: (N,3) float32 array in [0,1].
474
+ material_TDs_np: (N,*) array of material transmission / diffusion parameters.
475
+
476
+ Returns:
477
+ (bgr_tuple_uint8, background_tensor, material_colors_tensor, material_TDs_tensor)
478
+ """
479
+ bgr_tuple = hex_to_rgb(args.background_color)
480
+ background = torch.tensor(bgr_tuple, dtype=torch.float32, device=device)
481
+ material_colors = torch.tensor(
482
+ material_colors_np, dtype=torch.float32, device=device
483
+ )
484
+ material_TDs = torch.tensor(material_TDs_np, dtype=torch.float32, device=device)
485
+ return bgr_tuple, background, material_colors, material_TDs
486
+
487
+
488
+ def _compute_pixel_sizes(args) -> Tuple[int, int]:
489
+ """Derive pixel dimensions for solving vs. output STL size.
490
+
491
+ We oversample relative to nozzle_diameter to capture detail, then optionally downscale
492
+ for the differentiable optimization pass.
493
+
494
+ Returns:
495
+ (computed_output_size, computed_processing_size)
496
+ """
497
+ computed_output_size = int(round(args.stl_output_size * 2 / args.nozzle_diameter))
498
+ computed_processing_size = int(
499
+ round(computed_output_size / args.processing_reduction_factor)
500
+ )
501
+ print(f"Computed solving pixel size: {computed_output_size}")
502
+ return computed_output_size, computed_processing_size
503
+
504
+
505
+ def _load_priority_mask(
506
+ args, output_img_np: np.ndarray, device: torch.device
507
+ ) -> Optional[torch.Tensor]:
508
+ """Load and resize a priority / focus mask if provided.
509
+
510
+ The mask scales heights during initialization and can later weight loss terms.
511
+
512
+ Behavior:
513
+ - Reads image; converts RGBA/RGB to grayscale.
514
+ - Resizes to full-resolution output size.
515
+ - Persists a diagnostic PNG after normalization.
516
+
517
+ Returns:
518
+ focus_map_full: Float32 tensor (H,W) in [0,1] or None if no mask provided.
519
+ """
520
+ focus_map_full = None
521
+ if args.priority_mask != "":
522
+ pm = imread(args.priority_mask, cv2.IMREAD_UNCHANGED)
523
+ if pm.ndim == 3:
524
+ if pm.shape[2] == 4:
525
+ pm = pm[:, :, :3]
526
+ pm = cv2.cvtColor(pm, cv2.COLOR_BGR2GRAY)
527
+ tgt_h, tgt_w = output_img_np.shape[:2]
528
+ pm_resized = cv2.resize(pm, (tgt_w, tgt_h), interpolation=cv2.INTER_LINEAR)
529
+ pm_float = pm_resized.astype(np.float32) / 255.0
530
+ focus_map_full = torch.tensor(pm_float, dtype=torch.float32, device=device)
531
+ cv2.imwrite(
532
+ os.path.join(args.output_folder, "priority_mask_resized.png"),
533
+ (pm_float * 255).astype(np.uint8),
534
+ )
535
+ return focus_map_full
536
+
537
+
538
+ def _initialize_heightmap(
539
+ args,
540
+ output_img_np: np.ndarray,
541
+ bgr_tuple: Tuple[int, int, int],
542
+ material_colors_np: np.ndarray,
543
+ random_seed: int,
544
+ ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
545
+ """Initialize the height map logits & labels using selected method.
546
+
547
+ Methods:
548
+ depth : Uses an external depth estimation model (requires transformers).
549
+ kmeans : Clusters pixel colors into layer assignments (default).
550
+
551
+ Returns:
552
+ pixel_height_logits_init: (H,W) float32 numpy array of raw logits.
553
+ global_logits_init : (L,*) global logits array or None (depth variant may not use it).
554
+ pixel_height_labels : (H,W) int array of discrete initial layer indices.
555
+ """
556
+ print("Initalizing height map. This can take a moment...")
557
+ if args.init_heightmap_method == "depth":
558
+ try:
559
+ from autoforge.Helper.Heightmaps.DepthEstimateHeightMap import (
560
+ init_height_map_depth_color_adjusted,
561
+ )
562
+ except Exception:
563
+ print(
564
+ "Error: depth initializer requested but could not be imported. Install 'transformers' and try again.",
565
+ file=sys.stderr,
566
+ )
567
+ raise
568
+ pixel_height_logits_init, pixel_height_labels = (
569
+ init_height_map_depth_color_adjusted(
570
+ output_img_np,
571
+ args.max_layers,
572
+ random_seed=random_seed,
573
+ focus_map=None,
574
+ )
575
+ )
576
+ global_logits_init = None
577
+ else:
578
+ pixel_height_logits_init, global_logits_init, pixel_height_labels = (
579
+ run_init_threads(
580
+ output_img_np,
581
+ args.max_layers,
582
+ args.layer_height,
583
+ bgr_tuple,
584
+ random_seed=random_seed,
585
+ num_threads=args.num_init_rounds,
586
+ init_method="kmeans",
587
+ cluster_layers=args.num_init_cluster_layers,
588
+ material_colors=material_colors_np,
589
+ focus_map=None,
590
+ )
591
+ )
592
+ return pixel_height_logits_init, global_logits_init, pixel_height_labels
593
+
594
+
595
+ def _prepare_processing_targets(
596
+ output_img_np: np.ndarray,
597
+ computed_processing_size: int,
598
+ device: torch.device,
599
+ focus_map_full: Optional[torch.Tensor],
600
+ ) -> Tuple[np.ndarray, torch.Tensor, Optional[torch.Tensor]]:
601
+ """Create downscaled optimization target & focus map for faster iterations.
602
+
603
+ Args:
604
+ output_img_np: Full-resolution RGB image (float or uint8 expected).
605
+ computed_processing_size: Target square size for processing (maintains aspect via resize helper).
606
+ device: Torch device.
607
+ focus_map_full: Optional full-resolution focus map tensor.
608
+
609
+ Returns:
610
+ processing_img_np : Downscaled numpy image (H_p,W_p,3).
611
+ processing_target : Torch tensor version (float32) on device.
612
+ focus_map_proc : Optional downscaled focus map tensor (H_p,W_p).
613
+ """
614
+ processing_img_np = resize_image(output_img_np, computed_processing_size)
615
+ processing_target = torch.tensor(
616
+ processing_img_np, dtype=torch.float32, device=device
617
+ )
618
+
619
+ focus_map_proc = None
620
+ if focus_map_full is not None:
621
+ fm_proc_np = cv2.resize(
622
+ focus_map_full.cpu().numpy().astype(np.float32),
623
+ (processing_target.shape[1], processing_target.shape[0]),
624
+ interpolation=cv2.INTER_LINEAR,
625
+ )
626
+ focus_map_proc = torch.tensor(fm_proc_np, dtype=torch.float32, device=device)
627
+
628
+ return processing_img_np, processing_target, focus_map_proc
629
+
630
+
631
+ def _build_optimizer(
632
+ args,
633
+ processing_target: torch.Tensor,
634
+ processing_pixel_height_logits_init: np.ndarray,
635
+ processing_pixel_height_labels: np.ndarray,
636
+ global_logits_init,
637
+ material_colors: torch.Tensor,
638
+ material_TDs: torch.Tensor,
639
+ background: torch.Tensor,
640
+ device: torch.device,
641
+ perception_loss_module,
642
+ focus_map_proc: Optional[torch.Tensor],
643
+ ) -> FilamentOptimizer:
644
+ """Instantiate the FilamentOptimizer with initial tensors and configuration.
645
+
646
+ Args mirror the optimizer's constructor; this function simply centralizes assembly.
647
+
648
+ Returns:
649
+ FilamentOptimizer: Ready-to-run optimizer instance.
650
+ """
651
+ optimizer = FilamentOptimizer(
652
+ args=args,
653
+ target=processing_target,
654
+ pixel_height_logits_init=processing_pixel_height_logits_init,
655
+ pixel_height_labels=processing_pixel_height_labels,
656
+ global_logits_init=global_logits_init,
657
+ material_colors=material_colors,
658
+ material_TDs=material_TDs,
659
+ background=background,
660
+ device=device,
661
+ perception_loss_module=perception_loss_module,
662
+ focus_map=focus_map_proc,
663
+ )
664
+ return optimizer
665
+
666
+ @spaces.GPU
667
+ def _run_optimization_loop(optimizer: FilamentOptimizer, args, device: torch.device) -> None:
668
+ """Execute the main gradient-based optimization iterations.
669
+
670
+ Features:
671
+ - Automatic mixed precision (bfloat16 unless MPS).
672
+ - Periodic visualization & tensorboard logging (every 100 iterations).
673
+ - Discrete solution snapshots controlled via --discrete_check.
674
+ - Early stopping after a patience window (--early_stopping).
675
+
676
+ Args:
677
+ optimizer: Configured FilamentOptimizer instance.
678
+ args: Global argument namespace.
679
+ device: Torch device for autocast context.
680
+ """
681
+ print("Starting optimization...")
682
+ tbar = tqdm(range(args.iterations))
683
+ dtype = torch.bfloat16 if not args.mps else torch.float32
684
+ with torch.autocast(device.type, dtype=dtype):
685
+ for i in tbar:
686
+ loss_val = optimizer.step(record_best=i % args.discrete_check == 0)
687
+
688
+ optimizer.visualize(interval=100)
689
+ optimizer.log_to_tensorboard(interval=100)
690
+
691
+ if (i + 1) % 100 == 0:
692
+ tbar.set_description(
693
+ f"Iteration {i + 1}, Loss = {loss_val:.4f}, best validation Loss = {optimizer.best_discrete_loss:.4f}, learning_rate= {optimizer.current_learning_rate:.6f}"
694
+ )
695
+ if (
696
+ optimizer.best_step is not None
697
+ and optimizer.num_steps_done - optimizer.best_step > args.early_stopping
698
+ ):
699
+ print(
700
+ "Early stopping after",
701
+ args.early_stopping,
702
+ "steps without improvement.",
703
+ )
704
+ break
705
+
706
+
707
+
708
+ def _post_optimize_and_export(
709
+ args,
710
+ optimizer: FilamentOptimizer,
711
+ pixel_height_logits_init: np.ndarray,
712
+ pixel_height_labels: np.ndarray,
713
+ output_target: torch.Tensor,
714
+ alpha: Optional[np.ndarray],
715
+ material_colors_np: np.ndarray,
716
+ material_TDs_np: np.ndarray,
717
+ material_names: List[str],
718
+ bgr_tuple: Tuple[int, int, int],
719
+ device: torch.device,
720
+ focus_map_full: Optional[torch.Tensor],
721
+ focus_map_proc: Optional[torch.Tensor],
722
+ ) -> float:
723
+ """Finalize solution, optionally prune, and write all output artifacts.
724
+
725
+ Steps:
726
+ - Restore full-resolution logits to optimizer and (optionally) height residual.
727
+ - Replace focus map with full-res version if used.
728
+ - Perform pruning (respecting color slots for background & clear in FlatForge mode).
729
+ - Compute final loss estimate and persist to file.
730
+ - Export preview PNG, STL(s), swap instructions & project file.
731
+
732
+ Returns:
733
+ float: The final reported loss (post-pruning).
734
+ """
735
+ post_opt_step = 0
736
+
737
+ optimizer.log_to_tensorboard(
738
+ interval=1, namespace="post_opt", step=(post_opt_step := post_opt_step + 1)
739
+ )
740
+
741
+ optimizer.pixel_height_logits = torch.from_numpy(pixel_height_logits_init)
742
+ optimizer.best_params["pixel_height_logits"] = torch.from_numpy(
743
+ pixel_height_logits_init
744
+ ).to(device)
745
+ optimizer.target = output_target
746
+ optimizer.pixel_height_labels = torch.tensor(
747
+ pixel_height_labels, dtype=torch.int32, device=device
748
+ )
749
+ if focus_map_proc is not None and focus_map_full is not None:
750
+ optimizer.focus_map = focus_map_full
751
+
752
+ dtype = torch.bfloat16 if not args.mps else torch.float32
753
+ with torch.no_grad():
754
+ with torch.autocast(device.type, dtype=dtype):
755
+ if args.perform_pruning:
756
+ # Adjust pruning_max_colors to account for background and clear filament
757
+ # pruning_max_colors = total filaments needed
758
+ # Need to reserve slots: 1 for background (always), 1 for clear (FlatForge only)
759
+ max_colors_for_pruning = args.pruning_max_colors
760
+
761
+ if args.flatforge:
762
+ # FlatForge: pruning_max_colors = colored + clear + background
763
+ # Reserve 2 slots (1 clear + 1 background)
764
+ max_colors_for_pruning = max(1, args.pruning_max_colors - 2)
765
+ else:
766
+ # Traditional: pruning_max_colors = colored + background
767
+ # Reserve 1 slot for background
768
+ max_colors_for_pruning = max(1, args.pruning_max_colors - 1)
769
+
770
+ post_opt_step = run_pruning(args, max_colors_for_pruning, optimizer, post_opt_step)
771
+
772
+ disc_global, disc_height_image = optimizer.get_discretized_solution(
773
+ best=True
774
+ )
775
+
776
+ final_loss = PruningHelper.get_initial_loss(
777
+ optimizer.best_params["global_logits"].shape[0], optimizer
778
+ )
779
+ with open(os.path.join(args.output_folder, "final_loss.txt"), "w") as f:
780
+ f.write(f"{final_loss}")
781
+
782
+ print("Done. Saving outputs...")
783
+ comp_disc = optimizer.get_best_discretized_image()
784
+ args.max_layers = optimizer.max_layers
785
+
786
+ optimizer.log_to_tensorboard(
787
+ interval=1,
788
+ namespace="post_opt",
789
+ step=(post_opt_step := post_opt_step + 1),
790
+ )
791
+
792
+ comp_disc_np = comp_disc.cpu().numpy().astype(np.uint8)
793
+ comp_disc_np = cv2.cvtColor(comp_disc_np, cv2.COLOR_RGB2BGR)
794
+ cv2.imwrite(
795
+ os.path.join(args.output_folder, "final_model.png"), comp_disc_np
796
+ )
797
+
798
+ # Generate STL files
799
+ if args.flatforge:
800
+ # FlatForge mode: Generate separate STL files for each color
801
+ print("FlatForge mode enabled. Generating separate STL files...")
802
+ generate_flatforge_stls(
803
+ disc_global.cpu().numpy(),
804
+ disc_height_image.cpu().numpy(),
805
+ material_colors_np,
806
+ material_names,
807
+ material_TDs_np,
808
+ args.layer_height,
809
+ args.background_height,
810
+ args.background_color,
811
+ args.stl_output_size,
812
+ args.output_folder,
813
+ cap_layers=args.cap_layers,
814
+ alpha_mask=alpha,
815
+ )
816
+ else:
817
+ # Traditional mode: Generate single STL file
818
+ stl_filename = os.path.join(args.output_folder, "final_model.stl")
819
+ height_map_mm = (
820
+ disc_height_image.cpu().numpy().astype(np.float32)
821
+ ) * args.layer_height
822
+ generate_stl(
823
+ height_map_mm,
824
+ stl_filename,
825
+ args.background_height,
826
+ maximum_x_y_size=args.stl_output_size,
827
+ alpha_mask=alpha,
828
+ )
829
+
830
+ if not args.flatforge:
831
+ background_layers = int(args.background_height // args.layer_height)
832
+ swap_instructions = generate_swap_instructions(
833
+ disc_global.cpu().numpy(),
834
+ disc_height_image.cpu().numpy(),
835
+ args.layer_height,
836
+ background_layers,
837
+ args.background_height,
838
+ material_names,
839
+ getattr(args, "background_material_name", None),
840
+ )
841
+ with open(
842
+ os.path.join(args.output_folder, "swap_instructions.txt"), "w"
843
+ ) as f:
844
+ for line in swap_instructions:
845
+ f.write(line + "\n")
846
+
847
+ project_filename = os.path.join(args.output_folder, "project_file.hfp")
848
+ generate_project_file(
849
+ project_filename,
850
+ args,
851
+ disc_global.cpu().numpy(),
852
+ disc_height_image.cpu().numpy(),
853
+ output_target.shape[1],
854
+ output_target.shape[0],
855
+ os.path.join(args.output_folder, "final_model.stl"),
856
+ args.csv_file,
857
+ )
858
+
859
+ print("All done. Outputs in:", args.output_folder)
860
+ print("Happy Printing!")
861
+ return final_loss
862
+
863
+ @spaces.GPU
864
+ def run_pruning(args, max_colors_for_pruning: int, optimizer: FilamentOptimizer, post_opt_step: int) -> int:
865
+ optimizer.prune(
866
+ max_colors_allowed=max_colors_for_pruning,
867
+ max_swaps_allowed=args.pruning_max_swaps,
868
+ min_layers_allowed=args.min_layers,
869
+ max_layers_allowed=args.pruning_max_layer,
870
+ search_seed=True,
871
+ fast_pruning=args.fast_pruning,
872
+ fast_pruning_percent=args.fast_pruning_percent,
873
+ )
874
+ optimizer.log_to_tensorboard(
875
+ interval=1,
876
+ namespace="post_opt",
877
+ step=(post_opt_step := post_opt_step + 1),
878
+ )
879
+ return post_opt_step
880
+
881
+
882
+ def start(args) -> float:
883
+ """Entry point for a single optimization run.
884
+
885
+ Orchestrates the entire pipeline:
886
+ - Validation & device selection.
887
+ - Material & image loading (+ optional auto background selection).
888
+ - Resolution computation & resizing.
889
+ - Heightmap initialization.
890
+ - Optimizer construction & iterative optimization loop.
891
+ - Post-processing, pruning, and output generation.
892
+
893
+ Args:
894
+ args: Parsed argument namespace.
895
+
896
+ Returns:
897
+ float: Final loss value for this run (after pruning/export).
898
+ """
899
+ if args.num_init_cluster_layers == -1:
900
+ args.num_init_cluster_layers = args.max_layers
901
+
902
+ # check if csv or json is given
903
+ if args.csv_file == "" and args.json_file == "":
904
+ print("Error: No CSV or JSON file given. Please provide one of them.")
905
+ sys.exit(1)
906
+
907
+ device = get_device(args)
908
+
909
+ os.makedirs(args.output_folder, exist_ok=True)
910
+
911
+ perform_basic_check(args)
912
+
913
+ random_seed = set_seed(args)
914
+
915
+ # Load materials (we keep colors_list for potential auto background)
916
+ material_colors_np, material_TDs_np, material_names, colors_list = load_materials(
917
+ args
918
+ )
919
+
920
+ # Read input image early (needed for auto background color)
921
+ img = imread(args.input_image, cv2.IMREAD_UNCHANGED)
922
+ alpha = None
923
+ if img.shape[2] == 4:
924
+ alpha = img[:, :, 3]
925
+ alpha = alpha[..., None]
926
+ img = img[:, :, :3]
927
+
928
+ # Convert image from BGR to RGB for color analysis
929
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
930
+
931
+ # Auto background color selection (optional)
932
+ _auto_select_background_color(
933
+ args, img_rgb, alpha, material_colors_np, material_names, colors_list
934
+ )
935
+
936
+ # Prepare background color tensor and material tensors
937
+ bgr_tuple, background, material_colors, material_TDs = _prepare_background_and_materials(
938
+ args, device, material_colors_np, material_TDs_np
939
+ )
940
+
941
+ # Compute sizes
942
+ computed_output_size, computed_processing_size = _compute_pixel_sizes(args)
943
+
944
+ # Resize alpha if present (match final resolution) after computing size
945
+ if alpha is not None:
946
+ alpha = resize_image(alpha, computed_output_size)
947
+
948
+ # For the final resolution
949
+ output_img_np = resize_image(img_rgb, computed_output_size)
950
+ output_target = torch.tensor(output_img_np, dtype=torch.float32, device=device)
951
+
952
+ # Priority mask handling (full-res)
953
+ focus_map_full = _load_priority_mask(args, output_img_np, device)
954
+
955
+ # Initialize heightmap
956
+ pixel_height_logits_init, global_logits_init, pixel_height_labels = _initialize_heightmap(
957
+ args,
958
+ output_img_np,
959
+ bgr_tuple,
960
+ material_colors_np,
961
+ random_seed,
962
+ )
963
+
964
+ # Prepare processing targets and focus map (processing-res)
965
+ processing_img_np, processing_target, focus_map_proc = _prepare_processing_targets(
966
+ output_img_np, computed_processing_size, device, focus_map_full
967
+ )
968
+
969
+ # Downscale initial logits/labels to processing resolution
970
+ processing_pixel_height_logits_init = cv2.resize(
971
+ src=pixel_height_logits_init,
972
+ interpolation=cv2.INTER_NEAREST,
973
+ dsize=(processing_target.shape[1], processing_target.shape[0]),
974
+ )
975
+ processing_pixel_height_labels = cv2.resize(
976
+ src=pixel_height_labels,
977
+ interpolation=cv2.INTER_NEAREST,
978
+ dsize=(processing_target.shape[1], processing_target.shape[0]),
979
+ )
980
+
981
+ # Apply alpha mask to full-res logits (keep original order/behavior)
982
+ if alpha is not None:
983
+ pixel_height_logits_init[alpha < 128] = -13.815512
984
+
985
+ perception_loss_module = None
986
+
987
+ # Build optimizer
988
+ optimizer = _build_optimizer(
989
+ args,
990
+ processing_target,
991
+ processing_pixel_height_logits_init,
992
+ processing_pixel_height_labels,
993
+ global_logits_init,
994
+ material_colors,
995
+ material_TDs,
996
+ background,
997
+ device,
998
+ perception_loss_module,
999
+ focus_map_proc,
1000
+ )
1001
+
1002
+ # Run optimization loop
1003
+ _run_optimization_loop(optimizer, args, device)
1004
+
1005
+ # Post-process, prune, and export outputs
1006
+ final_loss = _post_optimize_and_export(
1007
+ args,
1008
+ optimizer,
1009
+ pixel_height_logits_init,
1010
+ pixel_height_labels,
1011
+ output_target,
1012
+ alpha,
1013
+ material_colors_np,
1014
+ material_TDs_np,
1015
+ material_names,
1016
+ bgr_tuple,
1017
+ device,
1018
+ focus_map_full,
1019
+ focus_map_proc,
1020
+ )
1021
+
1022
+ return final_loss
1023
+
1024
+
1025
+ def main() -> None:
1026
+ """Support multi-run execution via --best_of; persist best run artifacts.
1027
+
1028
+ If --best_of == 1, simply invokes a single start(). Otherwise:
1029
+ - Creates temporary run subfolders.
1030
+ - Tracks losses, reports statistics (best / median / std).
1031
+ - Moves files from best run folder into the final output folder.
1032
+
1033
+ Note: Memory is periodically reclaimed (gc + CUDA cache clears + closing matplotlib figures).
1034
+ """
1035
+ args = parse_args()
1036
+ final_output_folder = args.output_folder
1037
+ run_best_loss = 1000000000
1038
+ if args.best_of == 1:
1039
+ start(args)
1040
+ else:
1041
+ temp_output_folder = os.path.join(args.output_folder, "temp")
1042
+ ret = []
1043
+ for i in range(args.best_of):
1044
+ try:
1045
+ print(f"Run {i + 1}/{args.best_of}")
1046
+ run_folder = os.path.join(temp_output_folder, f"run_{i + 1}")
1047
+ args.output_folder = run_folder
1048
+ os.makedirs(args.output_folder, exist_ok=True)
1049
+ run_loss = start(args)
1050
+ print(f"Run {i + 1} finished with loss: {run_loss}")
1051
+ if run_loss < run_best_loss:
1052
+ run_best_loss = run_loss
1053
+ print(f"New best loss found: {run_best_loss} in run {i + 1}")
1054
+ ret.append((run_folder, run_loss))
1055
+ torch.cuda.empty_cache()
1056
+ import gc
1057
+
1058
+ gc.collect()
1059
+ torch.cuda.empty_cache()
1060
+ import matplotlib.pyplot as plt
1061
+
1062
+ plt.close("all")
1063
+ except Exception:
1064
+ traceback.print_exc()
1065
+ best_run = min(ret, key=lambda x: x[1])
1066
+ best_run_folder = best_run[0]
1067
+ best_loss = best_run[1]
1068
+
1069
+ losses = [x[1] for x in ret]
1070
+ median_loss = np.median(losses)
1071
+ std_loss = np.std(losses)
1072
+ print(f"Best run folder: {best_run_folder}")
1073
+ print(f"Best run loss: {best_loss}")
1074
+ print(f"Median loss: {median_loss}")
1075
+ print(f"Standard deviation of losses: {std_loss}")
1076
+
1077
+ if not os.path.exists(final_output_folder):
1078
+ os.makedirs(final_output_folder)
1079
+ for file in os.listdir(best_run_folder):
1080
+ src_file = os.path.join(best_run_folder, file)
1081
+ dst_file = os.path.join(final_output_folder, file)
1082
+ if os.path.isfile(src_file):
1083
+ os.rename(src_file, dst_file)
1084
+
1085
+
1086
+ if __name__ == "__main__":
1087
+ main()