MajorDaniel commited on
Commit
f1ff5ee
·
verified ·
1 Parent(s): d3c02b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -88
app.py CHANGED
@@ -3,63 +3,48 @@ import subprocess
3
  import time
4
  import sys
5
  import shutil
 
 
 
6
  from pathlib import Path
7
  from unittest.mock import MagicMock
8
 
9
  # ==========================================================
10
- # 1. SYSTEM SETUP (Xvfb & Pfade)
11
  # ==========================================================
12
- if os.environ.get("SPACE_ID"):
13
- # Xvfb Fix für Berechtigungen: Nutze /tmp statt /tmp/.X11-unix
14
- os.environ["PATH"] += os.pathsep + "/usr/bin"
15
-
16
- if not os.path.exists("/tmp/.X99-lock"):
17
- try:
18
- # Wir fügen -vfbdir /tmp hinzu, um den Berechtigungsfehler zu umgehen
19
- subprocess.Popen([
20
- "Xvfb", ":99",
21
- "-screen", "0", "1024x768x24",
22
- "-ac", "+extension", "GLX", "+render",
23
- "-noreset"
24
- ])
25
- time.sleep(3)
26
- except Exception as e:
27
- print(f"Xvfb Setup Error: {e}")
28
-
29
- os.environ.update({
30
- "DISPLAY": ":99",
31
- "PYOPENGL_PLATFORM": "egl"
32
- })
33
-
34
- # bpy-bin Import-Logik
35
- try:
36
- import bpy
37
- print(f"✅ Blender {bpy.app.version_string} über bpy-bin geladen!")
38
- except ImportError:
39
- print("⚠️ bpy nicht direkt gefunden, versuche Pfad-Suche...")
40
- import site
41
- for p in site.getsitepackages():
42
- # bpy-bin installiert sich oft als 'bpy' im site-packages
43
- if os.path.exists(os.path.join(p, "bpy")):
44
- sys.path.append(p)
45
- break
46
- try:
47
- import bpy
48
- print(f"✅ Blender {bpy.app.version_string} nach Suche gefunden!")
49
- except ImportError:
50
- print("❌ bpy-bin konnte nicht geladen werden.")
51
 
52
  # ==========================================================
53
  # 2. BUGFIXES & MOCKS
54
  # ==========================================================
55
  # Fix A: Gradio Schema-Fehler
56
  import gradio_client.utils as client_utils
 
57
  client_utils._json_schema_to_python_type = lambda *args, **kwargs: "Any"
58
  client_utils.json_schema_to_python_type = lambda *args, **kwargs: "Any"
59
 
60
  # Fix B: Flash Attention Mocking
61
  try:
62
- import flash_attn
63
  except ImportError:
64
  mock = MagicMock()
65
  sys.modules["flash_attn"] = mock
@@ -67,36 +52,215 @@ except ImportError:
67
  sys.modules["flash_attn.modules.mha"] = mock
68
  print("Flash Attention gemockt.")
69
 
 
70
  # ==========================================================
71
  # 3. CORE IMPORTS
72
  # ==========================================================
73
  try:
74
  import open3d as o3d
 
75
  o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
76
- except:
77
  pass
78
 
79
  import gradio as gr
80
  import spaces
81
- import torch
82
  import lightning as L
83
  import yaml
84
  from box import Box
85
 
86
- # ... (Ab hier folgen deine Funktionen wie validate_input_file, extract_mesh_python, etc. unverändert)
87
 
88
  # ==========================================================
89
- # 3. DEINE FUNKTIONEN (Unverändert)
90
  # ==========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
 
92
  def validate_input_file(file_path: str) -> bool:
93
- supported_formats = ['.obj', '.fbx', '.glb']
94
  if not file_path or not Path(file_path).exists():
95
  return False
96
  return Path(file_path).suffix.lower() in supported_formats
97
 
 
98
  def extract_mesh_python(input_file: str, output_dir: str) -> str:
