Argus / argus /models /argus.py
lixi042
Initial commit: Argus metric panoramic 3D reconstruction demo
510e990
Raw
History Blame Contribute Delete
10.2 kB
import torch
import torch.nn as nn
from typing import Optional, Dict
from huggingface_hub import PyTorchModelHubMixin
# Import model components
from argus.models.aggregator import Aggregator
from argus.heads.camera_head import CameraHead
from argus.heads.dpt_head import DPTHead
from argus.heads.utils import reorder_by_reference
class Argus(nn.Module, PyTorchModelHubMixin):
"""
Argus multi-task vision model for camera pose estimation, depth prediction, and 3D points.
Integrates an aggregator backbone with task-specific heads for:
- Camera pose encoding
- Depth map prediction
- 3D camera/rotated/world point prediction
Args:
img_size: Input image size (height/width, assumes square) (default: 518)
patch_size: Patch size for vision transformer backbone (default: 14)
embed_dim: Embedding dimension for transformer features (default: 1024)
enable_camera: Enable camera pose estimation head (default: True)
enable_depth: Enable depth prediction head (default: True)
enable_cam_point: Enable camera coordinate 3D point prediction head (default: False)
enable_rotated_point: Enable rotated 3D point prediction head (default: False)
enable_point: Enable world coordinate 3D point prediction head (default: False, Please do not set it to True during training)
Note:
All heads share the same aggregated transformer features from the Aggregator backbone.
Each DPT-based head outputs both predictions and confidence scores.
"""
def __init__(
self,
img_size: int = 518,
patch_size: int = 14,
embed_dim: int = 1024,
enable_camera: bool = True,
enable_depth: bool = True,
enable_cam_point: bool = False,
enable_rotated_point: bool = False,
enable_point: bool = False,
reorder_by_learning_ref: bool = True,
restore_metric_scale: bool = False
) -> None:
super().__init__()
# For inference
self.restore_metric_scale = restore_metric_scale
self.reorder_by_learning_ref = reorder_by_learning_ref
# Backbone and geometry transformer
self.aggregator = Aggregator(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
reorder_by_learning_ref=reorder_by_learning_ref,
)
# Task-specific prediction heads (lazy initialization based on flags)
self.camera_head: Optional[CameraHead] = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
self.depth_head: Optional[DPTHead] = DPTHead(
dim_in=2 * embed_dim,
output_dim=2,
activation="exp",
conf_activation="expp1"
) if enable_depth else None
# 3D point prediction heads (shared architecture, different output semantics)
self.cam_point_head: Optional[DPTHead] = DPTHead(
dim_in=2 * embed_dim,
output_dim=4,
activation="inv_log",
conf_activation="expp1"
) if enable_cam_point else None
self.rotated_point_head: Optional[DPTHead] = DPTHead(
dim_in=2 * embed_dim,
output_dim=4,
activation="inv_log",
conf_activation="expp1"
) if enable_rotated_point else None
self.point_head: Optional[DPTHead] = DPTHead(
dim_in=2 * embed_dim,
output_dim=4,
activation="inv_log",
conf_activation="expp1"
) if enable_point else None
def forward(
self,
images: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Forward pass of the Argus model.
Automatically adds batch dimension if missing and processes multi-task predictions.
Args:
images: Input RGB images with shape:
- [S, 3, H, W] (sequence without batch) or
- [B, S, 3, H, W] (batch of sequences)
Values in range [0, 1], where:
- B: batch size
- S: sequence length (number of frames)
- 3: RGB channels
- H/W: image height/width (matches img_size)
Returns:
Dictionary of model predictions with task-specific outputs:
Common outputs:
- covisibility_scores: Covisibility scores from aggregator (shape varies)
- ref_idx: Reference frame indices (shape varies)
Camera head outputs (if enabled):
- pose_enc: Final camera pose encoding [B, S, 9]
- pose_enc_list: List of pose encodings from all iterations [List[torch.Tensor]]
Depth head outputs (if enabled):
- depth: Predicted depth maps [B, S, H, W, 1]
- depth_conf: Depth prediction confidence [B, S, H, W]
Camera point head outputs (if enabled):
- cam_points: 3D camera coordinates per pixel [B, S, H, W, 3]
- cam_points_conf: Camera point confidence [B, S, H, W]
Rotated point head outputs (if enabled):
- rotated_points: Rotated 3D coordinates per pixel [B, S, H, W, 3]
- rotated_points_conf: Rotated point confidence [B, S, H, W]
World point head outputs (if enabled):
- world_points: 3D world coordinates per pixel [B, S, H, W, 3]
- world_points_conf: World point confidence [B, S, H, W]
Inference-only outputs (not training):
- images: Original input images (for visualization) [B, S, 3, H, W]
"""
# Add batch dimension if missing (handle [S,3,H,W] -> [1,S,3,H,W])
if len(images.shape) == 4:
images = images.unsqueeze(0)
# Extract aggregated features from backbone
(
aggregated_tokens_list, # List of aggregated transformer tokens across iterations
patch_start_idx, # Patch start indices for feature reconstruction
covisibility_scores, # Covisibility scores between frames
ref_idx # Reference frame indices
) = self.aggregator(images)
# Initialize prediction dictionary
predictions: Dict[str, torch.Tensor] = {}
# Disable mixed precision for precise prediction calculations
with torch.amp.autocast("cuda", enabled=False):
# Add aggregator outputs to predictions
if covisibility_scores is not None:
predictions["covisibility_scores"] = covisibility_scores
if ref_idx is not None:
predictions["ref_idx"] = ref_idx
# Camera pose prediction (if enabled)
if self.camera_head is not None:
pose_enc_list = self.camera_head(aggregated_tokens_list)
predictions["pose_enc"] = pose_enc_list[-1] # Use final iteration encoding
predictions["pose_enc_list"] = pose_enc_list # Mutil-layer supervision
# Depth prediction (if enabled)
if self.depth_head is not None:
depth, depth_conf = self.depth_head(
aggregated_tokens_list,
images=images,
patch_start_idx=patch_start_idx
)
predictions["depth"] = depth
predictions["depth_conf"] = depth_conf
# Camera 3D point prediction (if enabled)
if self.cam_point_head is not None:
cam_pts3d, cam_pts3d_conf = self.cam_point_head(
aggregated_tokens_list,
images=images,
patch_start_idx=patch_start_idx
)
predictions["cam_points"] = cam_pts3d
predictions["cam_points_conf"] = cam_pts3d_conf
# Rotated 3D point prediction (if enabled)
if self.rotated_point_head is not None:
rotated_pts3d, rotated_pts3d_conf = self.rotated_point_head(
aggregated_tokens_list,
images=images,
patch_start_idx=patch_start_idx
)
predictions["rotated_points"] = rotated_pts3d
predictions["rotated_points_conf"] = rotated_pts3d_conf
# World 3D point prediction (if enabled)
if self.point_head is not None:
world_pts3d, world_pts3d_conf = self.point_head(
aggregated_tokens_list,
images=images,
patch_start_idx=patch_start_idx
)
predictions["world_points"] = world_pts3d
predictions["world_points_conf"] = world_pts3d_conf
# Store input images for visualization during inference (skip in training)
if not self.training:
predictions["images"] = images
if "ref_idx" in predictions:
ref_idx = predictions["ref_idx"].detach()
# Reorder all spatial/temporal data (exclude adjacency matrix and IDs)
predictions["images"] = reorder_by_reference(predictions["images"], ref_idx)
if self.restore_metric_scale:
# Restore metric scale
abs_scale = 10.0
if self.camera_head is not None:
predictions["pose_enc"][...,:3] *= abs_scale
if self.depth_head is not None:
predictions["depth"] *= abs_scale
if self.cam_point_head is not None:
predictions["cam_points"] *= abs_scale
if self.rotated_point_head is not None:
predictions["rotated_points"] *= abs_scale
if self.point_head is not None:
predictions["world_points"] *= abs_scale
return predictions