hvoss-techfak commited on
Commit
99553eb
·
2 Parent(s): 2dd4885 3ac793e

Merge remote-tracking branch 'origin/main'

Browse files
Files changed (3) hide show
  1. app.py +60 -30
  2. auto_forge.py +1089 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import json
2
  import string
 
3
  import uuid
4
  import os
5
  import logging
6
  import zipfile
7
- import importlib
8
  import wandb
9
  from contextlib import redirect_stdout, redirect_stderr
10
- import spaces
11
 
12
 
13
  USE_WANDB = "WANDB_API_KEY" in os.environ
@@ -99,7 +99,7 @@ def get_script_args_info(exclude_args=None):
99
  {
100
  "name": "--iterations",
101
  "type": "number",
102
- "default": 4000,
103
  "help": "Number of optimization iterations",
104
  },
105
  {
@@ -160,7 +160,7 @@ def get_script_args_info(exclude_args=None):
160
  {
161
  "name": "--pruning_max_swaps",
162
  "type": "number",
163
- "default": 20,
164
  "precision": 0,
165
  "help": "Max number of swaps allowed after pruning",
166
  },
@@ -183,7 +183,7 @@ def get_script_args_info(exclude_args=None):
183
  {
184
  "name": "--learning_rate_warmup_fraction",
185
  "type": "slider",
186
- "default": 0.2,
187
  "min": 0.0,
188
  "max": 1.0,
189
  "step": 0.01,
@@ -215,7 +215,7 @@ def get_script_args_info(exclude_args=None):
215
  {
216
  "name": "--num_init_rounds",
217
  "type": "number",
218
- "default": 8,
219
  "precision": 0,
220
  "help": "Number of rounds to choose the starting height map from.",
221
  },
@@ -296,16 +296,23 @@ else:
296
  def run_autoforge_process(cmd, log_path):
297
  from joblib import parallel_backend
298
  cli_args = cmd[1:]
299
- autoforge_main = importlib.import_module("autoforge.__main__")
300
 
301
  exit_code = 0
302
- with open(log_path, "w", buffering=1, encoding="utf-8") as log_f, \
303
- redirect_stdout(log_f), redirect_stderr(log_f), parallel_backend("threading", n_jobs=-1):
 
 
 
 
304
  try:
 
 
 
 
305
  sys.argv = ["autoforge"] + cli_args
306
- autoforge_main.main()
307
  except SystemExit as e:
308
- exit_code = e.code
309
  except Exception as e:
310
  log_f.write(f"\nERROR: {e}\n")
311
  exit_code = -1
@@ -673,7 +680,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
673
  visible=False,
674
  )
675
 
676
- @spaces.GPU(duration=150)
677
  def execute_autoforge_script(
678
  current_filaments_df_state_val, input_image, *accordion_param_values
679
  ):
@@ -768,24 +774,48 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
768
  import threading
769
 
770
  class Worker(threading.Thread):
771
- def __init__(self, cmd, log_path):
772
- super().__init__(daemon=True)
773
- self.cmd, self.log_path = cmd, log_path
774
- self.returncode = None
775
- self.exc = None
776
-
777
- def run(self):
778
- try:
779
- self.returncode = run_autoforge_process(self.cmd, self.log_path)
780
- except Exception as e:
781
- self.exc = e
782
- with open(self.log_path, "a", encoding="utf-8") as lf:
783
- lf.write(
784
- "\nERROR: {}. This usually means there was no GPU or the process took too long.\n".format(
785
- exc_text(e)
786
- )
787
- )
788
- self.returncode = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
 
790
  try:
791
  worker = Worker(command, log_file)
 
1
  import json
2
  import string
3
+ import traceback
4
  import uuid
5
  import os
6
  import logging
7
  import zipfile
 
8
  import wandb
9
  from contextlib import redirect_stdout, redirect_stderr
10
+ import auto_forge
11
 
12
 
13
  USE_WANDB = "WANDB_API_KEY" in os.environ
 
99
  {
100
  "name": "--iterations",
101
  "type": "number",
102
+ "default": 6000,
103
  "help": "Number of optimization iterations",
104
  },
105
  {
 
160
  {
161
  "name": "--pruning_max_swaps",
162
  "type": "number",
163
+ "default": 50,
164
  "precision": 0,
165
  "help": "Max number of swaps allowed after pruning",
166
  },
 
183
  {
184
  "name": "--learning_rate_warmup_fraction",
185
  "type": "slider",
186
+ "default": 0.01,
187
  "min": 0.0,
188
  "max": 1.0,
189
  "step": 0.01,
 
215
  {
216
  "name": "--num_init_rounds",
217
  "type": "number",
218
+ "default": 32,
219
  "precision": 0,
220
  "help": "Number of rounds to choose the starting height map from.",
221
  },
 
296
  def run_autoforge_process(cmd, log_path):
297
  from joblib import parallel_backend
298
  cli_args = cmd[1:]
 
299
 
300
  exit_code = 0
301
+ # Ensure local project dir is first on sys.path so `import auto_forge` imports the file in this repo
302
+ script_dir = os.path.dirname(os.path.abspath(__file__))
303
+ if script_dir not in sys.path:
304
+ sys.path.insert(0, script_dir)
305
+
306
+ with open(log_path, "w", buffering=1, encoding="utf-8") as log_f, redirect_stdout(log_f), redirect_stderr(log_f), parallel_backend("threading", n_jobs=4):
307
  try:
308
+ # Force a fresh import of the local module by removing any cached module
309
+ if "auto_forge" in sys.modules:
310
+ del sys.modules["auto_forge"]
311
+ auto_forge = __import__("auto_forge")
312
  sys.argv = ["autoforge"] + cli_args
313
+ auto_forge.main()
314
  except SystemExit as e:
315
+ exit_code = e.code if isinstance(e.code, int) or e.code is None else 0
316
  except Exception as e:
317
  log_f.write(f"\nERROR: {e}\n")
318
  exit_code = -1
 
680
  visible=False,
681
  )
682
 
 
683
  def execute_autoforge_script(
684
  current_filaments_df_state_val, input_image, *accordion_param_values
685
  ):
 
774
  import threading
775
 
776
  class Worker(threading.Thread):
777
+ def __init__(self, cmd, log_path):
778
+ super().__init__(daemon=True)
779
+ self.cmd, self.log_path = cmd, log_path
780
+ self.returncode = None
781
+ self.exc = None
782
+
783
+ def run(self):
784
+ """Import and run the local `auto_forge.py` module in-process.
785
+
786
+ We load the script from the project dir as a fresh module using
787
+ importlib.util.spec_from_file_location to ensure decorators like
788
+ @spaces.GPU are executed at import time. Stdout/stderr are redirected
789
+ to the run log to preserve the live console stream.
790
+ """
791
+ try:
792
+ # Ensure the project directory is on sys.path so a plain `import auto_forge` finds the local file
793
+ script_dir = os.path.dirname(os.path.abspath(__file__))
794
+ if script_dir not in sys.path:
795
+ sys.path.insert(0, script_dir)
796
+
797
+ with open(self.log_path, "a", encoding="utf-8") as lf, redirect_stdout(lf), redirect_stderr(lf):
798
+ try:
799
+ # Provide argv for the module's CLI parsing and call main()
800
+ sys.argv = ["autoforge"] + (self.cmd[1:] if len(self.cmd) > 1 else [])
801
+ auto_forge.main()
802
+ self.returncode = 0
803
+ except Exception as e:
804
+ lf.write(f"\nERROR while importing/running auto_forge: {exc_text(e)}\n")
805
+ traceback.print_exc()
806
+ self.exc = e
807
+ if isinstance(e, SystemExit):
808
+ self.returncode = e.code if isinstance(e.code, int) or e.code is None else 1
809
+ else:
810
+ self.returncode = -1
811
+ except Exception as outer_e:
812
+ self.exc = outer_e
813
+ try:
814
+ with open(self.log_path, "a", encoding="utf-8") as lf:
815
+ lf.write(f"\nERROR loading autoforge.auto_forge: {exc_text(outer_e)}\n")
816
+ except Exception:
817
+ pass
818
+ self.returncode = -1
819
 
820
  try:
821
  worker = Worker(command, log_file)
auto_forge.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """auto_forge.py
2
+
3
+ High-level orchestration module for the AutoForge optimization pipeline.
4
+
5
+ Responsibilities:
6
+ - Parse CLI / config file arguments.
7
+ - Load image and material properties.
8
+ - (Optionally) auto-select a background filament color based on dominant image color.
9
+ - Initialize a height map using one of several strategies (k-means clustering or depth estimation).
10
+ - Build and run the filament optimization loop (differentiable + periodic discretization checks).
11
+ - Optionally prune the solution to respect practical printer constraints (materials, swaps, layers).
12
+ - Export final artifacts: preview PNG, STL(s), swap instructions, project file, metadata.
13
+
14
+ The implementation intentionally keeps side-effects (disk writes / prints) order-stable to
15
+ preserve prior behavior. Helper functions are factored out for readability; no functional
16
+ behavior should have changed relative to the previous monolithic version.
17
+ """
18
+ import argparse
19
+ import sys
20
+ import os
21
+ import traceback
22
+ from typing import Optional, Tuple, List
23
+
24
+ import configargparse
25
+ import cv2
26
+ try:
27
+ import spaces
28
+ except Exception:
29
+ # Provide a minimal shim so @spaces.GPU can be used even when 'spaces' isn't installed.
30
+ def _spaces_noop_decorator(fn=None):
31
+ # Support usage as @spaces.GPU or @spaces.GPU()
32
+ if fn is None:
33
+ def _inner(f):
34
+ return f
35
+ return _inner
36
+ return fn
37
+
38
+ class _DummySpaces:
39
+ GPU = staticmethod(_spaces_noop_decorator)
40
+
41
+ spaces = _DummySpaces()
42
+ import torch
43
+ import numpy as np
44
+ from tqdm import tqdm
45
+
46
+ from autoforge.Helper import PruningHelper
47
+ from autoforge.Helper.FilamentHelper import hex_to_rgb, load_materials
48
+ from autoforge.Helper.Heightmaps.ChristofidesHeightMap import (
49
+ run_init_threads,
50
+ )
51
+
52
+ from autoforge.Helper.ImageHelper import resize_image, imread
53
+ from autoforge.Helper.OtherHelper import set_seed, perform_basic_check, get_device
54
+ from autoforge.Helper.OutputHelper import (
55
+ generate_stl,
56
+ generate_swap_instructions,
57
+ generate_project_file,
58
+ generate_flatforge_stls,
59
+ )
60
+ from autoforge.Modules.Optimizer import FilamentOptimizer
61
+
62
+ # check if we can use torch.set_float32_matmul_precision('high')
63
+ if torch.__version__ >= "2.0.0":
64
+ try:
65
+ torch.set_float32_matmul_precision("high")
66
+ except Exception as e:
67
+ print("Warning: Could not set float32 matmul precision to high. Error:", e)
68
+ pass
69
+
70
+
71
+ def parse_args() -> argparse.Namespace:
72
+ """Create and parse command-line & config-file arguments.
73
+
74
+ Returns:
75
+ argparse.Namespace: Populated arguments structure. Some parameters may be adjusted later
76
+ (e.g., num_init_cluster_layers when -1 to infer from max_layers).
77
+ """
78
+ parser = configargparse.ArgParser()
79
+ parser.add_argument("--config", is_config_file=True, help="Path to config file")
80
+
81
+ parser.add_argument(
82
+ "--input_image", type=str, required=True, help="Path to input image"
83
+ )
84
+ parser.add_argument(
85
+ "--csv_file",
86
+ type=str,
87
+ default="",
88
+ help="Path to CSV file with material data",
89
+ )
90
+ parser.add_argument(
91
+ "--json_file",
92
+ type=str,
93
+ default="",
94
+ help="Path to json file with material data",
95
+ )
96
+ parser.add_argument(
97
+ "--output_folder", type=str, default="output", help="Folder to write outputs"
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--iterations", type=int, default=6000, help="Number of optimization iterations"
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--warmup_fraction",
106
+ type=float,
107
+ default=1.0,
108
+ help="Fraction of iterations for keeping the tau at the initial value",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--learning_rate_warmup_fraction",
113
+ type=float,
114
+ default=0.01,
115
+ help="Fraction of iterations that the learning rate is increasing (warmup)",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--init_tau",
120
+ type=float,
121
+ default=1.0,
122
+ help="Initial tau value for Gumbel-Softmax",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--final_tau",
127
+ type=float,
128
+ default=0.01,
129
+ help="Final tau value for Gumbel-Softmax",
130
+ )
131
+
132
+ parser.add_argument(
133
+ "--learning_rate",
134
+ type=float,
135
+ default=0.015,
136
+ help="Learning rate for optimization",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--layer_height", type=float, default=0.04, help="Layer thickness in mm"
141
+ )
142
+
143
+ parser.add_argument(
144
+ "--max_layers", type=int, default=75, help="Maximum number of layers"
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--min_layers",
149
+ type=int,
150
+ default=0,
151
+ help="Minimum number of layers. Used for pruning.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--background_height",
156
+ type=float,
157
+ default=0.24,
158
+ help="Height of the background in mm",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--background_color", type=str, default="#000000", help="Background color"
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--auto_background_color",
167
+ default=True,
168
+ help="Automatically set background color to the closest filament color matching the dominant image color. Overrides --background_color.",
169
+ )
170
+
171
+ parser.add_argument(
172
+ "--visualize",
173
+ type=bool,
174
+ default=True,
175
+ help="Enable visualization during optimization",
176
+ action=argparse.BooleanOptionalAction,
177
+ )
178
+
179
+ # Instead of an output_size parameter, we use stl_output_size and nozzle_diameter.
180
+ parser.add_argument(
181
+ "--stl_output_size",
182
+ type=int,
183
+ default=150,
184
+ help="Size of the longest dimension of the output STL file in mm",
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--processing_reduction_factor",
189
+ type=int,
190
+ default=2,
191
+ help="Reduction factor for reducing the processing size compared to the output size (default: 2 - half resolution)",
192
+ )
193
+
194
+ parser.add_argument(
195
+ "--nozzle_diameter",
196
+ type=float,
197
+ default=0.4,
198
+ help="Diameter of the printer nozzle in mm (details smaller than half this value will be ignored)",
199
+ )
200
+
201
+ parser.add_argument(
202
+ "--early_stopping",
203
+ type=int,
204
+ default=2000,
205
+ help="Number of steps without improvement before stopping",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--perform_pruning",
210
+ type=bool,
211
+ default=True,
212
+ help="Perform pruning after optimization",
213
+ action=argparse.BooleanOptionalAction,
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--fast_pruning",
218
+ type=bool,
219
+ default=True,
220
+ help="Use fast pruning method",
221
+ action=argparse.BooleanOptionalAction,
222
+ )
223
+ parser.add_argument(
224
+ "--fast_pruning_percent",
225
+ type=float,
226
+ default=0.25,
227
+ help="Percentage of increment search for fast pruning",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--pruning_max_colors",
232
+ type=int,
233
+ default=100,
234
+ help="Max number of colors allowed after pruning",
235
+ )
236
+ parser.add_argument(
237
+ "--pruning_max_swaps",
238
+ type=int,
239
+ default=100,
240
+ help="Max number of swaps allowed after pruning",
241
+ )
242
+
243
+ parser.add_argument(
244
+ "--pruning_max_layer",
245
+ type=int,
246
+ default=75,
247
+ help="Max number of layers allowed after pruning",
248
+ )
249
+
250
+ parser.add_argument(
251
+ "--random_seed",
252
+ type=int,
253
+ default=0,
254
+ help="Specify the random seed, or use 0 for automatic generation",
255
+ )
256
+
257
+ parser.add_argument(
258
+ "--mps",
259
+ action="store_true",
260
+ help="Use the Metal Performance Shaders (MPS) backend, if available.",
261
+ )
262
+
263
+ parser.add_argument(
264
+ "--run_name", type=str, help="Name of the run used for TensorBoard logging"
265
+ )
266
+
267
+ parser.add_argument(
268
+ "--tensorboard", action="store_true", help="Enable TensorBoard logging"
269
+ )
270
+
271
+ parser.add_argument(
272
+ "--num_init_rounds",
273
+ type=int,
274
+ default=16,
275
+ help="Number of rounds to choose the starting height map from.",
276
+ )
277
+
278
+ parser.add_argument(
279
+ "--num_init_cluster_layers",
280
+ type=int,
281
+ default=-1,
282
+ help="Number of layers to cluster the image into.",
283
+ )
284
+
285
+ parser.add_argument(
286
+ "--disable_visualization_for_gradio",
287
+ type=int,
288
+ default=0,
289
+ help="Simple switch to disable the matplotlib render window for gradio rendering.",
290
+ )
291
+
292
+ parser.add_argument(
293
+ "--best_of",
294
+ type=int,
295
+ default=1,
296
+ help="Run the program multiple times and output the best result.",
297
+ )
298
+
299
+ parser.add_argument(
300
+ "--discrete_check",
301
+ type=int,
302
+ default=100,
303
+ help="Modulo how often to check for new discrete results.",
304
+ )
305
+
306
+ parser.add_argument(
307
+ "--flatforge",
308
+ type=bool,
309
+ default=False,
310
+ help="Enable FlatForge mode to generate separate STL files for each color",
311
+ action=argparse.BooleanOptionalAction,
312
+ )
313
+
314
+ parser.add_argument(
315
+ "--cap_layers",
316
+ type=int,
317
+ default=0,
318
+ help="Number of complete clear/transparent layers to add on top in FlatForge mode",
319
+ )
320
+
321
+ # New: choose heightmap initializer
322
+ parser.add_argument(
323
+ "--init_heightmap_method",
324
+ type=str,
325
+ choices=["kmeans", "depth"],
326
+ default="kmeans",
327
+ help="Initializer for the height map: 'kmeans' (fast, default) or 'depth' (requires transformers).",
328
+ )
329
+ # New priority mask argument (optional)
330
+ parser.add_argument(
331
+ "--priority_mask",
332
+ type=str,
333
+ default="",
334
+ help="Optional path to a priority mask image (same dimensions as input image). Non-empty: apply weighted loss (0.1 outside, 1.0 at max inside).",
335
+ )
336
+
337
+ args = parser.parse_args()
338
+ return args
339
+
340
+
341
+ def _compute_dominant_image_color(
342
+ img_rgb: np.ndarray, alpha: Optional[np.ndarray]
343
+ ) -> Optional[Tuple[str, np.ndarray]]:
344
+ """Compute an approximate dominant color of the input image.
345
+
346
+ Strategy:
347
+ - Optionally downscale very large images for efficiency.
348
+ - Ignore (mostly) transparent pixels if alpha channel is provided.
349
+ - Use frequency counts (np.unique) over exact RGB triplets.
350
+
351
+ Args:
352
+ img_rgb: Image array in RGB order (H,W,3) uint8.
353
+ alpha: Optional alpha mask (H,W,1) or (H,W) uint8; pixels <128 are ignored.
354
+
355
+ Returns:
356
+ (hex_color, normalized_rgb) where hex_color is a '#RRGGBB' string and normalized_rgb
357
+ is float32 in [0,1]^3. Returns None if no valid pixels remain.
358
+ """
359
+ try:
360
+ # Downscale if needed (max side 300 px)
361
+ h, w = img_rgb.shape[:2]
362
+ max_side = max(h, w)
363
+ target_side = 300
364
+ alpha_small: Optional[np.ndarray] = None
365
+ if max_side > target_side:
366
+ scale = target_side / max_side
367
+ new_w = max(1, int(w * scale))
368
+ new_h = max(1, int(h * scale))
369
+ img_small = cv2.resize(
370
+ img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA
371
+ )
372
+ if alpha is not None:
373
+ alpha_small = cv2.resize(
374
+ alpha, (new_w, new_h), interpolation=cv2.INTER_NEAREST
375
+ )
376
+ else:
377
+ img_small = img_rgb
378
+ alpha_small = alpha
379
+ # Build mask for valid pixels (ignore transparent)
380
+ if alpha_small is not None:
381
+ valid_mask = (
382
+ alpha_small[..., 0] if alpha_small.ndim == 3 else alpha_small
383
+ ) >= 128
384
+ else:
385
+ valid_mask = np.ones(img_small.shape[:2], dtype=bool)
386
+ if valid_mask.sum() == 0:
387
+ return None
388
+ pixels = img_small[valid_mask]
389
+ # Use np.unique to find most frequent RGB triplet
390
+ unique_colors, counts = np.unique(
391
+ pixels.reshape(-1, 3), axis=0, return_counts=True
392
+ )
393
+ idx = int(np.argmax(counts))
394
+ dom_rgb_uint8 = unique_colors[idx]
395
+ dom_rgb_norm = dom_rgb_uint8.astype(np.float32) / 255.0
396
+ hex_color = "#" + "".join(f"{c:02X}" for c in dom_rgb_uint8)
397
+ return hex_color, dom_rgb_norm
398
+ except Exception:
399
+ traceback.print_exc()
400
+ return None
401
+
402
+
403
+ def _auto_select_background_color(
404
+ args,
405
+ img_rgb: np.ndarray,
406
+ alpha: Optional[np.ndarray],
407
+ material_colors_np: np.ndarray,
408
+ material_names: List[str],
409
+ colors_list: List[str],
410
+ ) -> None:
411
+ """Optionally override the user-provided background color with a closest material color.
412
+
413
+ When --auto_background_color is set:
414
+ - Determine dominant image color (ignoring transparency).
415
+ - Find closest filament (Euclidean in normalized RGB).
416
+ - Persist metadata to 'auto_background_color.txt'.
417
+
418
+ Side effects: Mutates args.background_color and attaches background_material_* fields.
419
+
420
+ Args:
421
+ args: Global argument namespace (mutated).
422
+ img_rgb: Full-resolution RGB image (uint8).
423
+ alpha: Optional alpha channel for transparency filtering.
424
+ material_colors_np: (N,3) array of filament RGB colors in [0,1].
425
+ material_names: List of filament names.
426
+ colors_list: List of filament hex color strings (#RRGGBB).
427
+ """
428
+ if not args.auto_background_color:
429
+ return
430
+ res = _compute_dominant_image_color(img_rgb, alpha)
431
+ if res is not None:
432
+ dominant_hex, dominant_rgb = res
433
+ diffs = material_colors_np - dominant_rgb[None, :]
434
+ dists = np.linalg.norm(diffs, axis=1)
435
+ closest_idx = int(np.argmin(dists))
436
+ chosen_hex = colors_list[closest_idx]
437
+ print(
438
+ f"Auto background color: dominant image color {dominant_hex} -> closest filament {chosen_hex} (index {closest_idx})."
439
+ )
440
+ args.background_color = chosen_hex
441
+ args.background_material_index = closest_idx
442
+ try:
443
+ args.background_material_name = material_names[closest_idx]
444
+ except Exception:
445
+ args.background_material_name = None
446
+ try:
447
+ with open(
448
+ os.path.join(args.output_folder, "auto_background_color.txt"), "w"
449
+ ) as f:
450
+ f.write(f"dominant_image_color={dominant_hex}\n")
451
+ f.write(f"chosen_filament_color={chosen_hex}\n")
452
+ f.write(f"closest_filament_index={closest_idx}\n")
453
+ if getattr(args, "background_material_name", None):
454
+ f.write(
455
+ f"closest_filament_name={args.background_material_name}\n"
456
+ )
457
+ except Exception:
458
+ traceback.print_exc()
459
+ else:
460
+ print(
461
+ "Warning: Auto background color computation failed; using provided --background_color."
462
+ )
463
+
464
+
465
+ def _prepare_background_and_materials(
466
+ args, device: torch.device, material_colors_np: np.ndarray, material_TDs_np: np.ndarray
467
+ ) -> Tuple[Tuple[int, int, int], torch.Tensor, torch.Tensor, torch.Tensor]:
468
+ """Create torch tensors for materials & background color.
469
+
470
+ Args:
471
+ args: Global arguments (uses background_color hex string).
472
+ device: Torch device for tensor placement.
473
+ material_colors_np: (N,3) float32 array in [0,1].
474
+ material_TDs_np: (N,*) array of material transmission / diffusion parameters.
475
+
476
+ Returns:
477
+ (bgr_tuple_uint8, background_tensor, material_colors_tensor, material_TDs_tensor)
478
+ """
479
+ bgr_tuple = hex_to_rgb(args.background_color)
480
+ background = torch.tensor(bgr_tuple, dtype=torch.float32, device=device)
481
+ material_colors = torch.tensor(
482
+ material_colors_np, dtype=torch.float32, device=device
483
+ )
484
+ material_TDs = torch.tensor(material_TDs_np, dtype=torch.float32, device=device)
485
+ return bgr_tuple, background, material_colors, material_TDs
486
+
487
+
488
+ def _compute_pixel_sizes(args) -> Tuple[int, int]:
489
+ """Derive pixel dimensions for solving vs. output STL size.
490
+
491
+ We oversample relative to nozzle_diameter to capture detail, then optionally downscale
492
+ for the differentiable optimization pass.
493
+
494
+ Returns:
495
+ (computed_output_size, computed_processing_size)
496
+ """
497
+ computed_output_size = int(round(args.stl_output_size * 2 / args.nozzle_diameter))
498
+ computed_processing_size = int(
499
+ round(computed_output_size / args.processing_reduction_factor)
500
+ )
501
+ print(f"Computed solving pixel size: {computed_output_size}")
502
+ return computed_output_size, computed_processing_size
503
+
504
+
505
+ def _load_priority_mask(
506
+ args, output_img_np: np.ndarray, device: torch.device
507
+ ) -> Optional[torch.Tensor]:
508
+ """Load and resize a priority / focus mask if provided.
509
+
510
+ The mask scales heights during initialization and can later weight loss terms.
511
+
512
+ Behavior:
513
+ - Reads image; converts RGBA/RGB to grayscale.
514
+ - Resizes to full-resolution output size.
515
+ - Persists a diagnostic PNG after normalization.
516
+
517
+ Returns:
518
+ focus_map_full: Float32 tensor (H,W) in [0,1] or None if no mask provided.
519
+ """
520
+ focus_map_full = None
521
+ if args.priority_mask != "":
522
+ pm = imread(args.priority_mask, cv2.IMREAD_UNCHANGED)
523
+ if pm.ndim == 3:
524
+ if pm.shape[2] == 4:
525
+ pm = pm[:, :, :3]
526
+ pm = cv2.cvtColor(pm, cv2.COLOR_BGR2GRAY)
527
+ tgt_h, tgt_w = output_img_np.shape[:2]
528
+ pm_resized = cv2.resize(pm, (tgt_w, tgt_h), interpolation=cv2.INTER_LINEAR)
529
+ pm_float = pm_resized.astype(np.float32) / 255.0
530
+ focus_map_full = torch.tensor(pm_float, dtype=torch.float32, device=device)
531
+ cv2.imwrite(
532
+ os.path.join(args.output_folder, "priority_mask_resized.png"),
533
+ (pm_float * 255).astype(np.uint8),
534
+ )
535
+ return focus_map_full
536
+
537
+
538
+ def _initialize_heightmap(
539
+ args,
540
+ output_img_np: np.ndarray,
541
+ bgr_tuple: Tuple[int, int, int],
542
+ material_colors_np: np.ndarray,
543
+ random_seed: int,
544
+ ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
545
+ """Initialize the height map logits & labels using selected method.
546
+
547
+ Methods:
548
+ depth : Uses an external depth estimation model (requires transformers).
549
+ kmeans : Clusters pixel colors into layer assignments (default).
550
+
551
+ Returns:
552
+ pixel_height_logits_init: (H,W) float32 numpy array of raw logits.
553
+ global_logits_init : (L,*) global logits array or None (depth variant may not use it).
554
+ pixel_height_labels : (H,W) int array of discrete initial layer indices.
555
+ """
556
+ print("Initalizing height map. This can take a moment...")
557
+ if args.init_heightmap_method == "depth":
558
+ try:
559
+ from autoforge.Helper.Heightmaps.DepthEstimateHeightMap import (
560
+ init_height_map_depth_color_adjusted,
561
+ )
562
+ except Exception:
563
+ print(
564
+ "Error: depth initializer requested but could not be imported. Install 'transformers' and try again.",
565
+ file=sys.stderr,
566
+ )
567
+ raise
568
+ pixel_height_logits_init, pixel_height_labels = (
569
+ init_height_map_depth_color_adjusted(
570
+ output_img_np,
571
+ args.max_layers,
572
+ random_seed=random_seed,
573
+ focus_map=None,
574
+ )
575
+ )
576
+ global_logits_init = None
577
+ else:
578
+ pixel_height_logits_init, global_logits_init, pixel_height_labels = (
579
+ run_init_threads(
580
+ output_img_np,
581
+ args.max_layers,
582
+ args.layer_height,
583
+ bgr_tuple,
584
+ random_seed=random_seed,
585
+ num_threads=4,
586
+ num_runs=args.num_init_rounds,
587
+ init_method="kmeans",
588
+ cluster_layers=args.num_init_cluster_layers,
589
+ material_colors=material_colors_np,
590
+ focus_map=None,
591
+ )
592
+ )
593
+
594
+ return pixel_height_logits_init, global_logits_init, pixel_height_labels
595
+
596
+
597
+ def _prepare_processing_targets(
598
+ output_img_np: np.ndarray,
599
+ computed_processing_size: int,
600
+ device: torch.device,
601
+ focus_map_full: Optional[torch.Tensor],
602
+ ) -> Tuple[np.ndarray, torch.Tensor, Optional[torch.Tensor]]:
603
+ """Create downscaled optimization target & focus map for faster iterations.
604
+
605
+ Args:
606
+ output_img_np: Full-resolution RGB image (float or uint8 expected).
607
+ computed_processing_size: Target square size for processing (maintains aspect via resize helper).
608
+ device: Torch device.
609
+ focus_map_full: Optional full-resolution focus map tensor.
610
+
611
+ Returns:
612
+ processing_img_np : Downscaled numpy image (H_p,W_p,3).
613
+ processing_target : Torch tensor version (float32) on device.
614
+ focus_map_proc : Optional downscaled focus map tensor (H_p,W_p).
615
+ """
616
+ processing_img_np = resize_image(output_img_np, computed_processing_size)
617
+ processing_target = torch.tensor(
618
+ processing_img_np, dtype=torch.float32, device=device
619
+ )
620
+
621
+ focus_map_proc = None
622
+ if focus_map_full is not None:
623
+ fm_proc_np = cv2.resize(
624
+ focus_map_full.cpu().numpy().astype(np.float32),
625
+ (processing_target.shape[1], processing_target.shape[0]),
626
+ interpolation=cv2.INTER_LINEAR,
627
+ )
628
+ focus_map_proc = torch.tensor(fm_proc_np, dtype=torch.float32, device=device)
629
+
630
+ return processing_img_np, processing_target, focus_map_proc
631
+
632
+
633
+ def _build_optimizer(
634
+ args,
635
+ processing_target: torch.Tensor,
636
+ processing_pixel_height_logits_init: np.ndarray,
637
+ processing_pixel_height_labels: np.ndarray,
638
+ global_logits_init,
639
+ material_colors: torch.Tensor,
640
+ material_TDs: torch.Tensor,
641
+ background: torch.Tensor,
642
+ device: torch.device,
643
+ perception_loss_module,
644
+ focus_map_proc: Optional[torch.Tensor],
645
+ ) -> FilamentOptimizer:
646
+ """Instantiate the FilamentOptimizer with initial tensors and configuration.
647
+
648
+ Args mirror the optimizer's constructor; this function simply centralizes assembly.
649
+
650
+ Returns:
651
+ FilamentOptimizer: Ready-to-run optimizer instance.
652
+ """
653
+ optimizer = FilamentOptimizer(
654
+ args=args,
655
+ target=processing_target,
656
+ pixel_height_logits_init=processing_pixel_height_logits_init,
657
+ pixel_height_labels=processing_pixel_height_labels,
658
+ global_logits_init=global_logits_init,
659
+ material_colors=material_colors,
660
+ material_TDs=material_TDs,
661
+ background=background,
662
+ device=device,
663
+ perception_loss_module=perception_loss_module,
664
+ focus_map=focus_map_proc,
665
+ )
666
+ return optimizer
667
+
668
+ @spaces.GPU
669
+ def _run_optimization_loop(optimizer: FilamentOptimizer, args, device: torch.device) -> None:
670
+ """Execute the main gradient-based optimization iterations.
671
+
672
+ Features:
673
+ - Automatic mixed precision (bfloat16 unless MPS).
674
+ - Periodic visualization & tensorboard logging (every 100 iterations).
675
+ - Discrete solution snapshots controlled via --discrete_check.
676
+ - Early stopping after a patience window (--early_stopping).
677
+
678
+ Args:
679
+ optimizer: Configured FilamentOptimizer instance.
680
+ args: Global argument namespace.
681
+ device: Torch device for autocast context.
682
+ """
683
+ print("Starting optimization...")
684
+ tbar = tqdm(range(args.iterations))
685
+ dtype = torch.bfloat16 if not args.mps else torch.float32
686
+ with torch.autocast(device.type, dtype=dtype):
687
+ for i in tbar:
688
+ loss_val = optimizer.step(record_best=i % args.discrete_check == 0)
689
+
690
+ optimizer.visualize(interval=100)
691
+ optimizer.log_to_tensorboard(interval=100)
692
+
693
+ if (i + 1) % 100 == 0:
694
+ tbar.set_description(
695
+ f"Iteration {i + 1}, Loss = {loss_val:.4f}, best validation Loss = {optimizer.best_discrete_loss:.4f}, learning_rate= {optimizer.current_learning_rate:.6f}"
696
+ )
697
+ if (
698
+ optimizer.best_step is not None
699
+ and optimizer.num_steps_done - optimizer.best_step > args.early_stopping
700
+ ):
701
+ print(
702
+ "Early stopping after",
703
+ args.early_stopping,
704
+ "steps without improvement.",
705
+ )
706
+ break
707
+
708
+
709
+
710
+ def _post_optimize_and_export(
711
+ args,
712
+ optimizer: FilamentOptimizer,
713
+ pixel_height_logits_init: np.ndarray,
714
+ pixel_height_labels: np.ndarray,
715
+ output_target: torch.Tensor,
716
+ alpha: Optional[np.ndarray],
717
+ material_colors_np: np.ndarray,
718
+ material_TDs_np: np.ndarray,
719
+ material_names: List[str],
720
+ bgr_tuple: Tuple[int, int, int],
721
+ device: torch.device,
722
+ focus_map_full: Optional[torch.Tensor],
723
+ focus_map_proc: Optional[torch.Tensor],
724
+ ) -> float:
725
+ """Finalize solution, optionally prune, and write all output artifacts.
726
+
727
+ Steps:
728
+ - Restore full-resolution logits to optimizer and (optionally) height residual.
729
+ - Replace focus map with full-res version if used.
730
+ - Perform pruning (respecting color slots for background & clear in FlatForge mode).
731
+ - Compute final loss estimate and persist to file.
732
+ - Export preview PNG, STL(s), swap instructions & project file.
733
+
734
+ Returns:
735
+ float: The final reported loss (post-pruning).
736
+ """
737
+ post_opt_step = 0
738
+
739
+ optimizer.log_to_tensorboard(
740
+ interval=1, namespace="post_opt", step=(post_opt_step := post_opt_step + 1)
741
+ )
742
+
743
+ optimizer.pixel_height_logits = torch.from_numpy(pixel_height_logits_init)
744
+ optimizer.best_params["pixel_height_logits"] = torch.from_numpy(
745
+ pixel_height_logits_init
746
+ ).to(device)
747
+ optimizer.target = output_target
748
+ optimizer.pixel_height_labels = torch.tensor(
749
+ pixel_height_labels, dtype=torch.int32, device=device
750
+ )
751
+ if focus_map_proc is not None and focus_map_full is not None:
752
+ optimizer.focus_map = focus_map_full
753
+
754
+ dtype = torch.bfloat16 if not args.mps else torch.float32
755
+ with torch.no_grad():
756
+ with torch.autocast(device.type, dtype=dtype):
757
+ if args.perform_pruning:
758
+ # Adjust pruning_max_colors to account for background and clear filament
759
+ # pruning_max_colors = total filaments needed
760
+ # Need to reserve slots: 1 for background (always), 1 for clear (FlatForge only)
761
+ max_colors_for_pruning = args.pruning_max_colors
762
+
763
+ if args.flatforge:
764
+ # FlatForge: pruning_max_colors = colored + clear + background
765
+ # Reserve 2 slots (1 clear + 1 background)
766
+ max_colors_for_pruning = max(1, args.pruning_max_colors - 2)
767
+ else:
768
+ # Traditional: pruning_max_colors = colored + background
769
+ # Reserve 1 slot for background
770
+ max_colors_for_pruning = max(1, args.pruning_max_colors - 1)
771
+
772
+ post_opt_step = run_pruning(args, max_colors_for_pruning, optimizer, post_opt_step)
773
+
774
+ disc_global, disc_height_image = optimizer.get_discretized_solution(
775
+ best=True
776
+ )
777
+
778
+ final_loss = PruningHelper.get_initial_loss(
779
+ optimizer.best_params["global_logits"].shape[0], optimizer
780
+ )
781
+ with open(os.path.join(args.output_folder, "final_loss.txt"), "w") as f:
782
+ f.write(f"{final_loss}")
783
+
784
+ print("Done. Saving outputs...")
785
+ comp_disc = optimizer.get_best_discretized_image()
786
+ args.max_layers = optimizer.max_layers
787
+
788
+ optimizer.log_to_tensorboard(
789
+ interval=1,
790
+ namespace="post_opt",
791
+ step=(post_opt_step := post_opt_step + 1),
792
+ )
793
+
794
+ comp_disc_np = comp_disc.cpu().numpy().astype(np.uint8)
795
+ comp_disc_np = cv2.cvtColor(comp_disc_np, cv2.COLOR_RGB2BGR)
796
+ cv2.imwrite(
797
+ os.path.join(args.output_folder, "final_model.png"), comp_disc_np
798
+ )
799
+
800
+ # Generate STL files
801
+ if args.flatforge:
802
+ # FlatForge mode: Generate separate STL files for each color
803
+ print("FlatForge mode enabled. Generating separate STL files...")
804
+ generate_flatforge_stls(
805
+ disc_global.cpu().numpy(),
806
+ disc_height_image.cpu().numpy(),
807
+ material_colors_np,
808
+ material_names,
809
+ material_TDs_np,
810
+ args.layer_height,
811
+ args.background_height,
812
+ args.background_color,
813
+ args.stl_output_size,
814
+ args.output_folder,
815
+ cap_layers=args.cap_layers,
816
+ alpha_mask=alpha,
817
+ )
818
+ else:
819
+ # Traditional mode: Generate single STL file
820
+ stl_filename = os.path.join(args.output_folder, "final_model.stl")
821
+ height_map_mm = (
822
+ disc_height_image.cpu().numpy().astype(np.float32)
823
+ ) * args.layer_height
824
+ generate_stl(
825
+ height_map_mm,
826
+ stl_filename,
827
+ args.background_height,
828
+ maximum_x_y_size=args.stl_output_size,
829
+ alpha_mask=alpha,
830
+ )
831
+
832
+ if not args.flatforge:
833
+ background_layers = int(args.background_height // args.layer_height)
834
+ swap_instructions = generate_swap_instructions(
835
+ disc_global.cpu().numpy(),
836
+ disc_height_image.cpu().numpy(),
837
+ args.layer_height,
838
+ background_layers,
839
+ args.background_height,
840
+ material_names,
841
+ getattr(args, "background_material_name", None),
842
+ )
843
+ with open(
844
+ os.path.join(args.output_folder, "swap_instructions.txt"), "w"
845
+ ) as f:
846
+ for line in swap_instructions:
847
+ f.write(line + "\n")
848
+
849
+ project_filename = os.path.join(args.output_folder, "project_file.hfp")
850
+ generate_project_file(
851
+ project_filename,
852
+ args,
853
+ disc_global.cpu().numpy(),
854
+ disc_height_image.cpu().numpy(),
855
+ output_target.shape[1],
856
+ output_target.shape[0],
857
+ os.path.join(args.output_folder, "final_model.stl"),
858
+ args.csv_file,
859
+ )
860
+
861
+ print("All done. Outputs in:", args.output_folder)
862
+ print("Happy Printing!")
863
+ return final_loss
864
+
865
+ @spaces.GPU
866
+ def run_pruning(args, max_colors_for_pruning: int, optimizer: FilamentOptimizer, post_opt_step: int) -> int:
867
+ optimizer.prune(
868
+ max_colors_allowed=max_colors_for_pruning,
869
+ max_swaps_allowed=args.pruning_max_swaps,
870
+ min_layers_allowed=args.min_layers,
871
+ max_layers_allowed=args.pruning_max_layer,
872
+ search_seed=True,
873
+ fast_pruning=args.fast_pruning,
874
+ fast_pruning_percent=args.fast_pruning_percent,
875
+ )
876
+ optimizer.log_to_tensorboard(
877
+ interval=1,
878
+ namespace="post_opt",
879
+ step=(post_opt_step := post_opt_step + 1),
880
+ )
881
+ return post_opt_step
882
+
883
+
884
+ def start(args) -> float:
885
+ """Entry point for a single optimization run.
886
+
887
+ Orchestrates the entire pipeline:
888
+ - Validation & device selection.
889
+ - Material & image loading (+ optional auto background selection).
890
+ - Resolution computation & resizing.
891
+ - Heightmap initialization.
892
+ - Optimizer construction & iterative optimization loop.
893
+ - Post-processing, pruning, and output generation.
894
+
895
+ Args:
896
+ args: Parsed argument namespace.
897
+
898
+ Returns:
899
+ float: Final loss value for this run (after pruning/export).
900
+ """
901
+ if args.num_init_cluster_layers == -1:
902
+ args.num_init_cluster_layers = args.max_layers
903
+
904
+ # check if csv or json is given
905
+ if args.csv_file == "" and args.json_file == "":
906
+ print("Error: No CSV or JSON file given. Please provide one of them.")
907
+ sys.exit(1)
908
+
909
+ device = torch.device("cpu")
910
+
911
+ os.makedirs(args.output_folder, exist_ok=True)
912
+
913
+ perform_basic_check(args)
914
+
915
+ random_seed = set_seed(args)
916
+
917
+ # Load materials (we keep colors_list for potential auto background)
918
+ material_colors_np, material_TDs_np, material_names, colors_list = load_materials(
919
+ args
920
+ )
921
+
922
+ # Read input image early (needed for auto background color)
923
+ img = imread(args.input_image, cv2.IMREAD_UNCHANGED)
924
+ alpha = None
925
+ if img.shape[2] == 4:
926
+ alpha = img[:, :, 3]
927
+ alpha = alpha[..., None]
928
+ img = img[:, :, :3]
929
+
930
+ # Convert image from BGR to RGB for color analysis
931
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
932
+
933
+ # Auto background color selection (optional)
934
+ _auto_select_background_color(
935
+ args, img_rgb, alpha, material_colors_np, material_names, colors_list
936
+ )
937
+
938
+ # Prepare background color tensor and material tensors
939
+ bgr_tuple, background, material_colors, material_TDs = _prepare_background_and_materials(
940
+ args, device, material_colors_np, material_TDs_np
941
+ )
942
+
943
+ # Compute sizes
944
+ computed_output_size, computed_processing_size = _compute_pixel_sizes(args)
945
+
946
+ # Resize alpha if present (match final resolution) after computing size
947
+ if alpha is not None:
948
+ alpha = resize_image(alpha, computed_output_size)
949
+
950
+ # For the final resolution
951
+ output_img_np = resize_image(img_rgb, computed_output_size)
952
+ output_target = torch.tensor(output_img_np, dtype=torch.float32, device=device)
953
+
954
+ # Priority mask handling (full-res)
955
+ focus_map_full = _load_priority_mask(args, output_img_np, device)
956
+
957
+ # Initialize heightmap
958
+ pixel_height_logits_init, global_logits_init, pixel_height_labels = _initialize_heightmap(
959
+ args,
960
+ output_img_np,
961
+ bgr_tuple,
962
+ material_colors_np,
963
+ random_seed,
964
+ )
965
+
966
+ # Prepare processing targets and focus map (processing-res)
967
+ processing_img_np, processing_target, focus_map_proc = _prepare_processing_targets(
968
+ output_img_np, computed_processing_size, device, focus_map_full
969
+ )
970
+
971
+ # Downscale initial logits/labels to processing resolution
972
+ processing_pixel_height_logits_init = cv2.resize(
973
+ src=pixel_height_logits_init,
974
+ interpolation=cv2.INTER_NEAREST,
975
+ dsize=(processing_target.shape[1], processing_target.shape[0]),
976
+ )
977
+ processing_pixel_height_labels = cv2.resize(
978
+ src=pixel_height_labels,
979
+ interpolation=cv2.INTER_NEAREST,
980
+ dsize=(processing_target.shape[1], processing_target.shape[0]),
981
+ )
982
+
983
+ # Apply alpha mask to full-res logits (keep original order/behavior)
984
+ if alpha is not None:
985
+ pixel_height_logits_init[alpha < 128] = -13.815512
986
+
987
+ perception_loss_module = None
988
+
989
+ # Build optimizer
990
+ optimizer = _build_optimizer(
991
+ args,
992
+ processing_target,
993
+ processing_pixel_height_logits_init,
994
+ processing_pixel_height_labels,
995
+ global_logits_init,
996
+ material_colors,
997
+ material_TDs,
998
+ background,
999
+ device,
1000
+ perception_loss_module,
1001
+ focus_map_proc,
1002
+ )
1003
+
1004
+ # Run optimization loop
1005
+ _run_optimization_loop(optimizer, args, torch.device("cuda"))
1006
+
1007
+ # Post-process, prune, and export outputs
1008
+ final_loss = _post_optimize_and_export(
1009
+ args,
1010
+ optimizer,
1011
+ pixel_height_logits_init,
1012
+ pixel_height_labels,
1013
+ output_target,
1014
+ alpha,
1015
+ material_colors_np,
1016
+ material_TDs_np,
1017
+ material_names,
1018
+ bgr_tuple,
1019
+ torch.device("cuda"),
1020
+ focus_map_full,
1021
+ focus_map_proc,
1022
+ )
1023
+
1024
+ return final_loss
1025
+
1026
+
1027
+ def main() -> None:
1028
+ """Support multi-run execution via --best_of; persist best run artifacts.
1029
+
1030
+ If --best_of == 1, simply invokes a single start(). Otherwise:
1031
+ - Creates temporary run subfolders.
1032
+ - Tracks losses, reports statistics (best / median / std).
1033
+ - Moves files from best run folder into the final output folder.
1034
+
1035
+ Note: Memory is periodically reclaimed (gc + CUDA cache clears + closing matplotlib figures).
1036
+ """
1037
+ args = parse_args()
1038
+ final_output_folder = args.output_folder
1039
+ run_best_loss = 1000000000
1040
+ if args.best_of == 1:
1041
+ start(args)
1042
+ else:
1043
+ temp_output_folder = os.path.join(args.output_folder, "temp")
1044
+ ret = []
1045
+ for i in range(args.best_of):
1046
+ try:
1047
+ print(f"Run {i + 1}/{args.best_of}")
1048
+ run_folder = os.path.join(temp_output_folder, f"run_{i + 1}")
1049
+ args.output_folder = run_folder
1050
+ os.makedirs(args.output_folder, exist_ok=True)
1051
+ run_loss = start(args)
1052
+ print(f"Run {i + 1} finished with loss: {run_loss}")
1053
+ if run_loss < run_best_loss:
1054
+ run_best_loss = run_loss
1055
+ print(f"New best loss found: {run_best_loss} in run {i + 1}")
1056
+ ret.append((run_folder, run_loss))
1057
+ torch.cuda.empty_cache()
1058
+ import gc
1059
+
1060
+ gc.collect()
1061
+ torch.cuda.empty_cache()
1062
+ import matplotlib.pyplot as plt
1063
+
1064
+ plt.close("all")
1065
+ except Exception:
1066
+ traceback.print_exc()
1067
+ best_run = min(ret, key=lambda x: x[1])
1068
+ best_run_folder = best_run[0]
1069
+ best_loss = best_run[1]
1070
+
1071
+ losses = [x[1] for x in ret]
1072
+ median_loss = np.median(losses)
1073
+ std_loss = np.std(losses)
1074
+ print(f"Best run folder: {best_run_folder}")
1075
+ print(f"Best run loss: {best_loss}")
1076
+ print(f"Median loss: {median_loss}")
1077
+ print(f"Standard deviation of losses: {std_loss}")
1078
+
1079
+ if not os.path.exists(final_output_folder):
1080
+ os.makedirs(final_output_folder)
1081
+ for file in os.listdir(best_run_folder):
1082
+ src_file = os.path.join(best_run_folder, file)
1083
+ dst_file = os.path.join(final_output_folder, file)
1084
+ if os.path.isfile(src_file):
1085
+ os.rename(src_file, dst_file)
1086
+
1087
+
1088
+ if __name__ == "__main__":
1089
+ main()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- autoforge==1.9.0
2
  sentry-sdk[huggingface_hub]
3
  sentry-sdk[fastapi]
4
  wandb
 
1
+ autoforge==1.9.1
2
  sentry-sdk[huggingface_hub]
3
  sentry-sdk[fastapi]
4
  wandb