99
- from src.data.extract import extract_builtin, get_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  files = get_files(
101
  data_name="raw_data.npz",
102
  inputs=str(input_file),
@@ -106,12 +270,17 @@ def extract_mesh_python(input_file: str, output_dir: str) -> str:
106
  warning=False,
107
  )
108
  if not files:
109
- raise RuntimeError("No files to extract")
110
- timestamp = str(int(time.time()))
111
- extract_builtin(output_folder=output_dir, target_count=50000, num_runs=1, id=0, time=timestamp, files=files)
112
  return files[0][1]
113
 
114
- def run_inference_python(input_file: str, output_file: str, inference_type: str, seed: int = 12345, npz_dir: str = None) -> str:
 
 
 
 
 
 
 
115
  from src.data.datapath import Datapath
116
  from src.data.dataset import DatasetConfig, UniRigDatasetModule
117
  from src.data.transform import TransformConfig
@@ -123,84 +292,118 @@ def run_inference_python(input_file: str, output_file: str, inference_type: str,
123
 
124
  if inference_type == "skeleton":
125
  L.seed_everything(seed, workers=True)
126
- configs = ["configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml",
127
- "configs/transform/inference_ar_transform.yaml",
128
- "configs/model/unirig_ar_350m_1024_81920_float32.yaml",
129
- "configs/system/ar_inference_articulationxl.yaml",
130
- "configs/tokenizer/tokenizer_parts_articulationxl_256.yaml"]
 
 
131
  data_name = "raw_data.npz"
132
  else:
133
- configs = ["configs/task/quick_inference_unirig_skin.yaml",
134
- "configs/transform/inference_skin_transform.yaml",
135
- "configs/model/unirig_skin.yaml",
136
- "configs/system/skin.yaml", None]
 
 
 
137
  data_name = "predict_skeleton.npz"
138
 
139
- with open(configs[0], 'r') as f: task = Box(yaml.safe_load(f))
140
-
 
141
  if inference_type == "skeleton":
142
- if npz_dir is None: npz_dir = Path(output_file).parent / "npz"
 
143
  npz_dir.mkdir(exist_ok=True)
144
- npz_data_dir = extract_mesh_python(input_file, npz_dir)
145
  datapath = Datapath(files=[npz_data_dir], cls=None)
146
  else:
147
  skeleton_work_dir = Path(input_file).parent
148
  skeleton_npz_dir = list(skeleton_work_dir.rglob("**/*.npz"))[0].parent
149
  datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None)
150
 
151
- data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r')))
152
- transform_config = Box(yaml.safe_load(open(configs[1], 'r')))
153
-
154
  if inference_type == "skeleton":
155
- tokenizer = get_tokenizer(config=TokenizerConfig.parse(config=Box(yaml.safe_load(open(configs[4], 'r')))))
156
- model = get_model(tokenizer=tokenizer, **Box(yaml.safe_load(open(configs[2], 'r'))))
 
 
157
  else:
158
- model = get_model(tokenizer=None, **Box(yaml.safe_load(open(configs[2], 'r'))))
159
 
160
  data = UniRigDatasetModule(
161
- process_fn=model._process_fn,
162
  predict_dataset_config=DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls(),
163
  predict_transform_config=TransformConfig.parse(config=transform_config.predict_transform_config),
164
- tokenizer_config=None if inference_type=="skin" else tokenizer.config,
165
- data_name=data_name, datapath=datapath, cls=None
 
 
166
  )
167
 
168
  writer_config = task.writer.copy()
169
  if inference_type == "skeleton":
170
- writer_config.update({'npz_dir': str(npz_dir), 'output_dir': str(Path(output_file).parent), 'output_name': Path(output_file).name, 'user_mode': False})
 
 
 
 
 
 
 
171
  else:
172
- writer_config.update({'npz_dir': str(skeleton_npz_dir), 'output_name': str(output_file), 'user_mode': True, 'export_fbx': True})
 
 
 
 
 
 
 
173
 
