0xZohar commited on
Commit
1c8c926
·
verified ·
1 Parent(s): f17371f

Fix: Remove engine caching to prevent LDR generation state corruption

Browse files
Files changed (1) hide show
  1. code/demo.py +28 -45
code/demo.py CHANGED
@@ -51,44 +51,7 @@ def get_clip_retriever_cached():
51
  print(f"✅ CLIP retriever loaded ({retriever.features.shape[0]} designs)")
52
  return retriever
53
 
54
- @functools.lru_cache(maxsize=1)
55
- def get_gpt_engine_cached():
56
- """Lazy load GPT engine (initialized only once, cached)"""
57
- print("🔧 Initializing GPT engine (one-time setup)...")
58
-
59
- # Use absolute path relative to this file (fixes HF Spaces deployment)
60
- config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
61
- gpt_ckpt_path = None # test mode doesn't use this
62
-
63
- # Detect HuggingFace Spaces environment
64
- is_hf_space = os.getenv("SPACE_ID") is not None
65
-
66
- if is_hf_space:
67
- from huggingface_hub import hf_hub_download
68
- print(f"Loading GPT model from HuggingFace Model Hub...")
69
- shape_ckpt_path = hf_hub_download(
70
- repo_id="0xZohar/object-assembler-models",
71
- filename="save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors",
72
- cache_dir=HF_CACHE_DIR,
73
- local_files_only=False # Allow runtime download on first use
74
- )
75
- save_gpt_ckpt_path = shape_ckpt_path
76
- print(f"✅ GPT model loaded from cache: {shape_ckpt_path}")
77
- else:
78
- shape_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
79
- save_gpt_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
80
-
81
- # GPU T4 small environment: Use EngineFast with CUDA (required for t2t method)
82
- device = torch.device("cuda")
83
- print("✅ Using EngineFast on GPU (T4 small)")
84
-
85
- engine = EngineFast(
86
- config_path, gpt_ckpt_path, shape_ckpt_path, save_gpt_ckpt_path,
87
- device=device,
88
- mode='test'
89
- )
90
- print(f"✅ GPT engine initialized on {device}")
91
- return engine
92
 
93
  # 确保临时目录存在(远程服务器路径)
94
  TMP_DIR = "./tmp/ldr_processor_demo"
@@ -333,22 +296,42 @@ def generate_ldr_gpu(ldr_content, ldr_path):
333
  List of predicted LDR lines
334
  """
335
  print("🤖 Running GPT model to generate new assembly sequence...")
 
336
 
337
  stride = 5
338
  rot_num = 24
339
  bert_shift = 1
340
  shift = 0
341
 
342
- # Lazy load GPT engine (cached, initialized only once)
343
- engine = get_gpt_engine_cached()
 
344
 
345
- device = engine.device
346
- if device.type == "cuda":
347
- print(" Using CUDA graphs (this will take some time to warmup)")
 
 
 
 
 
 
 
 
 
 
348
  else:
349
- print(" Running on CPU (slower). Set GPT_DEVICE=cuda or enable GPU/ZeroGPU for faster runs.")
 
350
 
351
- print(" Graph compiled, starting generation...")
 
 
 
 
 
 
 
352
 
353
  targets_source = torch.from_numpy(ldr_content[0]).to(device).unsqueeze(0)
354
  targets = targets_source.clone()
 
51
  print(f"✅ CLIP retriever loaded ({retriever.features.shape[0]} designs)")
52
  return retriever
53
 
54
+ # Removed cached engine - creates fresh instance each time to prevent state corruption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # 确保临时目录存在(远程服务器路径)
57
  TMP_DIR = "./tmp/ldr_processor_demo"
 
296
  List of predicted LDR lines
297
  """
298
  print("🤖 Running GPT model to generate new assembly sequence...")
299
+ print(" Using CUDA graphs, this will take some time to warmup and capture the graph.")
300
 
301
  stride = 5
302
  rot_num = 24
303
  bert_shift = 1
304
  shift = 0
305
 
306
+ # Prepare checkpoint paths
307
+ config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
308
+ gpt_ckpt_path = None # test mode doesn't use this
309
 
310
+ # Detect HuggingFace Spaces environment
311
+ is_hf_space = os.getenv("SPACE_ID") is not None
312
+
313
+ if is_hf_space:
314
+ from huggingface_hub import hf_hub_download
315
+ checkpoint_path = hf_hub_download(
316
+ repo_id="0xZohar/object-assembler-models",
317
+ filename="save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors",
318
+ cache_dir=HF_CACHE_DIR,
319
+ local_files_only=False
320
+ )
321
+ shape_ckpt_path = checkpoint_path
322
+ save_gpt_ckpt_path = checkpoint_path
323
  else:
324
+ shape_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
325
+ save_gpt_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
326
 
327
+ # Create fresh engine instance (fixes state corruption from caching)
328
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
329
+ engine = EngineFast(
330
+ config_path, gpt_ckpt_path, shape_ckpt_path, save_gpt_ckpt_path,
331
+ device=device,
332
+ mode='test'
333
+ )
334
+ print(" Compiled the graph.")
335
 
336
  targets_source = torch.from_numpy(ldr_content[0]).to(device).unsqueeze(0)
337
  targets = targets_source.clone()