from __future__ import annotations import importlib import os import sys from contextlib import nullcontext from pathlib import Path from types import SimpleNamespace from typing import Any, Optional, Tuple from huggingface_hub import hf_hub_download import torch from torch import Tensor, nn PARTFIELD_FEATURE_DIM = 448 _PARTFIELD_SDF_CHANNELS = 64 _PARTFIELD_CHECKPOINT_REPO_ID = "mikaelaangel/partfield-ckpt" _PARTFIELD_CHECKPOINT_FILENAME = "model_objaverse.ckpt" _PARTFIELD_ROOT = Path(__file__).resolve().parents[2] / "partfield" _PARTFIELD_CHECKPOINT_PATH = _PARTFIELD_ROOT / "model" / _PARTFIELD_CHECKPOINT_FILENAME _PARTFIELD_CONFIG_PATH = _PARTFIELD_ROOT / "configs" / "final" / "demo.yaml" def _resolve_partfield_root() -> Path: if _PARTFIELD_ROOT.exists(): return _PARTFIELD_ROOT raise FileNotFoundError( f"Could not find vendored PartField checkout at {_PARTFIELD_ROOT}" ) def _ensure_partfield_checkpoint(checkpoint_path: Path) -> Path: if checkpoint_path.exists(): return checkpoint_path checkpoint_path.parent.mkdir(parents=True, exist_ok=True) hf_hub_download( repo_id=_PARTFIELD_CHECKPOINT_REPO_ID, filename=_PARTFIELD_CHECKPOINT_FILENAME, local_dir=str(checkpoint_path.parent), local_dir_use_symlinks=False, token=os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) if not checkpoint_path.exists(): raise FileNotFoundError(f"Could not download PartField checkpoint to {checkpoint_path}") return checkpoint_path def _resolve_partfield_config_path() -> Path: if not _PARTFIELD_CONFIG_PATH.exists(): raise FileNotFoundError(f"Could not find PartField config: {_PARTFIELD_CONFIG_PATH}") return _PARTFIELD_CONFIG_PATH def ensure_partfield_assets_downloaded() -> Path: _resolve_partfield_root() _resolve_partfield_config_path() return _ensure_partfield_checkpoint(_PARTFIELD_CHECKPOINT_PATH) class PartFieldFeatureExtractor: """Lazy wrapper around the local PartField checkpoint used by `particulate`.""" def __init__( self, ) -> None: self.partfield_root = _resolve_partfield_root() self.partfield_checkpoint_path = _ensure_partfield_checkpoint(_PARTFIELD_CHECKPOINT_PATH) self.partfield_config_path = _resolve_partfield_config_path() self._model: Optional[nn.Module] = None self._sample_triplane_feat: Any = None self._target_device: Optional[torch.device] = None def _ensure_imports_loaded(self) -> None: if self._sample_triplane_feat is not None: return partfield_root_str = str(self.partfield_root) if partfield_root_str not in sys.path: sys.path.append(partfield_root_str) encoder_pc = importlib.import_module("partfield.model.PVCNN.encoder_pc") trainer_module = importlib.import_module("partfield.model_trainer_pvcnn_only_demo") config_module = importlib.import_module("partfield.config") self._sample_triplane_feat = encoder_pc.sample_triplane_feat self._partfield_model_cls = trainer_module.Model self._partfield_setup = config_module.setup def _load_model(self, device: torch.device) -> nn.Module: self._ensure_imports_loaded() cfg = self._partfield_setup( SimpleNamespace( config_file=str(self.partfield_config_path), opts=[], ), freeze=False, ) saving_module = importlib.import_module("lightning.pytorch.core.saving") original_pl_load = saving_module.pl_load def trusted_checkpoint_load(path_or_url: Any, map_location: Any = None) -> Any: return torch.load( path_or_url, map_location=map_location, weights_only=False, ) saving_module.pl_load = trusted_checkpoint_load try: model = self._partfield_model_cls.load_from_checkpoint( str(self.partfield_checkpoint_path), cfg=cfg, map_location="cpu", ) finally: saving_module.pl_load = original_pl_load model.eval() model.requires_grad_(False) model.to(device=device) self._model = model self._target_device = device return model def _get_model(self, device: torch.device) -> nn.Module: if self._model is None: return self._load_model(device) if self._target_device != device: self._model.to(device=device) self._target_device = device return self._model def _autocast_context(self, device: torch.device): if device.type != "cuda": return nullcontext() return torch.autocast( device_type="cuda", dtype=torch.get_autocast_gpu_dtype(), enabled=torch.is_autocast_enabled(), ) def _normalize_points( self, encode_points: Tensor, *decode_points: Tensor | None, ) -> Tuple[Tensor, Tuple[Tensor | None, ...]]: bbmin = encode_points.amin(dim=-2, keepdim=True) bbmax = encode_points.amax(dim=-2, keepdim=True) center = (bbmin + bbmax) * 0.5 extent = (bbmax - bbmin).amax(dim=-1, keepdim=True).clamp_min(1e-6) scale = 2.0 * 0.9 / extent normalized_encode_points = (encode_points - center) * scale normalized_decode_points = tuple( None if points is None else (points - center) * scale for points in decode_points ) return normalized_encode_points, normalized_decode_points @torch.no_grad() def extract( self, *, encode_points: Tensor, decode_shape_points: Tensor | None = None, decode_query_points: Tensor | None = None, ) -> tuple[Tensor | None, Tensor | None]: model = self._get_model(encode_points.device) normalized_encode_points, ( normalized_shape_points, normalized_query_points, ) = self._normalize_points( encode_points, decode_shape_points, decode_query_points, ) with self._autocast_context(encode_points.device): encoded = model.pvcnn(normalized_encode_points, normalized_encode_points) planes = model.triplane_transformer(encoded) part_planes = planes[:, :, _PARTFIELD_SDF_CHANNELS:] shape_features = ( None if normalized_shape_points is None else self._sample_triplane_feat(part_planes, normalized_shape_points) ) query_features = ( None if normalized_query_points is None else self._sample_triplane_feat(part_planes, normalized_query_points) ) return ( None if shape_features is None else shape_features.float(), None if query_features is None else query_features.float(), )