174
  callbacks = [get_writer(**writer_config, order_config=data.predict_transform_config.order_config)]
175
- system = get_system(**Box(yaml.safe_load(open(configs[3], 'r'))), model=model, steps_per_epoch=1)
176
-
177
  trainer = L.Trainer(callbacks=callbacks, logger=None, **task.trainer)
178
- trainer.predict(system, datamodule=data, ckpt_path=download(task.resume_from_checkpoint), return_predictions=False)
179
-
 
 
 
 
 
180
  return str(output_file)
181
 
 
182
  def merge_results_python(source_file: str, target_file: str, output_file: str) -> str:
183
  from src.inference.merge import transfer
 
184
  transfer(source=str(source_file), target=str(target_file), output=str(output_file), add_root=False)
185
  return str(output_file)
186
 
 
187
  # ==========================================================
188
- # 4. GRADIO APP
189
  # ==========================================================
190
-
191
  @spaces.GPU()
192
  def main(input_file: str, seed: int = 12345):
193
  temp_dir = Path(__file__).parent / "tmp"
194
  temp_dir.mkdir(exist_ok=True)
195
- if not validate_input_file(input_file): raise gr.Error("Invalid file format")
196
-
 
 
197
  file_stem = Path(input_file).stem
198
  input_model_dir = temp_dir / f"{file_stem}_{seed}"
199
  input_model_dir.mkdir(exist_ok=True)
200
-
201
  input_path = input_model_dir / Path(input_file).name
202
  shutil.copy2(input_file, input_path)
203
-
204
  skel_fbx = input_model_dir / f"{file_stem}_skeleton.fbx"
205
  skel_only = input_model_dir / f"{file_stem}_skeleton_only{input_path.suffix}"
206
  skin_fbx = input_model_dir / f"{file_stem}_skin.fbx"
@@ -208,12 +411,13 @@ def main(input_file: str, seed: int = 12345):
208
 
209
  run_inference_python(str(input_path), str(skel_fbx), "skeleton", seed)
210
  merge_results_python(str(skel_fbx), str(input_path), str(skel_only))
211
-
212
  run_inference_python(str(skel_fbx), str(skin_fbx), "skin")
213
  merge_results_python(str(skin_fbx), str(input_path), str(final_out))
214
 
215
  return str(final_out), [str(skel_only), str(final_out)]
216
 
 
217
  def create_app():
218
  with gr.Blocks(title="UniRig Demo") as interface:
219
  gr.Markdown("# 🎯 UniRig: Automated 3D Model Rigging")
@@ -225,9 +429,10 @@ def create_app():
225
  with gr.Column():
226
  out_3d = gr.Model3D(label="Result")
227
  out_files = gr.Files(label="Download Files")
228
-
229
  btn.click(fn=main, inputs=[input_3d, seed], outputs=[out_3d, out_files])
230
  return interface
231
 
 
232
  if __name__ == "__main__":
233
- create_app().queue().launch(show_api=False)
 
3
  import time
4
  import sys
5
  import shutil
6
+ import tarfile
7
+ import urllib.request
8
+ import site
9
  from pathlib import Path
10
  from unittest.mock import MagicMock
11
 
12
  # ==========================================================
13
+ # 0. GLOBALS (Blender userland download)
14
  # ==========================================================
