|
|
""" |
|
|
Model loading utilities for DA3 and other models. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_workspace_cache_env() -> None: |
|
|
""" |
|
|
Ensure HF/torch caches live under /workspace when available. |
|
|
|
|
|
RunPod pods typically mount a volume at /workspace; placing caches there reduces |
|
|
repeated downloads across restarts/redeploys. |
|
|
""" |
|
|
workspace = Path(os.environ.get("YLFF_WORKSPACE_DIR", "/workspace")) |
|
|
try: |
|
|
workspace.mkdir(parents=True, exist_ok=True) |
|
|
except Exception: |
|
|
return |
|
|
|
|
|
cache_root = workspace / ".cache" |
|
|
hf_root = cache_root / "huggingface" |
|
|
try: |
|
|
hf_root.mkdir(parents=True, exist_ok=True) |
|
|
(hf_root / "hub").mkdir(parents=True, exist_ok=True) |
|
|
(hf_root / "transformers").mkdir(parents=True, exist_ok=True) |
|
|
(cache_root / "torch").mkdir(parents=True, exist_ok=True) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
os.environ.setdefault("XDG_CACHE_HOME", str(cache_root)) |
|
|
os.environ.setdefault("HF_HOME", str(hf_root)) |
|
|
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_root / "hub")) |
|
|
os.environ.setdefault("TRANSFORMERS_CACHE", str(hf_root / "transformers")) |
|
|
os.environ.setdefault("TORCH_HOME", str(cache_root / "torch")) |
|
|
|
|
|
|
|
|
_ensure_workspace_cache_env() |
|
|
|
|
|
|
|
|
if torch.backends.cudnn.is_available(): |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cudnn.deterministic = False |
|
|
logger.debug("cuDNN benchmark mode enabled for faster training") |
|
|
|
|
|
|
|
|
|
|
|
DA3_MODELS = { |
|
|
|
|
|
"depth-anything/DA3-GIANT": { |
|
|
"series": "main", |
|
|
"size": "giant", |
|
|
"capabilities": [ |
|
|
"mono_depth", |
|
|
"multi_view_depth", |
|
|
"pose_conditioned_depth", |
|
|
"pose_estimation", |
|
|
"3d_gaussians", |
|
|
], |
|
|
"metric": False, |
|
|
"description": "Largest model, best quality, all capabilities", |
|
|
}, |
|
|
"depth-anything/DA3-LARGE": { |
|
|
"series": "main", |
|
|
"size": "large", |
|
|
"capabilities": [ |
|
|
"mono_depth", |
|
|
"multi_view_depth", |
|
|
"pose_conditioned_depth", |
|
|
"pose_estimation", |
|
|
"3d_gaussians", |
|
|
], |
|
|
"metric": False, |
|
|
"description": "Large model, good quality, all capabilities", |
|
|
}, |
|
|
"depth-anything/DA3-BASE": { |
|
|
"series": "main", |
|
|
"size": "base", |
|
|
"capabilities": [ |
|
|
"mono_depth", |
|
|
"multi_view_depth", |
|
|
"pose_conditioned_depth", |
|
|
"pose_estimation", |
|
|
"3d_gaussians", |
|
|
], |
|
|
"metric": False, |
|
|
"description": "Base model, balanced quality/speed, all capabilities", |
|
|
}, |
|
|
"depth-anything/DA3-SMALL": { |
|
|
"series": "main", |
|
|
"size": "small", |
|
|
"capabilities": [ |
|
|
"mono_depth", |
|
|
"multi_view_depth", |
|
|
"pose_conditioned_depth", |
|
|
"pose_estimation", |
|
|
"3d_gaussians", |
|
|
], |
|
|
"metric": False, |
|
|
"description": "Small model, fastest, all capabilities", |
|
|
}, |
|
|
|
|
|
"depth-anything/DA3Metric-LARGE": { |
|
|
"series": "metric", |
|
|
"size": "large", |
|
|
"capabilities": ["mono_depth", "metric_depth"], |
|
|
"metric": True, |
|
|
"description": "Specialized for metric depth estimation (real-world scale)", |
|
|
}, |
|
|
|
|
|
"depth-anything/DA3Mono-LARGE": { |
|
|
"series": "mono", |
|
|
"size": "large", |
|
|
"capabilities": ["mono_depth"], |
|
|
"metric": False, |
|
|
"description": "High-quality relative monocular depth", |
|
|
}, |
|
|
|
|
|
"depth-anything/DA3NESTED-GIANT-LARGE": { |
|
|
"series": "nested", |
|
|
"size": "giant-large", |
|
|
"capabilities": [ |
|
|
"mono_depth", |
|
|
"multi_view_depth", |
|
|
"pose_conditioned_depth", |
|
|
"pose_estimation", |
|
|
"metric_depth", |
|
|
], |
|
|
"metric": True, |
|
|
"description": "Combines giant model with metric model for real-world metric scale", |
|
|
"recommended_for": ["ba_validation", "fine_tuning", "metric_reconstruction"], |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def get_recommended_model(use_case: str = "ba_validation") -> str: |
|
|
""" |
|
|
Get recommended model for a specific use case. |
|
|
|
|
|
Args: |
|
|
use_case: One of: |
|
|
- "ba_validation": BA validation and fine-tuning (needs pose + metric depth) |
|
|
- "pose_estimation": Camera pose estimation |
|
|
- "metric_depth": Metric depth estimation |
|
|
- "mono_depth": Monocular depth estimation |
|
|
- "fast": Fast inference (smaller model) |
|
|
|
|
|
Returns: |
|
|
Recommended model name |
|
|
""" |
|
|
recommendations = { |
|
|
"ba_validation": "depth-anything/DA3NESTED-GIANT-LARGE", |
|
|
"fine_tuning": "depth-anything/DA3NESTED-GIANT-LARGE", |
|
|
"pose_estimation": "depth-anything/DA3-LARGE", |
|
|
"metric_depth": "depth-anything/DA3Metric-LARGE", |
|
|
"mono_depth": "depth-anything/DA3Mono-LARGE", |
|
|
"fast": "depth-anything/DA3-SMALL", |
|
|
"best_quality": "depth-anything/DA3-GIANT", |
|
|
} |
|
|
|
|
|
model = recommendations.get(use_case, "depth-anything/DA3-LARGE") |
|
|
logger.info(f"Recommended model for '{use_case}': {model}") |
|
|
return model |
|
|
|
|
|
|
|
|
def list_available_models() -> Dict[str, Dict]: |
|
|
"""List all available DA3 models with their characteristics.""" |
|
|
return DA3_MODELS.copy() |
|
|
|
|
|
|
|
|
def get_model_info(model_name: str) -> Optional[Dict]: |
|
|
"""Get information about a specific model.""" |
|
|
return DA3_MODELS.get(model_name) |
|
|
|
|
|
|
|
|
def load_da3_model( |
|
|
model_name: Optional[str] = None, |
|
|
device: str = "cuda", |
|
|
use_case: Optional[str] = None, |
|
|
compile_model: bool = True, |
|
|
compile_mode: str = "reduce-overhead", |
|
|
) -> torch.nn.Module: |
|
|
""" |
|
|
Load pretrained DA3 model with optional compilation optimizations. |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model name or local path. |
|
|
If None and use_case is provided, uses recommended model. |
|
|
device: Device to load model on |
|
|
use_case: Optional use case to get recommended model if model_name not provided |
|
|
compile_model: Whether to compile model with torch.compile (PyTorch 2.0+) |
|
|
compile_mode: Compilation mode: "default", "reduce-overhead", "max-autotune" |
|
|
|
|
|
Returns: |
|
|
Loaded DA3 model |
|
|
""" |
|
|
|
|
|
if model_name is None: |
|
|
if use_case: |
|
|
model_name = get_recommended_model(use_case) |
|
|
logger.info(f"Auto-selected model for '{use_case}': {model_name}") |
|
|
else: |
|
|
model_name = "depth-anything/DA3-LARGE" |
|
|
logger.info(f"Using default model: {model_name}") |
|
|
|
|
|
|
|
|
model_info = get_model_info(model_name) |
|
|
if model_info: |
|
|
logger.info(f"Loading {model_info['series']} series model: {model_name}") |
|
|
logger.info(f" Description: {model_info['description']}") |
|
|
if model_info.get("recommended_for"): |
|
|
logger.info(f" Recommended for: {', '.join(model_info['recommended_for'])}") |
|
|
|
|
|
try: |
|
|
|
|
|
from depth_anything_3.api import DepthAnything3 |
|
|
|
|
|
|
|
|
if device == "mps" and not torch.backends.mps.is_available(): |
|
|
if torch.cuda.is_available(): |
|
|
logger.warning("MPS requested but not available. Falling back to CUDA.") |
|
|
device = "cuda" |
|
|
else: |
|
|
logger.warning("MPS requested but not available. Falling back to CPU.") |
|
|
device = "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(torch.cuda, "is_bf16_supported"): |
|
|
logger.info("Patching torch.cuda.is_bf16_supported=True to enable CPU bfloat16 autocast") |
|
|
torch.cuda.is_bf16_supported = lambda: True |
|
|
|
|
|
logger.info(f"Loading DA3 model: {model_name} on {device}") |
|
|
model = DepthAnything3.from_pretrained(model_name) |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
if device == "mps": |
|
|
compile_model = False |
|
|
logger.info("Disabling torch.compile on MPS device") |
|
|
|
|
|
if compile_model and hasattr(torch, "compile"): |
|
|
try: |
|
|
logger.info(f"Compiling model with torch.compile (mode={compile_mode})...") |
|
|
model = torch.compile(model, mode=compile_mode, fullgraph=False) |
|
|
logger.info("Model compilation successful") |
|
|
except Exception as e: |
|
|
logger.warning(f"Model compilation failed: {e}. Continuing without compilation.") |
|
|
elif compile_model: |
|
|
logger.warning( |
|
|
"torch.compile not available (requires PyTorch 2.0+). Skipping compilation." |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
except ImportError: |
|
|
logger.error( |
|
|
"DA3 not found. Install with:\n" |
|
|
" git clone https://github.com/ByteDance-Seed/Depth-Anything-3.git\n" |
|
|
" cd Depth-Anything-3\n" |
|
|
" pip install -e ." |
|
|
) |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load DA3 model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def load_model_from_checkpoint( |
|
|
model: torch.nn.Module, |
|
|
checkpoint_path: Path, |
|
|
device: str = "cuda", |
|
|
) -> torch.nn.Module: |
|
|
""" |
|
|
Load model weights from checkpoint. |
|
|
|
|
|
Args: |
|
|
model: Model architecture |
|
|
checkpoint_path: Path to checkpoint |
|
|
device: Device to load on |
|
|
|
|
|
Returns: |
|
|
Model with loaded weights |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
if "model_state_dict" in checkpoint: |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
|
|
|
logger.info(f"Loaded model from {checkpoint_path}") |
|
|
return model |
|
|
|