Spaces:
Paused
Paused
Fix: Remove engine caching to prevent LDR generation state corruption
Browse files- code/demo.py +28 -45
code/demo.py
CHANGED
|
@@ -51,44 +51,7 @@ def get_clip_retriever_cached():
|
|
| 51 |
print(f"✅ CLIP retriever loaded ({retriever.features.shape[0]} designs)")
|
| 52 |
return retriever
|
| 53 |
|
| 54 |
-
|
| 55 |
-
def get_gpt_engine_cached():
|
| 56 |
-
"""Lazy load GPT engine (initialized only once, cached)"""
|
| 57 |
-
print("🔧 Initializing GPT engine (one-time setup)...")
|
| 58 |
-
|
| 59 |
-
# Use absolute path relative to this file (fixes HF Spaces deployment)
|
| 60 |
-
config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
|
| 61 |
-
gpt_ckpt_path = None # test mode doesn't use this
|
| 62 |
-
|
| 63 |
-
# Detect HuggingFace Spaces environment
|
| 64 |
-
is_hf_space = os.getenv("SPACE_ID") is not None
|
| 65 |
-
|
| 66 |
-
if is_hf_space:
|
| 67 |
-
from huggingface_hub import hf_hub_download
|
| 68 |
-
print(f"Loading GPT model from HuggingFace Model Hub...")
|
| 69 |
-
shape_ckpt_path = hf_hub_download(
|
| 70 |
-
repo_id="0xZohar/object-assembler-models",
|
| 71 |
-
filename="save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors",
|
| 72 |
-
cache_dir=HF_CACHE_DIR,
|
| 73 |
-
local_files_only=False # Allow runtime download on first use
|
| 74 |
-
)
|
| 75 |
-
save_gpt_ckpt_path = shape_ckpt_path
|
| 76 |
-
print(f"✅ GPT model loaded from cache: {shape_ckpt_path}")
|
| 77 |
-
else:
|
| 78 |
-
shape_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 79 |
-
save_gpt_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 80 |
-
|
| 81 |
-
# GPU T4 small environment: Use EngineFast with CUDA (required for t2t method)
|
| 82 |
-
device = torch.device("cuda")
|
| 83 |
-
print("✅ Using EngineFast on GPU (T4 small)")
|
| 84 |
-
|
| 85 |
-
engine = EngineFast(
|
| 86 |
-
config_path, gpt_ckpt_path, shape_ckpt_path, save_gpt_ckpt_path,
|
| 87 |
-
device=device,
|
| 88 |
-
mode='test'
|
| 89 |
-
)
|
| 90 |
-
print(f"✅ GPT engine initialized on {device}")
|
| 91 |
-
return engine
|
| 92 |
|
| 93 |
# 确保临时目录存在(远程服务器路径)
|
| 94 |
TMP_DIR = "./tmp/ldr_processor_demo"
|
|
@@ -333,22 +296,42 @@ def generate_ldr_gpu(ldr_content, ldr_path):
|
|
| 333 |
List of predicted LDR lines
|
| 334 |
"""
|
| 335 |
print("🤖 Running GPT model to generate new assembly sequence...")
|
|
|
|
| 336 |
|
| 337 |
stride = 5
|
| 338 |
rot_num = 24
|
| 339 |
bert_shift = 1
|
| 340 |
shift = 0
|
| 341 |
|
| 342 |
-
#
|
| 343 |
-
|
|
|
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
else:
|
| 349 |
-
|
|
|
|
| 350 |
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
targets_source = torch.from_numpy(ldr_content[0]).to(device).unsqueeze(0)
|
| 354 |
targets = targets_source.clone()
|
|
|
|
| 51 |
print(f"✅ CLIP retriever loaded ({retriever.features.shape[0]} designs)")
|
| 52 |
return retriever
|
| 53 |
|
| 54 |
+
# Removed cached engine - creates fresh instance each time to prevent state corruption
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# 确保临时目录存在(远程服务器路径)
|
| 57 |
TMP_DIR = "./tmp/ldr_processor_demo"
|
|
|
|
| 296 |
List of predicted LDR lines
|
| 297 |
"""
|
| 298 |
print("🤖 Running GPT model to generate new assembly sequence...")
|
| 299 |
+
print(" Using CUDA graphs, this will take some time to warmup and capture the graph.")
|
| 300 |
|
| 301 |
stride = 5
|
| 302 |
rot_num = 24
|
| 303 |
bert_shift = 1
|
| 304 |
shift = 0
|
| 305 |
|
| 306 |
+
# Prepare checkpoint paths
|
| 307 |
+
config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
|
| 308 |
+
gpt_ckpt_path = None # test mode doesn't use this
|
| 309 |
|
| 310 |
+
# Detect HuggingFace Spaces environment
|
| 311 |
+
is_hf_space = os.getenv("SPACE_ID") is not None
|
| 312 |
+
|
| 313 |
+
if is_hf_space:
|
| 314 |
+
from huggingface_hub import hf_hub_download
|
| 315 |
+
checkpoint_path = hf_hub_download(
|
| 316 |
+
repo_id="0xZohar/object-assembler-models",
|
| 317 |
+
filename="save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors",
|
| 318 |
+
cache_dir=HF_CACHE_DIR,
|
| 319 |
+
local_files_only=False
|
| 320 |
+
)
|
| 321 |
+
shape_ckpt_path = checkpoint_path
|
| 322 |
+
save_gpt_ckpt_path = checkpoint_path
|
| 323 |
else:
|
| 324 |
+
shape_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 325 |
+
save_gpt_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 326 |
|
| 327 |
+
# Create fresh engine instance (fixes state corruption from caching)
|
| 328 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 329 |
+
engine = EngineFast(
|
| 330 |
+
config_path, gpt_ckpt_path, shape_ckpt_path, save_gpt_ckpt_path,
|
| 331 |
+
device=device,
|
| 332 |
+
mode='test'
|
| 333 |
+
)
|
| 334 |
+
print(" Compiled the graph.")
|
| 335 |
|
| 336 |
targets_source = torch.from_numpy(ldr_content[0]).to(device).unsqueeze(0)
|
| 337 |
targets = targets_source.clone()
|