15
+ # Blender 3.6 LTS uses Python 3.10 -> good match for this Space
16
+ BLENDER_VERSION = "3.6.5"
17
+ BLENDER_TARBALL = f"blender-{BLENDER_VERSION}-linux-x64.tar.xz"
18
+ BLENDER_URL = f"https://download.blender.org/release/Blender3.6/{BLENDER_TARBALL}"
19
+
20
+ # Cache location writable without root
21
+ BLENDER_CACHE_DIR = Path.home() / ".cache" / "unirig" / f"blender-{BLENDER_VERSION}"
22
+ BLENDER_EXTRACT_DIR = BLENDER_CACHE_DIR / f"blender-{BLENDER_VERSION}-linux-x64"
23
+ BLENDER_BIN = BLENDER_EXTRACT_DIR / "blender"
24
+
25
+ # Where we will write a temporary Blender python script at runtime
26
+ BLENDER_SCRIPT_PATH = BLENDER_CACHE_DIR / "hf_blender_extract.py"
27
+
28
+
29
+ # ==========================================================
30
+ # 1. SYSTEM SETUP (No Xvfb needed when using Blender -b)
31
+ # ==========================================================
32
+ # NOTE: We intentionally do NOT start Xvfb because HF blocks /tmp/.X11-unix creation
33
+ # and Blender is run headless via `-b`.
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  # ==========================================================
37
  # 2. BUGFIXES & MOCKS
38
  # ==========================================================
39
  # Fix A: Gradio Schema-Fehler
40
  import gradio_client.utils as client_utils
41
+
42
  client_utils._json_schema_to_python_type = lambda *args, **kwargs: "Any"
43
  client_utils.json_schema_to_python_type = lambda *args, **kwargs: "Any"
44
 
45
  # Fix B: Flash Attention Mocking
46
  try:
47
+ import flash_attn # noqa: F401
48
  except ImportError:
49
  mock = MagicMock()
50
  sys.modules["flash_attn"] = mock
 
52
  sys.modules["flash_attn.modules.mha"] = mock
53
  print("Flash Attention gemockt.")
54
 
55
+
56
  # ==========================================================
57
  # 3. CORE IMPORTS
58
  # ==========================================================
59
  try:
60
  import open3d as o3d
61
+
62
  o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
63
+ except Exception:
64
  pass
65
 
66
  import gradio as gr
67
  import spaces
68
+ import torch # noqa: F401
69
  import lightning as L
70
  import yaml
71
  from box import Box
72
 
 
73
 
74
  # ==========================================================
75
+ # 4. BLENDER HELPERS (download + run headless extraction)
76
  # ==========================================================
