""" Model loading utilities for DA3 and other models. """ import logging import os from pathlib import Path from typing import Dict, Optional import torch # type: ignore logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # HuggingFace cache location (RunPod optimization) # --------------------------------------------------------------------------- def _ensure_workspace_cache_env() -> None: """ Ensure HF/torch caches live under /workspace when available. RunPod pods typically mount a volume at /workspace; placing caches there reduces repeated downloads across restarts/redeploys. """ workspace = Path(os.environ.get("YLFF_WORKSPACE_DIR", "/workspace")) try: workspace.mkdir(parents=True, exist_ok=True) except Exception: return cache_root = workspace / ".cache" hf_root = cache_root / "huggingface" try: hf_root.mkdir(parents=True, exist_ok=True) (hf_root / "hub").mkdir(parents=True, exist_ok=True) (hf_root / "transformers").mkdir(parents=True, exist_ok=True) (cache_root / "torch").mkdir(parents=True, exist_ok=True) except Exception: # If we can't create directories, still set env defaults (caller may have perms) pass os.environ.setdefault("XDG_CACHE_HOME", str(cache_root)) os.environ.setdefault("HF_HOME", str(hf_root)) os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_root / "hub")) os.environ.setdefault("TRANSFORMERS_CACHE", str(hf_root / "transformers")) os.environ.setdefault("TORCH_HOME", str(cache_root / "torch")) _ensure_workspace_cache_env() # Optimize cuDNN for consistent input sizes (faster convolutions) if torch.backends.cudnn.is_available(): torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False # Allow non-deterministic for speed logger.debug("cuDNN benchmark mode enabled for faster training") # Available DA3 models and their characteristics DA3_MODELS = { # Main Series - Unified depth-ray representation "depth-anything/DA3-GIANT": { "series": "main", "size": "giant", "capabilities": [ "mono_depth", "multi_view_depth", "pose_conditioned_depth", "pose_estimation", "3d_gaussians", ], "metric": False, "description": "Largest model, best quality, all capabilities", }, "depth-anything/DA3-LARGE": { "series": "main", "size": "large", "capabilities": [ "mono_depth", "multi_view_depth", "pose_conditioned_depth", "pose_estimation", "3d_gaussians", ], "metric": False, "description": "Large model, good quality, all capabilities", }, "depth-anything/DA3-BASE": { "series": "main", "size": "base", "capabilities": [ "mono_depth", "multi_view_depth", "pose_conditioned_depth", "pose_estimation", "3d_gaussians", ], "metric": False, "description": "Base model, balanced quality/speed, all capabilities", }, "depth-anything/DA3-SMALL": { "series": "main", "size": "small", "capabilities": [ "mono_depth", "multi_view_depth", "pose_conditioned_depth", "pose_estimation", "3d_gaussians", ], "metric": False, "description": "Small model, fastest, all capabilities", }, # Metric Series - Real-world scale depth "depth-anything/DA3Metric-LARGE": { "series": "metric", "size": "large", "capabilities": ["mono_depth", "metric_depth"], "metric": True, "description": "Specialized for metric depth estimation (real-world scale)", }, # Monocular Series - High-quality relative depth "depth-anything/DA3Mono-LARGE": { "series": "mono", "size": "large", "capabilities": ["mono_depth"], "metric": False, "description": "High-quality relative monocular depth", }, # Nested Series - Best for metric reconstruction "depth-anything/DA3NESTED-GIANT-LARGE": { "series": "nested", "size": "giant-large", "capabilities": [ "mono_depth", "multi_view_depth", "pose_conditioned_depth", "pose_estimation", "metric_depth", ], "metric": True, "description": "Combines giant model with metric model for real-world metric scale", "recommended_for": ["ba_validation", "fine_tuning", "metric_reconstruction"], }, } def get_recommended_model(use_case: str = "ba_validation") -> str: """ Get recommended model for a specific use case. Args: use_case: One of: - "ba_validation": BA validation and fine-tuning (needs pose + metric depth) - "pose_estimation": Camera pose estimation - "metric_depth": Metric depth estimation - "mono_depth": Monocular depth estimation - "fast": Fast inference (smaller model) Returns: Recommended model name """ recommendations = { "ba_validation": "depth-anything/DA3NESTED-GIANT-LARGE", # Best: metric + pose "fine_tuning": "depth-anything/DA3NESTED-GIANT-LARGE", # Best: metric + pose "pose_estimation": "depth-anything/DA3-LARGE", # Good balance "metric_depth": "depth-anything/DA3Metric-LARGE", # Specialized "mono_depth": "depth-anything/DA3Mono-LARGE", # Specialized "fast": "depth-anything/DA3-SMALL", # Fastest "best_quality": "depth-anything/DA3-GIANT", # Highest quality } model = recommendations.get(use_case, "depth-anything/DA3-LARGE") logger.info(f"Recommended model for '{use_case}': {model}") return model def list_available_models() -> Dict[str, Dict]: """List all available DA3 models with their characteristics.""" return DA3_MODELS.copy() def get_model_info(model_name: str) -> Optional[Dict]: """Get information about a specific model.""" return DA3_MODELS.get(model_name) def load_da3_model( model_name: Optional[str] = None, device: str = "cuda", use_case: Optional[str] = None, compile_model: bool = True, compile_mode: str = "reduce-overhead", ) -> torch.nn.Module: """ Load pretrained DA3 model with optional compilation optimizations. Args: model_name: HuggingFace model name or local path. If None and use_case is provided, uses recommended model. device: Device to load model on use_case: Optional use case to get recommended model if model_name not provided compile_model: Whether to compile model with torch.compile (PyTorch 2.0+) compile_mode: Compilation mode: "default", "reduce-overhead", "max-autotune" Returns: Loaded DA3 model """ # Auto-select model if not provided if model_name is None: if use_case: model_name = get_recommended_model(use_case) logger.info(f"Auto-selected model for '{use_case}': {model_name}") else: model_name = "depth-anything/DA3-LARGE" # Default fallback logger.info(f"Using default model: {model_name}") # Get model info model_info = get_model_info(model_name) if model_info: logger.info(f"Loading {model_info['series']} series model: {model_name}") logger.info(f" Description: {model_info['description']}") if model_info.get("recommended_for"): logger.info(f" Recommended for: {', '.join(model_info['recommended_for'])}") try: # Try to import DA3 API from depth_anything_3.api import DepthAnything3 # type: ignore # Robust device selection: Fallback from MPS if not available (e.g. running on Linux) if device == "mps" and not torch.backends.mps.is_available(): if torch.cuda.is_available(): logger.warning("MPS requested but not available. Falling back to CUDA.") device = "cuda" else: logger.warning("MPS requested but not available. Falling back to CPU.") device = "cpu" # Monkeypatch torch.cuda.is_bf16_supported to avoid initializing CUDA on CPU machines # AND return True to force usage of bfloat16, which is required for CPU autocast # (PyTorch CPU AMP does not support float16, only bfloat16) if hasattr(torch.cuda, "is_bf16_supported"): logger.info("Patching torch.cuda.is_bf16_supported=True to enable CPU bfloat16 autocast") torch.cuda.is_bf16_supported = lambda: True logger.info(f"Loading DA3 model: {model_name} on {device}") model = DepthAnything3.from_pretrained(model_name) model = model.to(device) # Compile model for faster inference/training (PyTorch 2.0+) # Disable compilation on MPS (often unstable or unsupported) if device == "mps": compile_model = False logger.info("Disabling torch.compile on MPS device") if compile_model and hasattr(torch, "compile"): try: logger.info(f"Compiling model with torch.compile (mode={compile_mode})...") model = torch.compile(model, mode=compile_mode, fullgraph=False) logger.info("Model compilation successful") except Exception as e: logger.warning(f"Model compilation failed: {e}. Continuing without compilation.") elif compile_model: logger.warning( "torch.compile not available (requires PyTorch 2.0+). Skipping compilation." ) model.eval() return model except ImportError: logger.error( "DA3 not found. Install with:\n" " git clone https://github.com/ByteDance-Seed/Depth-Anything-3.git\n" " cd Depth-Anything-3\n" " pip install -e ." ) raise except Exception as e: logger.error(f"Failed to load DA3 model: {e}") raise def load_model_from_checkpoint( model: torch.nn.Module, checkpoint_path: Path, device: str = "cuda", ) -> torch.nn.Module: """ Load model weights from checkpoint. Args: model: Model architecture checkpoint_path: Path to checkpoint device: Device to load on Returns: Model with loaded weights """ checkpoint = torch.load(checkpoint_path, map_location=device) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"]) else: model.load_state_dict(checkpoint) logger.info(f"Loaded model from {checkpoint_path}") return model