77
+ def ensure_blender() -> str:
78
+ """
79
+ Download and extract Blender into user cache dir (no root).
80
+ Returns path to blender executable.
81
+ """
82
+ if BLENDER_BIN.exists():
83
+ return str(BLENDER_BIN)
84
+
85
+ BLENDER_CACHE_DIR.mkdir(parents=True, exist_ok=True)
86
+ tar_path = BLENDER_CACHE_DIR / BLENDER_TARBALL
87
+
88
+ if not tar_path.exists():
89
+ print(f"⬇️ Downloading Blender {BLENDER_VERSION} from: {BLENDER_URL}")
90
+ urllib.request.urlretrieve(BLENDER_URL, tar_path)
91
+
92
+ print(f"📦 Extracting Blender to: {BLENDER_CACHE_DIR}")
93
+ with tarfile.open(tar_path, "r:xz") as tf:
94
+ tf.extractall(path=BLENDER_CACHE_DIR)
95
+
96
+ if not BLENDER_BIN.exists():
97
+ raise RuntimeError(f"Blender binary not found after extract: {BLENDER_BIN}")
98
+
99
+ return str(BLENDER_BIN)
100
+
101
+
102
+ def ensure_blender_script():
103
+ """
104
+ Writes a tiny extraction runner script that will be executed INSIDE Blender's Python.
105
+ This avoids needing `import bpy` in the Space's Python runtime.
106
+ """
107
+ if BLENDER_SCRIPT_PATH.exists():
108
+ return
109
+
110
+ BLENDER_CACHE_DIR.mkdir(parents=True, exist_ok=True)
111
+
112
+ # This script runs inside Blender's Python; it can import bpy and then call your extraction pipeline.
113
+ script = r'''
114
+ import sys
115
+ import time
116
+ from pathlib import Path
117
+
118
+ def _parse(argv):
119
+ args = {"input": None, "output_dir": None, "target_count": 50000}
120
+ it = iter(argv)
121
+ for k in it:
122
+ if k == "--input":
123
+ args["input"] = next(it)
124
+ elif k == "--output_dir":
125
+ args["output_dir"] = next(it)
126
+ elif k == "--target_count":
127
+ args["target_count"] = int(next(it))
128
+ if not args["input"] or not args["output_dir"]:
129
+ raise SystemExit("Usage: --input <file> --output_dir <dir> [--target_count N]")
130
+ return args
131
+
132
+ def main():
133
+ argv = sys.argv
134
+ if "--" in argv:
135
+ argv = argv[argv.index("--") + 1 :]
136
+ else:
137
+ argv = []
138
+ args = _parse(argv)
139
+
140
+ out = Path(args["output_dir"])
141
+ out.mkdir(parents=True, exist_ok=True)
142
+
143
+ # Now import your project's extractor (this will import bpy inside Blender, which is fine)
144
+ from src.data.extract import extract_builtin, get_files
145
+
146
+ files = get_files(
147
+ data_name="raw_data.npz",
148
+ inputs=str(args["input"]),
149
+ input_dataset_dir=None,
150
+ output_dataset_dir=str(out),
151
+ force_override=True,
152
+ warning=False,
153
+ )
154
+ if not files:
155
+ raise RuntimeError("No files to extract")
156
+
157
+ timestamp = str(int(time.time()))
158
+ extract_builtin(
159
+ output_folder=str(out),
160
+ target_count=int(args["target_count"]),
161
+ num_runs=1,
162
+ id=0,
163
+ time=timestamp,
164
+ files=files,
165
+ )
166
+
167
+ if __name__ == "__main__":
168
+ main()
169
+ '''
170
+ BLENDER_SCRIPT_PATH.write_text(script, encoding="utf-8")
171
+
172
+
173
+ def run_blender_extract(input_file: str, output_dir: str, target_count: int = 50000):
174
+ """
175
+ Runs Blender headless (-b) and executes the extraction script.
176
+ We also pass PYTHONPATH so Blender's Python can import this repo + site-packages.
177
+ """
178
+ blender = ensure_blender()
179
+ ensure_blender_script()
180
+
181
+ repo_root = Path(__file__).parent.resolve()
182
+
183
+ # Make installed pip packages visible to Blender-Python (in case extract.py needs them)
184
+ py_paths = []
185
+ try:
186
+ py_paths += site.getsitepackages()
187
+ except Exception:
188
+ pass
189
+ py_paths.append(str(repo_root))
190
+
191
+ env = os.environ.copy()
192
+ env["PYTHONPATH"] = os.pathsep.join([p for p in py_paths if p] + [env.get("PYTHONPATH", "")])
193
+
194
+ cmd = [
195
+ blender,
196
+ "-b",
197
+ "-noaudio",
198
+ "--python",
199
+ str(BLENDER_SCRIPT_PATH),
200
+ "--",
201
+ "--input",
202
+ str(input_file),
203
+ "--output_dir",
204
+ str(output_dir),
205
+ "--target_count",
206
+ str(target_count),
207
+ ]
208
+
209
+ print("🧩 Running Blender extract:")
210
+ print(" " + " ".join(cmd))
211
+ subprocess.check_call(cmd, env=env)
212
 
213
+
214
+ # ==========================================================
215
+ # 5. DEINE FUNKTIONEN (mit Blender-Fallback)
216
+ # ==========================================================
217
  def validate_input_file(file_path: str) -> bool:
218
+ supported_formats = [".obj", ".fbx", ".glb"]
219
  if not file_path or not Path(file_path).exists():
220
  return False
221
  return Path(file_path).suffix.lower() in supported_formats
222
 
223
+
224
  def extract_mesh_python(input_file: str, output_dir: str) -> str:
225
+ """
226
+ 1) Try native bpy (if it ever exists in the Space)
227
+ 2) Otherwise run Blender headless subprocess that generates the npz
228
+ """
229
+ try:
230
+ import bpy # noqa: F401
231
+ from src.data.extract import extract_builtin, get_files
232
+
233
+ files = get_files(
234
+ data_name="raw_data.npz",
235
+ inputs=str(input_file),
236
+ input_dataset_dir=None,
237
+ output_dataset_dir=output_dir,
238
+ force_override=True,
239
+ warning=False,
240
+ )
241
+ if not files:
242
+ raise RuntimeError("No files to extract")
243
+
244
+ timestamp = str(int(time.time()))
245
+ extract_builtin(
246
+ output_folder=output_dir,
247
+ target_count=50000,
248
+ num_runs=1,
249
+ id=0,
250
+ time=timestamp,
251
+ files=files,
252
+ )
253
+ return files[0][1]
254
+ except Exception as e:
255
+ print(f"⚠️ Native bpy extraction failed ({type(e).__name__}: {e}) -> using Blender subprocess fallback.")
256
+
257
+ # Blender subprocess fallback
258
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
259
+ run_blender_extract(input_file=input_file, output_dir=output_dir, target_count=50000)
260
+
261
+ # Recompute expected output path using existing helper
262
+ from src.data.extract import get_files
263
+
264
  files = get_files(
265
  data_name="raw_data.npz",
266
  inputs=str(input_file),
 
270
  warning=False,
271
  )
272
  if not files:
273
+ raise RuntimeError("No files produced by Blender extraction")
 
 
274
  return files[0][1]
275
 
276
+
277
+ def run_inference_python(
278
+ input_file: str,
279
+ output_file: str,
280
+ inference_type: str,
281
+ seed: int = 12345,
282
+ npz_dir: str = None,
283
+ ) -> str:
284
  from src.data.datapath import Datapath
285
  from src.data.dataset import DatasetConfig, UniRigDatasetModule
286
  from src.data.transform import TransformConfig
 
292
 
293
  if inference_type == "skeleton":
294
  L.seed_everything(seed, workers=True)
295
+ configs = [
296
+ "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml",
297
+ "configs/transform/inference_ar_transform.yaml",
298
+ "configs/model/unirig_ar_350m_1024_81920_float32.yaml",
299
+ "configs/system/ar_inference_articulationxl.yaml",
300
+ "configs/tokenizer/tokenizer_parts_articulationxl_256.yaml",
301
+ ]
302
  data_name = "raw_data.npz"
303
  else:
304
+ configs = [
305
+ "configs/task/quick_inference_unirig_skin.yaml",
306
+ "configs/transform/inference_skin_transform.yaml",
307
+ "configs/model/unirig_skin.yaml",
308
+ "configs/system/skin.yaml",
309
+ None,
310
+ ]
311
  data_name = "predict_skeleton.npz"
312
 
313
+ with open(configs[0], "r") as f:
314
+ task = Box(yaml.safe_load(f))
315
+
316
  if inference_type == "skeleton":
317
+ if npz_dir is None:
318
+ npz_dir = Path(output_file).parent / "npz"
319
  npz_dir.mkdir(exist_ok=True)
320
+ npz_data_dir = extract_mesh_python(input_file, str(npz_dir))
321
  datapath = Datapath(files=[npz_data_dir], cls=None)
322
  else:
323
  skeleton_work_dir = Path(input_file).parent
324
  skeleton_npz_dir = list(skeleton_work_dir.rglob("**/*.npz"))[0].parent
325
  datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None)
326
 
327
+ data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", "r")))
328
+ transform_config = Box(yaml.safe_load(open(configs[1], "r")))
329
+
330
  if inference_type == "skeleton":
331
+ tokenizer = get_tokenizer(
332
+ config=TokenizerConfig.parse(config=Box(yaml.safe_load(open(configs[4], "r"))))
333
+ )
334
+ model = get_model(tokenizer=tokenizer, **Box(yaml.safe_load(open(configs[2], "r"))))
335
  else:
336
+ model = get_model(tokenizer=None, **Box(yaml.safe_load(open(configs[2], "r"))))
337
 
338
  data = UniRigDatasetModule(
339
+ process_fn=model._process_fn,
340
  predict_dataset_config=DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls(),
341
  predict_transform_config=TransformConfig.parse(config=transform_config.predict_transform_config),
342
+ tokenizer_config=None if inference_type == "skin" else tokenizer.config,
343
+ data_name=data_name,
344
+ datapath=datapath,
345
+ cls=None,
346
  )
347
 
348
  writer_config = task.writer.copy()
349
  if inference_type == "skeleton":
350
+ writer_config.update(
351
+ {
352
+ "npz_dir": str(npz_dir),
353
+ "output_dir": str(Path(output_file).parent),
354
+ "output_name": Path(output_file).name,
355
+ "user_mode": False,
356
+ }
357
+ )
358
  else:
359
+ writer_config.update(
360
+ {
361
+ "npz_dir": str(skeleton_npz_dir),
362
+ "output_name": str(output_file),
363
+ "user_mode": True,
364
+ "export_fbx": True,
365
+ }
366
+ )
367
 
368
  callbacks = [get_writer(**writer_config, order_config=data.predict_transform_config.order_config)]
369
+ system = get_system(**Box(yaml.safe_load(open(configs[3], "r"))), model=model, steps_per_epoch=1)
370
+
371
  trainer = L.Trainer(callbacks=callbacks, logger=None, **task.trainer)
372
+ trainer.predict(
373
+ system,
374
+ datamodule=data,
375
+ ckpt_path=download(task.resume_from_checkpoint),
376
+ return_predictions=False,
377
+ )
378
+
379
  return str(output_file)
380
 
381
+
382
  def merge_results_python(source_file: str, target_file: str, output_file: str) -> str:
383
  from src.inference.merge import transfer
384
+
385
  transfer(source=str(source_file), target=str(target_file), output=str(output_file), add_root=False)
386
  return str(output_file)
387
 
388
+
389
  # ==========================================================
390
+ # 6. GRADIO APP
391
  # ==========================================================
 
392
  @spaces.GPU()
393
  def main(input_file: str, seed: int = 12345):
394
  temp_dir = Path(__file__).parent / "tmp"
395
  temp_dir.mkdir(exist_ok=True)
396
+
397
+ if not validate_input_file(input_file):
398
+ raise gr.Error("Invalid file format")
399
+
400
  file_stem = Path(input_file).stem
401
  input_model_dir = temp_dir / f"{file_stem}_{seed}"
402
  input_model_dir.mkdir(exist_ok=True)
403
+
404
  input_path = input_model_dir / Path(input_file).name
405
  shutil.copy2(input_file, input_path)
406
+
407
  skel_fbx = input_model_dir / f"{file_stem}_skeleton.fbx"
408
  skel_only = input_model_dir / f"{file_stem}_skeleton_only{input_path.suffix}"
409
  skin_fbx = input_model_dir / f"{file_stem}_skin.fbx"
 
411
 
412
  run_inference_python(str(input_path), str(skel_fbx), "skeleton", seed)
413
  merge_results_python(str(skel_fbx), str(input_path), str(skel_only))
414
+
415
  run_inference_python(str(skel_fbx), str(skin_fbx), "skin")
416
  merge_results_python(str(skin_fbx), str(input_path), str(final_out))
417
 
418
  return str(final_out), [str(skel_only), str(final_out)]
419
 
420
+
421
  def create_app():
422
  with gr.Blocks(title="UniRig Demo") as interface:
423
  gr.Markdown("# 🎯 UniRig: Automated 3D Model Rigging")
 
429
  with gr.Column():
430
  out_3d = gr.Model3D(label="Result")
431
  out_files = gr.Files(label="Download Files")
432
+
433
  btn.click(fn=main, inputs=[input_3d, seed], outputs=[out_3d, out_files])
434
  return interface
435
 
436
+
437
  if __name__ == "__main__":
438
+ create_app().queue().launch(show_api=False)