diff --git a/README.md b/README.md
index a5190c8f35d966252e3cb37d0eb58c0aeba6848e..30fd8e650c3c61d52afcffed6c7513b7b574eb84 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,67 @@
---
-title: UniSH
-emoji: 🏆
-colorFrom: pink
-colorTo: red
+title: UniSH (Unified Scene & Human Reconstruction)
+emoji: 🏃♂️
+colorFrom: blue
+colorTo: purple
sdk: gradio
-sdk_version: 6.3.0
+sdk_version: 5.0.0
app_file: app.py
pinned: false
-license: apache-2.0
+license: cc-by-nc-4.0
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass
+
+
+
+Mengfei Li1, Peng Li1, Zheng Zhang2, Jiahao Lu1, Chengfeng Zhao1, Wei Xue1,
+Qifeng Liu1, Sida Peng3, Wenxiao Zhang1, Wenhan Luo1, Yuan Liu1†, Yike Guo1†
+
+1The Hong Kong University of Science and Technology, 2Beijing University of Posts and Telecommunications, 3Zhejiang University
+
+
+
+
+
+
+
+## Abstract
+
+We present UniSH, a unified, feed-forward framework for joint metric-scale 3D scene and human reconstruction. A key challenge in this domain is the scarcity of large-scale, annotated real-world data, forcing a reliance on synthetic datasets. This reliance introduces a significant sim-to-real domain gap, leading to poor generalization, low-fidelity human geometry, and poor alignment on in-the-wild videos.
+
+To address this, we propose an innovative training paradigm that effectively leverages unlabeled in-the-wild data. Our framework bridges strong, disparate priors from scene reconstruction and HMR, and is trained with two core components: (1) a robust distillation strategy to refine human surface details by distilling high-frequency details from an expert depth model, and (2) a two-stage supervision scheme, which first learns coarse localization on synthetic data, then fine-tunes on real data by directly optimizing the geometric correspondence between the SMPL mesh and the human point cloud. This approach enables our feed-forward model to jointly recover high-fidelity scene geometry, human point clouds, camera parameters, and coherent, metric-scale SMPL bodies, all in a single forward pass. Extensive experiments demonstrate that our model achieves state-of-the-art performance on human-centric scene reconstruction and delivers highly competitive results on global human motion estimation, comparing favorably against both optimization-based frameworks and HMR-only methods.
+
+## Method
+
+
+
+**The network architecture of UniSH.**
+UniSH takes a monocular video as input. The video frames are processed by the **Reconstruction Branch** to predict per-frame camera extrinsics *E*, confidence maps *C*, and pointmaps *P*. Camera intrinsics *K* are derived from the pointmaps. Human crops from the video are fed into the **Human Body Branch** along with *K* to estimate global SMPL shape parameters *β* and per-frame pose parameters *θi*. Features from both branches are processed by **AlignNet** to predict the global scene scale *s* and per-frame SMPL translations *ti* for coherent scene and human alignment.
+
+## Usage
+
+This Space provides an interactive demo for UniSH.
+
+1. **Upload a Video**: Upload a monocular video containing a human.
+2. **Set Duration**: Choose the duration to process (default: 3 seconds).
+3. **Run Inference**: Click "Run Inference" to generate the 3D reconstruction.
+4. **Visualize**: The result will be displayed in an interactive 3D viewer where you can rotate, pan, and zoom.
+
+## BibTeX
+
+```bibtex
+@misc{li2026unishunifyingscenehuman,
+ title={UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass},
+ author={Mengfei Li and Peng Li and Zheng Zhang and Jiahao Lu and Chengfeng Zhao and Wei Xue and Qifeng Liu and Sida Peng and Wenxiao Zhang and Wenhan Luo and Yuan Liu and Yike Guo},
+ year={2026},
+ eprint={2601.01222},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2601.01222},
+}
+```
+
+## Acknowledgements
+
+This website is licensed under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
+Template borrowed from [Nerfies](https://github.com/nerfies/nerfies.github.io).
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa5114b716f3f97964d29c098128020bf02f7c2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,457 @@
+import gradio as gr
+import spaces
+import os
+import sys
+import shutil
+import tempfile
+import torch
+import cv2
+import subprocess
+import numpy as np
+import trimesh
+from huggingface_hub import hf_hub_download
+
+# Add current directory to path
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+from unish.utils.inference_utils import (
+ load_model,
+ process_video,
+ run_inference,
+ generate_mixed_geometries_in_memory,
+ save_smpl_meshes_per_frame,
+ save_scene_only_point_clouds,
+ save_human_point_clouds,
+ save_camera_parameters_per_frame
+)
+
+MODEL = None
+BODY_MODELS_PATH = "body_models/"
+
+def download_smpl_assets(body_models_path):
+ """
+ Download SMPL models from private repository if they don't exist.
+ The path logic mimics SMPLWrapper's expectation:
+ 1. SMPLWrapper appends 'smpl' if not present in body_models_path.
+ 2. smplx library expects another 'smpl' folder inside that (or appends it).
+ Based on existing structure 'body_models/smpl/smpl/SMPL_*.pkl', the target dir is constructed below.
+ """
+ if 'smpl' not in body_models_path:
+ model_path = os.path.join(body_models_path, 'smpl')
+ else:
+ model_path = body_models_path
+
+ # smplx looks for a 'smpl' folder inside the given model_path
+ target_dir = os.path.join(model_path, 'smpl')
+
+ os.makedirs(target_dir, exist_ok=True)
+
+ files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"]
+ token = os.environ.get("SMPL_DOWNLOAD_TOKEN")
+
+ for filename in files:
+ file_path = os.path.join(target_dir, filename)
+ if not os.path.exists(file_path):
+ if not token:
+ print(f"Warning: SMPL_DOWNLOAD_TOKEN not set. Cannot download {filename}.")
+ continue
+
+ print(f"Downloading {filename} to {target_dir}...")
+ try:
+ hf_hub_download(
+ repo_id="Murphyyyy/UniSH-Private-Assets",
+ filename=filename,
+ local_dir=target_dir,
+ token=token
+ )
+ except Exception as e:
+ print(f"Failed to download {filename}: {e}")
+
+def pack_sequence_to_glb(base_dir, output_path, start_frame=0, end_frame=60, scene_rate=0.5):
+ scene = trimesh.Scene()
+
+ print(f">>> Packing frames {start_frame} to {end_frame}...")
+
+ valid_count = 0
+
+ for i in range(start_frame, end_frame):
+ frame_node_name = f"frame_{valid_count}"
+
+ s_path = os.path.join(base_dir, "scene_only_point_clouds", f"scene_only_frame_{i:04d}.ply")
+ h_path = os.path.join(base_dir, "human_only_point_clouds", f"human_frame_{i:04d}.ply")
+ smpl_path = os.path.join(base_dir, "smpl_meshes_per_frame", f"smpl_mesh_frame_{i:04d}.ply")
+
+ if not (os.path.exists(h_path) or os.path.exists(smpl_path)):
+ continue
+
+ scene.graph.update(frame_node_name, parent="world")
+
+ if os.path.exists(smpl_path):
+ try:
+ smpl = trimesh.load(smpl_path)
+ flesh_color = [255, 160, 122, 255]
+ smpl.visual.vertex_colors = np.tile(flesh_color, (len(smpl.vertices), 1))
+
+ scene.add_geometry(smpl, node_name=f"{frame_node_name}_smpl", parent_node_name=frame_node_name)
+ except Exception as e:
+ pass
+
+ if os.path.exists(h_path):
+ try:
+ human = trimesh.load(h_path)
+ if isinstance(human, trimesh.PointCloud):
+ scene.add_geometry(human, node_name=f"{frame_node_name}_human", parent_node_name=frame_node_name)
+ except: pass
+
+ if os.path.exists(s_path):
+ try:
+ s_obj = trimesh.load(s_path)
+ if isinstance(s_obj, trimesh.PointCloud):
+ total_pts = len(s_obj.vertices)
+ if total_pts > 0:
+ if scene_rate < 0.99:
+ count = int(total_pts * scene_rate)
+ if count > 100:
+ idx = np.random.choice(total_pts, count, replace=False)
+ s_obj = trimesh.PointCloud(s_obj.vertices[idx], colors=s_obj.colors[idx])
+ scene.add_geometry(s_obj, node_name=f"{frame_node_name}_scene", parent_node_name=frame_node_name)
+ except: pass
+
+ valid_count += 1
+
+ if valid_count == 0:
+ print("Error: No valid frames found.")
+ return
+
+ try:
+ rot = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0])
+ scene.apply_transform(rot)
+ except: pass
+
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ print(f">>> Exporting to {output_path}...")
+ scene.export(output_path)
+ print(f">>> Done! Saved {valid_count} frames.")
+
+def get_player_html(glb_abs_path):
+ html_content = f"""
+
+
+
+
+ UniSH Viewer
+
+
+
+
+
+
+
+
+
+
+
Loading 3D Sequence...
+
+
+
+
+
+
+
+
+
+
+
+ Frame: 0
+
+
+
+
+
+
+
+ """
+ return html_content
+
+@spaces.GPU(duration=120)
+def predict(video_path, duration_seconds=3.0):
+ global MODEL
+
+ # 0. Setup directories
+ output_dir = tempfile.mkdtemp()
+
+ # 1. Trim video
+ duration = min(float(duration_seconds), 10.0)
+ trimmed_video_path = os.path.join(output_dir, "input_trimmed.mp4")
+
+ cmd = [
+ "ffmpeg", "-i", video_path,
+ "-t", str(duration),
+ "-c:v", "libx264", "-c:a", "aac",
+ trimmed_video_path, "-y"
+ ]
+ subprocess.run(cmd, check=True)
+
+ # 2. Load Model
+ if MODEL is None:
+ MODEL = load_model()
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ MODEL.to(device)
+ MODEL.eval()
+
+ # 3. Process Video
+ fps = 6.0
+ target_size = 518
+ human_idx = 0
+ bbox_scale = 1.0
+
+ # Check and download SMPL assets
+ download_smpl_assets(BODY_MODELS_PATH)
+
+ data_dict = process_video(
+ trimmed_video_path, fps, human_idx, target_size,
+ bbox_scale=bbox_scale
+ )
+
+ # 4. Run Inference
+ results = run_inference(MODEL, data_dict, device, chunk_size=30)
+
+ # 5. Generate Geometries & Save
+ seq_name = results['seq_name']
+
+ viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
+ results, BODY_MODELS_PATH, fps=fps, conf_thres=0.1
+ )
+
+ # Save to disk
+ save_smpl_meshes_per_frame(results, output_dir, BODY_MODELS_PATH)
+ save_scene_only_point_clouds(viz_scene_only_point_clouds, output_dir, seq_name)
+ save_human_point_clouds(viz_scene_point_clouds, viz_scene_only_point_clouds, output_dir, seq_name, results)
+
+ # 6. Pack to GLB
+ base_dir = os.path.join(output_dir, seq_name)
+ output_glb_path = os.path.join(output_dir, "output.glb")
+
+ num_frames = len(viz_scene_point_clouds)
+
+ pack_sequence_to_glb(
+ base_dir,
+ output_glb_path,
+ start_frame=0,
+ end_frame=num_frames,
+ scene_rate=0.5
+ )
+
+ return get_player_html(output_glb_path)
+
+with gr.Blocks() as demo:
+ gr.Markdown("# UniSH Demo")
+ gr.Markdown("Upload a video to reconstruct scene and human in 3D.")
+
+ with gr.Row():
+ with gr.Column():
+ input_video = gr.Video(label="Input Video")
+ duration_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Duration to Process (seconds)")
+ submit_btn = gr.Button("Run Inference", variant="primary")
+
+ with gr.Column():
+ output_html = gr.HTML(label="3D Result", min_height=600)
+
+ submit_btn.click(
+ predict,
+ inputs=[input_video, duration_slider],
+ outputs=[output_html]
+ )
+
+demo.queue()
+demo.launch()
+
+
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..380f78e208310132e58e2fc2e28bbf9deb42555e
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,14 @@
+name: unish
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.10
+ - pip
+ - git
+ - ninja
+ - mesalib
+ - libgl-devel
+ - libegl-devel
+ - gxx_linux-64=11.*
+ - ffmpeg
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..568421f811e25e17fc4aa7ccb4c6735bdcd9990c
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,186 @@
+import argparse
+import os
+import torch
+import numpy as np
+import random
+import logging
+from unish.utils.inference_utils import *
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+def setup_logging(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Create logger
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO)
+
+ # Create handlers
+ c_handler = logging.StreamHandler()
+ f_handler = logging.FileHandler(os.path.join(output_dir, 'inference.log'), mode='w')
+
+ # Create formatters and add it to handlers
+ c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
+ c_handler.setFormatter(c_format)
+ f_handler.setFormatter(f_format)
+
+ # Add handlers to the logger
+ logger.addHandler(c_handler)
+ logger.addHandler(f_handler)
+
+ return logger
+
+def main():
+ parser = argparse.ArgumentParser(description="Video Inference Script")
+ parser.add_argument("--video_path", type=str, required=True,
+ help="Path to the input video file or directory containing images")
+ parser.add_argument("--fps", type=float, default=6.0,
+ help="Target FPS for frame extraction (default: 6.0)")
+ parser.add_argument("--original_fps", type=float, default=30.0,
+ help="Original FPS of the image sequence (default: 30.0, used only for directory input)")
+ parser.add_argument("--target_size", type=int, default=518,
+ help="Target size for frame processing (default: 518)")
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/unish_release.safetensors",
+ help="Path to the model checkpoint")
+ parser.add_argument("--output_dir", type=str, default="inference_results_video",
+ help="Output directory for results")
+ parser.add_argument("--body_models_path", type=str, default="body_models/",
+ help="Path to SMPL body models")
+ parser.add_argument("--device", type=str, default="cuda",
+ help="Device to run inference on")
+ parser.add_argument("--save_results", action="store_true", default=True,
+ help="Save additional results including smpl_points_for_camera (default: True)")
+ parser.add_argument("--chunk_size", type=int, default=30,
+ help="Number of frames to process in each chunk during inference (default: 30)")
+ parser.add_argument("--gpu_id", type=int, default=0,
+ help="GPU ID to use for inference (default: 0)")
+ parser.add_argument("--camera_mode", type=str, default="fixed",
+ choices=["predicted", "fixed"],
+ help="Camera mode: 'predicted' uses model-predicted camera parameters, "
+ "'fixed' uses a fixed camera angle (default: predicted)")
+ parser.add_argument("--human_idx", type=int, default=0,
+ help="Human index to process (default: 0)")
+ parser.add_argument("--start_idx", type=int, default=None,
+ help="Start frame index for processing (default: None, process from beginning)")
+ parser.add_argument("--end_idx", type=int, default=None,
+ help="End frame index for processing (default: None, process to end)")
+ parser.add_argument("--bbox_scale", type=float, default=1.0,
+ help="Scale factor for bounding box size (default: 1.0)")
+ parser.add_argument("--conf_thres", type=float, default=0.1,
+ help="Confidence threshold for point cloud generation (default: 0.1)")
+
+ # New arguments
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+ parser.add_argument("--yolo_ckpt", type=str, default="ckpts/yolo11n.pt", help="Path to YOLO checkpoint")
+ parser.add_argument("--sam2_model", type=str, default="facebook/sam2-hiera-large", help="SAM2 model name or path")
+
+ args = parser.parse_args()
+
+ # Setup seed
+ setup_seed(args.seed)
+
+ # Setup logging
+ logger = setup_logging(args.output_dir)
+
+ # Setup device
+ if torch.cuda.is_available():
+ if args.device == "cuda":
+ # Use specified GPU ID
+ device = torch.device(f"cuda:{args.gpu_id}")
+ # Set the current CUDA device
+ torch.cuda.set_device(args.gpu_id)
+ logger.info(
+ f"Using GPU {args.gpu_id}: {torch.cuda.get_device_name(args.gpu_id)}")
+ else:
+ device = torch.device(args.device)
+ else:
+ device = torch.device("cpu")
+ logger.info("CUDA not available, using CPU")
+
+ logger.info(f"Using device: {device}")
+
+ # Load model
+ logger.info("Loading model...")
+ model = load_model(args.checkpoint)
+ model = model.to(device)
+ model.eval()
+
+ # Process video
+ logger.info(f"Processing video: {args.video_path}")
+ data_dict = process_video(
+ args.video_path, args.fps, args.human_idx, args.target_size,
+ bbox_scale=args.bbox_scale, start_idx=args.start_idx, end_idx=args.end_idx,
+ original_fps=args.original_fps,
+ yolo_ckpt=args.yolo_ckpt, sam2_model=args.sam2_model
+ )
+
+ # Run inference
+ results = run_inference(model, data_dict, device, args.chunk_size)
+
+ # Create output directory
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
+ results, args.body_models_path, fps=args.fps, conf_thres=args.conf_thres
+ )
+
+ # Determine camera mode based on arguments
+ use_predicted_camera = (args.camera_mode == "predicted")
+ logger.info(f"Using {args.camera_mode} camera mode")
+
+ original_rgb_images = results['rgb_images']
+
+ if original_rgb_images is not None:
+ if hasattr(original_rgb_images, 'permute'): # It's a torch tensor
+ original_rgb_images = original_rgb_images.permute(
+ 0, 2, 3, 1).cpu().numpy() # [S, H, W, 3]
+ elif not isinstance(original_rgb_images, np.ndarray):
+ original_rgb_images = np.array(original_rgb_images)
+
+ # Ensure proper data type and range
+ if original_rgb_images.max() <= 1.0:
+ original_rgb_images = (
+ original_rgb_images * 255).astype(np.uint8)
+
+ original_human_boxes = data_dict['human_boxes']
+
+ run_visualization(viz_scene_point_clouds, viz_smpl_meshes, smpl_points_for_camera,
+ args.output_dir, results['seq_name'],
+ fps=args.fps, # Use original fps
+ rgb_images=original_rgb_images,
+ human_boxes=original_human_boxes,
+ chunk_size=args.chunk_size, # Use original chunk size
+ results=results,
+ use_predicted_camera=use_predicted_camera,
+ scene_only_point_clouds=viz_scene_only_point_clouds,
+ conf_thres=args.conf_thres)
+
+ if args.save_results:
+
+ logger.info("Creating SMPL meshes per frame...")
+ save_smpl_meshes_per_frame(
+ results, args.output_dir, args.body_models_path)
+
+ logger.info("Saving scene point clouds (without human)...")
+ save_scene_only_point_clouds(
+ viz_scene_only_point_clouds, args.output_dir, results['seq_name'])
+
+ logger.info("Saving human point clouds...")
+ save_human_point_clouds(viz_scene_point_clouds,
+ viz_scene_only_point_clouds, args.output_dir, results['seq_name'], results)
+
+ logger.info("Saving camera parameters per frame...")
+ save_camera_parameters_per_frame(
+ results, args.output_dir, results['seq_name'])
+
+ logger.info(f"Inference completed! Results saved to {args.output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..4b447e21ed61920cbebac4612f56aa999fd4ffb9
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,74 @@
+#!/bin/bash
+set -e
+
+# ==========================================
+# UniSH Auto-Install Script
+# ==========================================
+
+get_cuda_version() {
+ if [ ! -z "$1" ]; then echo "$1"; return; fi
+ if command -v nvidia-smi &> /dev/null; then
+ DRIVER_CUDA_MAJOR=$(nvidia-smi | grep "CUDA Version" | awk -F'CUDA Version:' '{print $2}' | awk -F'.' '{print $1}' | tr -d '[:space:]')
+ if [ "$DRIVER_CUDA_MAJOR" == "12" ]; then echo "12.1"; elif [ "$DRIVER_CUDA_MAJOR" == "11" ]; then echo "11.8"; else echo "12.1"; fi
+ else echo "12.1"; fi
+}
+
+if [[ -z "$CONDA_PREFIX" ]]; then
+ echo "❌ Error: Please activate the conda environment first!"
+ exit 1
+fi
+
+TARGET_CUDA=$(get_cuda_version "$1")
+echo "========================================"
+echo " Detected/Selected CUDA: $TARGET_CUDA"
+echo "========================================"
+
+if [[ "$TARGET_CUDA" == "12.1" ]]; then TORCH_INDEX_URL="https://download.pytorch.org/whl/cu121";
+elif [[ "$TARGET_CUDA" == "11.8" ]]; then TORCH_INDEX_URL="https://download.pytorch.org/whl/cu118";
+else TORCH_INDEX_URL=""; fi
+
+echo "[1/6] Installing PyTorch 2.4.1 (CUDA $TARGET_CUDA)..."
+pip install torch==2.4.1 torchvision==0.19.1 --index-url $TORCH_INDEX_URL
+
+echo "[2/6] Installing Safe Requirements..."
+pip install -r requirements.txt
+
+echo "[3/6] Installing Custom Utils3D..."
+pip install "git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183"
+
+echo "[4/6] Installing Heavy Dependencies..."
+pip install open3d==0.19.0 --no-deps
+pip install ultralytics==8.3.227 --no-deps
+pip install timm==1.0.24 --no-deps
+
+echo "[5/6] Installing MMCV & PyTorch3D..."
+pip install mmcv==2.2.0 --no-deps --no-binary mmcv
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" --no-build-isolation
+
+echo "[6/6] Installing SAM 2 (With Setuptools Fix)..."
+
+pip install setuptools==69.5.1 wheel
+rm -rf _tmp_install_sam2
+
+mkdir -p _tmp_install_sam2
+cd _tmp_install_sam2
+
+echo " -> Cloning SAM 2..."
+git clone https://github.com/facebookresearch/segment-anything-2.git --depth 1
+cd segment-anything-2
+
+echo " -> Patching setup.py..."
+python -c "
+path = 'setup.py'
+with open(path, 'r') as f: c = f.read()
+c = c.replace('torch>=2.5.1', 'torch>=2.4.1')
+with open(path, 'w') as f: f.write(c)
+"
+pip install . --no-deps --no-build-isolation
+cd ../..
+rm -rf _tmp_install_sam2
+
+echo "========================================"
+echo "Installation Complete!"
+python -c "import torch; print(f'PyTorch: {torch.__version__} | CUDA: {torch.version.cuda}')"
+echo "========================================"
\ No newline at end of file
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cd6c2868c489eda0a4e0cb19295419a6af07d496
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1,7 @@
+ffmpeg
+libgl1-mesa-glx
+libglib2.0-0
+libegl1-mesa
+xvfb
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..09e2469bebf82a89f376907c31ccce547de96451
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+torch==2.4.1
+torchvision==0.19.1
+numpy
+scipy
+trimesh
+tqdm
+opencv-python-headless
+pillow
+gradio
+spaces
+ninja
+einops
+safetensors
+huggingface_hub
+open3d==0.19.0
+ultralytics==8.3.227
+timm==1.0.24
+git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183
+mmcv==2.2.0 --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html
+pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
+git+https://github.com/facebookresearch/segment-anything-2.git
+smplx
diff --git a/static/teaser.svg b/static/teaser.svg
new file mode 100644
index 0000000000000000000000000000000000000000..fccafcd95ab8c49c7bb263ed8ae265fe2f509b9f
--- /dev/null
+++ b/static/teaser.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/unish/__pycache__/pipeline.cpython-310.pyc b/unish/__pycache__/pipeline.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45e0ce67c07484a2ad3b0174be3074727edabb5a
Binary files /dev/null and b/unish/__pycache__/pipeline.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/align_net.cpython-310.pyc b/unish/heads/__pycache__/align_net.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1673d3d9e9e3f7adc0d3948c5b18296c645631e5
Binary files /dev/null and b/unish/heads/__pycache__/align_net.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/dpt_head.cpython-310.pyc b/unish/heads/__pycache__/dpt_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..756752d7aab768b63976636293abf71d73919498
Binary files /dev/null and b/unish/heads/__pycache__/dpt_head.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/head_act.cpython-310.pyc b/unish/heads/__pycache__/head_act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4ca1c0796ebb227b8efb5023d735a9b996e48e8
Binary files /dev/null and b/unish/heads/__pycache__/head_act.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/human_head_cliff.cpython-310.pyc b/unish/heads/__pycache__/human_head_cliff.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c86d3a43211901e7eaee355860c1772608cc9b7
Binary files /dev/null and b/unish/heads/__pycache__/human_head_cliff.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/pose_transformer.cpython-310.pyc b/unish/heads/__pycache__/pose_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e269a8b417c3b5f5e9e6bfe6541fbde539a4a6d1
Binary files /dev/null and b/unish/heads/__pycache__/pose_transformer.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc b/unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..253b56d946b156cfed4c44b081bc9fab50e6445a
Binary files /dev/null and b/unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/utils.cpython-310.pyc b/unish/heads/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..997995fd99e3f8a36e543852b110c4293f07ba86
Binary files /dev/null and b/unish/heads/__pycache__/utils.cpython-310.pyc differ
diff --git a/unish/heads/__pycache__/vit.cpython-310.pyc b/unish/heads/__pycache__/vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8948f83997d8ffc50f7b20263788a15afa623e96
Binary files /dev/null and b/unish/heads/__pycache__/vit.cpython-310.pyc differ
diff --git a/unish/heads/align_net.py b/unish/heads/align_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..b306b7ede708729048d20263d15d68c3b15ec8d3
--- /dev/null
+++ b/unish/heads/align_net.py
@@ -0,0 +1,571 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import numpy as np
+
+from unish.utils.data_utils import rot6d_to_rotmat
+from unish.utils.constants import SMPL_MEAN_PARAMS
+
+
+class TimeStepRoPE1D(nn.Module):
+ """1D RoPE for timestep embedding, similar to pi3's RoPE2D but for 1D time sequence"""
+
+ def __init__(self, freq=100.0):
+ super().__init__()
+ self.base = freq
+ self.cache = {}
+ self.max_train_len = 120
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D, seq_len, device, dtype) in self.cache:
+ return self.cache[D, seq_len, device, dtype]
+
+ if seq_len <= self.max_train_len:
+ assert D % 2 == 0
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (seq_len, D)
+ sin = freqs.sin() # (seq_len, D)
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
+ return cos, sin
+
+ else:
+ cos_train, sin_train = self.get_cos_sin(D, self.max_train_len, device, dtype)
+ cos_train_res = cos_train.transpose(0, 1).unsqueeze(0)
+ sin_train_res = sin_train.transpose(0, 1).unsqueeze(0)
+
+ # [1, D, max_train_len] -> [1, D, seq_len]
+ cos_interp = F.interpolate(cos_train_res, size=seq_len, mode='linear', align_corners=True)
+ sin_interp = F.interpolate(sin_train_res, size=seq_len, mode='linear', align_corners=True)
+
+ # [1, D, seq_len] -> [seq_len, D]
+ cos_final = cos_interp.squeeze(0).transpose(0, 1)
+ sin_final = sin_interp.squeeze(0).transpose(0, 1)
+
+ self.cache[D, seq_len, device, dtype] = (cos_final, sin_final)
+ return cos_final, sin_final
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ """Apply 1D RoPE to tokens based on 1D positions"""
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] # [batch, 1, seq_len, D]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] # [batch, 1, seq_len, D]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ Apply 1D RoPE to tokens based on timestep positions.
+ Args:
+ tokens: [batch, num_heads, seq_len, head_dim]
+ positions: [batch, seq_len] - timestep positions (0, 1, 2, ...)
+ Returns:
+ tokens with RoPE applied: [batch, num_heads, seq_len, head_dim]
+ """
+ head_dim = tokens.size(3)
+ assert head_dim % 2 == 0, "head_dim should be a multiple of two"
+ assert positions.ndim == 2 # [batch, seq_len]
+
+ cos, sin = self.get_cos_sin(head_dim, int(positions.max()) + 1, tokens.device, tokens.dtype)
+
+ return self.apply_rope1d(tokens, positions.long(), cos, sin)
+
+
+class TransformerDecoderLayer(nn.Module):
+ """单层Transformer Decoder with RoPE support"""
+
+ def __init__(self, hidden_dim=512, num_heads=8, ff_dim=1024, dropout=0.1, use_rope=True):
+ super().__init__()
+
+ self.use_rope = use_rope
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.head_dim = hidden_dim // num_heads
+
+ if use_rope:
+ self.self_attention = None
+ self.cross_attention = None
+
+ self.self_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.self_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.self_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.self_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ self.cross_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.cross_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.cross_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.cross_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ # RoPE for timestep embedding
+ self.timestep_rope = TimeStepRoPE1D(freq=100.0)
+ else:
+ self.self_attention = nn.MultiheadAttention(
+ embed_dim=hidden_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ batch_first=True
+ )
+
+ self.cross_attention = nn.MultiheadAttention(
+ embed_dim=hidden_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ batch_first=True
+ )
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(hidden_dim, ff_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(ff_dim, hidden_dim),
+ nn.Dropout(dropout)
+ )
+
+ self.norm1 = nn.LayerNorm(hidden_dim) # for self attention
+ self.norm2 = nn.LayerNorm(hidden_dim) # for cross attention
+ self.norm3 = nn.LayerNorm(hidden_dim) # for feed forward
+
+ # Dropout
+ self.dropout = nn.Dropout(dropout)
+ self.attn_dropout = nn.Dropout(dropout)
+
+ # Scale factor for attention
+ self.scale = self.head_dim ** -0.5
+
+ # Gradient checkpointing flag
+ self.use_gradient_checkpoint = False
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.use_gradient_checkpoint = True
+
+ def _rope_attention(self, q_proj, k_proj, v_proj, out_proj, query, key, value, timestep_pos=None):
+ """Apply RoPE-based attention using torch.nn.functional.scaled_dot_product_attention"""
+ batch_size, seq_len, _ = query.shape
+
+ # Project Q, K, V
+ q = q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k_proj(key).view(batch_size, key.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
+ v = v_proj(value).view(batch_size, value.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
+
+ # Apply RoPE to Q and K if timestep positions are provided
+ if timestep_pos is not None and self.use_rope:
+ # For self-attention, both q and k use the same timestep positions
+ if query.shape == key.shape: # self-attention case
+ q = self.timestep_rope(q, timestep_pos)
+ k = self.timestep_rope(k, timestep_pos)
+ else: # cross-attention case
+ # Only apply RoPE to query (cam_token), key/value are spatial features
+ q = self.timestep_rope(q, timestep_pos)
+
+ attn_output = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
+ scale=self.scale
+ )
+
+ # Reshape output
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
+
+ # Output projection
+ return out_proj(attn_output)
+
+ def forward(self, query, key, value, self_attn_mask=None, cross_attn_mask=None, timestep_pos=None):
+ """
+ Args:
+ query: [batch, num_views, hidden_dim]
+ key: [batch, num_views, hidden_dim]
+ value: [batch, num_views, hidden_dim]
+ timestep_pos: [batch, num_views] - timestep positions for RoPE
+ """
+ if self.use_gradient_checkpoint and self.training:
+ from torch.utils.checkpoint import checkpoint
+
+ if self.use_rope:
+ # 1. Self Attention + Residual with RoPE (with gradient checkpointing)
+ self_attn_output = checkpoint(
+ self._rope_attention,
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
+ query, query, query, timestep_pos,
+ use_reentrant=False
+ )
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual with RoPE (with gradient checkpointing)
+ cross_attn_output = checkpoint(
+ self._rope_attention,
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
+ query, key, value, timestep_pos,
+ use_reentrant=False
+ )
+ query = self.norm2(query + self.dropout(cross_attn_output))
+ else:
+ # 1. Self Attention + Residual (with gradient checkpointing)
+ def self_attn_fn(q, k, v):
+ out, _ = self.self_attention(q, k, v, attn_mask=self_attn_mask)
+ return out
+ self_attn_output = checkpoint(self_attn_fn, query, query, query, use_reentrant=False)
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual (with gradient checkpointing)
+ def cross_attn_fn(q, k, v):
+ out, _ = self.cross_attention(q, k, v, attn_mask=cross_attn_mask)
+ return out
+ cross_attn_output = checkpoint(cross_attn_fn, query, key, value, use_reentrant=False)
+ query = self.norm2(query + self.dropout(cross_attn_output))
+
+ # 3. Feed Forward + Residual (with gradient checkpointing)
+ ff_output = checkpoint(self.feed_forward, query, use_reentrant=False)
+ query = self.norm3(query + ff_output)
+ else:
+ # Original implementation without gradient checkpointing
+ if self.use_rope:
+ # 1. Self Attention + Residual with RoPE
+ self_attn_output = self._rope_attention(
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
+ query, query, query, timestep_pos
+ )
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual with RoPE
+ cross_attn_output = self._rope_attention(
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
+ query, key, value, timestep_pos
+ )
+ query = self.norm2(query + self.dropout(cross_attn_output))
+ else:
+ # 1. Self Attention + Residual (original implementation)
+ self_attn_output, _ = self.self_attention(query, query, query, attn_mask=self_attn_mask)
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual (original implementation)
+ cross_attn_output, _ = self.cross_attention(query, key, value, attn_mask=cross_attn_mask)
+ query = self.norm2(query + self.dropout(cross_attn_output))
+
+ # 3. Feed Forward + Residual
+ ff_output = self.feed_forward(query)
+ query = self.norm3(query + ff_output)
+
+ return query
+
+
+class CrossViewTransformerDecoderLayer(nn.Module):
+ """Cross-view Transformer Decoder Layer for V4 - handles concatenated tokens from multiple views"""
+
+ def __init__(self, hidden_dim=512, num_heads=8, ff_dim=1024, dropout=0.1, use_rope=True):
+ super().__init__()
+
+ self.use_rope = use_rope
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.head_dim = hidden_dim // num_heads
+
+ if use_rope:
+ self.self_attention = None
+ self.cross_attention = None
+
+ # Self-attention components
+ self.self_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.self_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.self_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.self_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ # Cross-attention components
+ self.cross_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.cross_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.cross_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+ self.cross_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ # RoPE for timestep embedding
+ self.timestep_rope = TimeStepRoPE1D(freq=100.0)
+ else:
+ # Self Attention层
+ self.self_attention = nn.MultiheadAttention(
+ embed_dim=hidden_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ batch_first=True
+ )
+
+ # Cross Attention层
+ self.cross_attention = nn.MultiheadAttention(
+ embed_dim=hidden_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ batch_first=True
+ )
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(hidden_dim, ff_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout),
+ nn.Linear(ff_dim, hidden_dim),
+ nn.Dropout(dropout)
+ )
+
+ self.norm1 = nn.LayerNorm(hidden_dim) # for self attention
+ self.norm2 = nn.LayerNorm(hidden_dim) # for cross attention
+ self.norm3 = nn.LayerNorm(hidden_dim) # for feed forward
+
+ self.dropout = nn.Dropout(dropout)
+ self.attn_dropout = nn.Dropout(dropout)
+
+ self.scale = self.head_dim ** -0.5
+
+ self.use_gradient_checkpoint = False
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.use_gradient_checkpoint = True
+
+ def _rope_attention(self, q_proj, k_proj, v_proj, out_proj, query, key, value, query_timestep_pos=None, key_timestep_pos=None):
+ """Apply RoPE-based attention for cross-view scenarios using torch.nn.functional.scaled_dot_product_attention"""
+ batch_size, query_seq_len, _ = query.shape
+ _, key_seq_len, _ = key.shape
+
+ # Project Q, K, V
+ q = q_proj(query).view(batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k_proj(key).view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v_proj(value).view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ # Apply RoPE to Q and K if timestep positions are provided
+ if self.use_rope:
+ if query_timestep_pos is not None:
+ q_scale = q[:, :, 0:1, :] # [batch, num_heads, 1, head_dim] - scale token
+ q_cam = q[:, :, 1:, :] # [batch, num_heads, num_views, head_dim] - cam tokens
+
+ cam_timestep_pos = query_timestep_pos[:, 1:]
+ q_cam_rope = self.timestep_rope(q_cam, cam_timestep_pos)
+
+ q = torch.cat([q_scale, q_cam_rope], dim=2)
+ if key_timestep_pos is not None:
+ k = self.timestep_rope(k, key_timestep_pos)
+
+ attn_output = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
+ scale=self.scale
+ )
+
+ # Reshape output
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_seq_len, self.hidden_dim)
+
+ # Output projection
+ return out_proj(attn_output)
+
+ def forward(self, query, key, value, query_timestep_pos=None, key_timestep_pos=None):
+ """
+ Args:
+ query: [batch, num_queries, hidden_dim] - cam tokens + scale token
+ key: [batch, num_views * num_tokens, hidden_dim] - concatenated feature tokens from all views
+ value: [batch, num_views * num_tokens, hidden_dim] - concatenated feature tokens from all views
+ query_timestep_pos: [batch, num_queries] - timestep positions for query tokens
+ key_timestep_pos: [batch, num_views * num_tokens] - timestep positions for key/value tokens
+ """
+ if self.use_gradient_checkpoint and self.training:
+ from torch.utils.checkpoint import checkpoint
+
+ if self.use_rope:
+ # 1. Self Attention + Residual with RoPE (with gradient checkpointing)
+ self_attn_output = checkpoint(
+ self._rope_attention,
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
+ query, query, query, query_timestep_pos, query_timestep_pos,
+ use_reentrant=False
+ )
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual with RoPE (with gradient checkpointing)
+ cross_attn_output = checkpoint(
+ self._rope_attention,
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
+ query, key, value, query_timestep_pos, key_timestep_pos,
+ use_reentrant=False
+ )
+ query = self.norm2(query + self.dropout(cross_attn_output))
+ else:
+ # 1. Self Attention + Residual (with gradient checkpointing)
+ def self_attn_fn(q, k, v):
+ out, _ = self.self_attention(q, k, v)
+ return out
+ self_attn_output = checkpoint(self_attn_fn, query, query, query, use_reentrant=False)
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual (with gradient checkpointing)
+ def cross_attn_fn(q, k, v):
+ out, _ = self.cross_attention(q, k, v)
+ return out
+ cross_attn_output = checkpoint(cross_attn_fn, query, key, value, use_reentrant=False)
+ query = self.norm2(query + self.dropout(cross_attn_output))
+
+ # 3. Feed Forward + Residual (with gradient checkpointing)
+ ff_output = checkpoint(self.feed_forward, query, use_reentrant=False)
+ query = self.norm3(query + ff_output)
+ else:
+ # Original implementation without gradient checkpointing
+ if self.use_rope:
+ # 1. Self Attention + Residual with RoPE
+ self_attn_output = self._rope_attention(
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
+ query, query, query, query_timestep_pos, query_timestep_pos
+ )
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual with RoPE
+ cross_attn_output = self._rope_attention(
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
+ query, key, value, query_timestep_pos, key_timestep_pos
+ )
+ query = self.norm2(query + self.dropout(cross_attn_output))
+ else:
+ # 1. Self Attention + Residual (original implementation)
+ self_attn_output, _ = self.self_attention(query, query, query)
+ query = self.norm1(query + self.dropout(self_attn_output))
+
+ # 2. Cross Attention + Residual (original implementation)
+ cross_attn_output, _ = self.cross_attention(query, key, value)
+ query = self.norm2(query + self.dropout(cross_attn_output))
+
+ # 3. Feed Forward + Residual
+ ff_output = self.feed_forward(query)
+ query = self.norm3(query + ff_output)
+
+ return query
+
+
+class AlignNet(nn.Module):
+ def __init__(self, aggregated_dim=2048, cam_dim=1024, hidden_dim=512, num_heads=8, ff_dim=512, dropout=0.1, use_rope=True, num_decoder_layers=2):
+ super().__init__()
+
+ self.use_rope = use_rope
+ self.hidden_dim = hidden_dim
+ self.num_decoder_layers = num_decoder_layers
+
+ self.scale_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
+
+ self.cam_feature_adapter = nn.Sequential(
+ nn.LayerNorm(cam_dim),
+ nn.Linear(cam_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout)
+ )
+
+ self.patch_feature_adapter = nn.Sequential(
+ nn.LayerNorm(aggregated_dim),
+ nn.Linear(aggregated_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout)
+ )
+ self.register_feature_adapter = nn.Sequential(
+ nn.LayerNorm(aggregated_dim),
+ nn.Linear(aggregated_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Dropout(dropout)
+ )
+
+ self.decoder_layers = nn.ModuleList([
+ CrossViewTransformerDecoderLayer(hidden_dim, num_heads, ff_dim, dropout, use_rope=use_rope)
+ for _ in range(num_decoder_layers)
+ ])
+
+ mean_params = SMPL_MEAN_PARAMS
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
+ self.register_buffer('init_body_pose', init_body_pose)
+ self.register_buffer('init_betas', init_betas)
+ self.register_buffer('init_cam', init_cam)
+
+ self.trans_head = nn.Linear(hidden_dim, 3)
+
+ self.scale_head = nn.Linear(hidden_dim, 1)
+
+ self.joint_conversion_fn = rot6d_to_rotmat
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ for layer in self.decoder_layers:
+ if hasattr(layer, 'gradient_checkpointing_enable'):
+ layer.gradient_checkpointing_enable()
+
+ def forward(self, hidden_tokens, cam_token, fps=6.0):
+ batch_size, num_views, num_tokens, _ = hidden_tokens.shape
+
+ register_tokens = hidden_tokens[:, :, :5, :]
+ patch_tokens = hidden_tokens[:, :, 5:, :]
+
+ if cam_token.dim() == 4:
+ cam_token = cam_token.squeeze(2) # [batch, num_views, 1, 1024] -> [batch, num_views, 1024]
+
+ cam_adapted = self.cam_feature_adapter(cam_token) # [batch, num_views, hidden_dim]
+
+ patch_tokens_reshaped = patch_tokens.view(batch_size * num_views, patch_tokens.shape[2], -1) # [batch*num_views, 777, 2048]
+ patch_adapted_tokens = self.patch_feature_adapter(patch_tokens_reshaped) # [batch*num_views, 777, hidden_dim]
+ patch_adapted_tokens = patch_adapted_tokens.view(batch_size, num_views, patch_tokens.shape[2], -1) # [batch, num_views, 777, hidden_dim]
+
+ register_tokens_reshaped = register_tokens.view(batch_size * num_views, 5, -1) # [batch*num_views, 5, 2048]
+ register_adapted_tokens = self.register_feature_adapter(register_tokens_reshaped) # [batch*num_views, 5, hidden_dim]
+ register_adapted_tokens = register_adapted_tokens.view(batch_size, num_views, 5, -1) # [batch, num_views, 5, hidden_dim]
+
+ fused_features_per_view = torch.cat([register_adapted_tokens, patch_adapted_tokens], dim=2) # [batch, num_views, 782, hidden_dim]
+
+ concatenated_features = fused_features_per_view.view(batch_size, num_views * num_tokens, -1)
+
+ scale_token_expanded = self.scale_token.expand(batch_size, -1, -1)
+
+ query_tokens = torch.cat([scale_token_expanded, cam_adapted], dim=1)
+
+ if self.use_rope:
+ base_fps = 6.0
+
+ time_scale = base_fps / fps
+
+ scale_timestep = torch.zeros((batch_size, 1), device=cam_adapted.device, dtype=torch.long)
+
+ cam_timestep_float = torch.arange(num_views, device=cam_adapted.device, dtype=torch.float32) * time_scale
+ cam_timestep = cam_timestep_float.round().long().unsqueeze(0).expand(batch_size, -1)
+ query_timestep_pos = torch.cat([scale_timestep, cam_timestep], dim=1) # [batch, 1 + num_views]
+
+ key_timestep_base_float = torch.arange(num_views, device=cam_adapted.device, dtype=torch.float32) * time_scale
+ key_timestep_base = key_timestep_base_float.round().long()
+ key_timestep_pos = key_timestep_base.unsqueeze(1).expand(-1, num_tokens).flatten()
+ key_timestep_pos = key_timestep_pos.unsqueeze(0).expand(batch_size, -1) # [batch, num_views * num_tokens]
+ else:
+ query_timestep_pos = None
+ key_timestep_pos = None
+
+ decoder_output = query_tokens
+ for i, layer in enumerate(self.decoder_layers):
+ residual = decoder_output
+
+ decoder_output = layer(
+ decoder_output, concatenated_features, concatenated_features,
+ query_timestep_pos=query_timestep_pos, key_timestep_pos=key_timestep_pos
+ )
+
+ decoder_output = decoder_output + residual
+
+ scale_output = decoder_output[:, 0, :]
+ cam_outputs = decoder_output[:, 1:, :]
+
+ scale_logits = self.scale_head(scale_output) # [batch, 1]
+ scale = F.softplus(scale_logits)
+
+ trans_raw = self.trans_head(cam_outputs) # [batch, num_views, 3]
+ xy, z = trans_raw.split([2, 1], dim=-1) # xy: [batch, num_views, 2], z: [batch, num_views, 1]
+ z = torch.exp(z)
+ trans = torch.cat([xy * z, z], dim=-1) # [batch, num_views, 3]
+
+
+ return {
+ "scale": scale, # [batch, 1]
+ "trans_cam": trans, # [batch, num_views, 3]
+ }
diff --git a/unish/heads/dpt_head.py b/unish/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f69cf5f08fe41a8bc7c275edce6aa98d3d68a7c
--- /dev/null
+++ b/unish/heads/dpt_head.py
@@ -0,0 +1,500 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+import os
+from typing import List, Dict, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=oc,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for oc in out_channels
+ ]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ expand=False,
+ )
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: int = None,
+ frames_end_idx: int = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+
+ x = x.to(self.projects[0].weight.dtype)
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx]
+
+ x = x.reshape(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H).to(self.projects[0].weight.dtype)
+
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H).to(self.projects[0].weight.dtype)
+
+ if self.feature_only:
+ return out.view(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
+
+ preds = preds.view(B, S, *preds.shape[1:])
+ conf = conf.view(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Tuple[int, int] = None,
+ scale_factor: float = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
diff --git a/unish/heads/head_act.py b/unish/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489
--- /dev/null
+++ b/unish/heads/head_act.py
@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/unish/heads/human_head_cliff.py b/unish/heads/human_head_cliff.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2d2964a7b21b03e855d82169b38810692e94a0
--- /dev/null
+++ b/unish/heads/human_head_cliff.py
@@ -0,0 +1,97 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import einops
+
+
+from unish.utils.data_utils import rot6d_to_rotmat
+from unish.utils.constants import SMPL_MEAN_PARAMS
+from .pose_transformer import TransformerDecoder
+
+TRANSFORMER_DECODER={'depth': 6,
+ 'heads': 8,
+ 'mlp_dim': 1024,
+ 'dim_head': 64,
+ 'dropout': 0.0,
+ 'emb_dropout': 0.0,
+ 'norm': 'layer',
+ 'context_dim': 1280}
+
+NUM_POSE_INPUT = 23
+NUM_BETAS_INPUT = 10
+NUM_BETAS = 10
+NUM_POSE_PARAMS = 23
+
+class HumanHeadCliff(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.joint_rep_dim = 6
+ npose = self.joint_rep_dim * (NUM_POSE_INPUT + 1)
+ self.npose = npose
+ transformer_args = dict(
+ num_tokens=1,
+ token_dim=(3 + npose + NUM_BETAS_INPUT + 3),
+ dim=1024,
+ )
+ transformer_args = (transformer_args | dict(TRANSFORMER_DECODER))
+ self.transformer = TransformerDecoder(
+ **transformer_args
+ )
+ dim=transformer_args['dim']
+ self.decpose = nn.Linear(dim, self.joint_rep_dim * (NUM_POSE_PARAMS + 1))
+ self.decshape = nn.Linear(dim, NUM_BETAS)
+ # self.deccam = nn.Linear(dim, 3)
+ # self.deckp = nn.Linear(dim, 88)
+
+ mean_params = SMPL_MEAN_PARAMS
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
+ self.register_buffer('init_body_pose', init_body_pose)
+ self.register_buffer('init_betas', init_betas)
+ self.register_buffer('init_cam', init_cam)
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ if hasattr(self.transformer, 'gradient_checkpointing_enable'):
+ self.transformer.gradient_checkpointing_enable()
+
+ def forward(self, x, bbox_info, **kwargs):
+ """
+ x: (B, N, C, H, W)
+ bbox_info: [cx / f, cy / f, box_size / f], (B, N, 3)
+ """
+
+ batch_size, num_views = x.shape[:2]
+ x = einops.rearrange(x, 'b n c h w -> (b n) (h w) c')
+
+ init_body_pose = self.init_body_pose.expand(batch_size * num_views, -1)
+ init_betas = self.init_betas.expand(batch_size * num_views, -1)
+ init_cam = self.init_cam.expand(batch_size * num_views, -1)
+ bbox_info = bbox_info.view(-1, 3)
+
+ pred_body_pose = init_body_pose
+ pred_betas = init_betas
+ pred_cam = init_cam
+ token = torch.cat([bbox_info, pred_body_pose, pred_betas, pred_cam], dim=-1)[:, None, :]
+
+ # Pass through transformer
+ token_out = self.transformer(token, context=x)
+ token_out = token_out.squeeze(1) # (B, C)
+
+ pred_body_pose = self.decpose(token_out) + pred_body_pose
+ pred_betas = self.decshape(token_out) + pred_betas
+
+ joint_conversion_fn = rot6d_to_rotmat
+
+ pred_body_pose = pred_body_pose.view(-1, 6)
+ pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, num_views, -1)
+ pred_betas = pred_betas.view(batch_size, num_views, -1).mean(dim=1)
+ token_out = token_out.view(batch_size, num_views, -1)
+
+ pred_smpl_params = {'pose_cam': pred_body_pose,
+ 'token_out': token_out,
+ 'betas': pred_betas}
+ return pred_smpl_params
diff --git a/unish/heads/pose_transformer.py b/unish/heads/pose_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f16e97133288666e840a1b6d8ba79091bb494740
--- /dev/null
+++ b/unish/heads/pose_transformer.py
@@ -0,0 +1,364 @@
+from inspect import isfunction
+from typing import Callable, Optional
+
+import torch
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import nn
+
+from .t_cond_mlp import (
+ AdaptiveLayerNorm1D,
+ FrequencyEmbedder,
+ normalization_layer,
+)
+# from .vit import Attention, FeedForward
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+class PreNorm(nn.Module):
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
+ super().__init__()
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
+ self.fn = fn
+
+ def forward(self, x: torch.Tensor, *args, **kwargs):
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
+ return self.fn(self.norm(x, *args), **kwargs)
+ else:
+ return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.0):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x):
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+
+ self.attend = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ context_dim = default(context_dim, dim)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x, context=None):
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+ q = self.to_q(x)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
+
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
+
+ attn = self.attend(dots)
+ attn = self.dropout(attn)
+
+ out = torch.matmul(attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ mlp_dim: int,
+ dropout: float = 0.0,
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *args):
+ for attn, ff in self.layers:
+ x = attn(x, *args) + x
+ x = ff(x, *args) + x
+ return x
+
+
+class TransformerCrossAttn(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ mlp_dim: int,
+ dropout: float = 0.0,
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
+ ca = CrossAttention(
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
+ )
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
+ if context_list is None:
+ context_list = [context] * len(self.layers)
+ if len(context_list) != len(self.layers):
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
+
+ b, n = x.shape[:2]
+
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
+ x = self_attn(x, *args) + x
+ # TODO
+ # x = x.view(b*n, 1, -1)
+ x = cross_attn(x, *args, context=context_list[i]) + x
+ # x = x.view(b, n, -1)
+ x = ff(x, *args) + x
+ return x
+
+
+class DropTokenDropout(nn.Module):
+ def __init__(self, p: float = 0.1):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ # x: (batch_size, seq_len, dim)
+ if self.training and self.p > 0:
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
+ # TODO: permutation idx for each batch using torch.argsort
+ if zero_mask.any():
+ x = x[:, ~zero_mask, :]
+ return x
+
+
+class ZeroTokenDropout(nn.Module):
+ def __init__(self, p: float = 0.1):
+ super().__init__()
+ if p < 0 or p > 1:
+ raise ValueError(
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
+ )
+ self.p = p
+
+ def forward(self, x: torch.Tensor):
+ # x: (batch_size, seq_len, dim)
+ if self.training and self.p > 0:
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
+ # Zero-out the masked tokens
+ x[zero_mask, :] = 0
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ token_dim: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ emb_dropout_type: str = "drop",
+ emb_dropout_loc: str = "token",
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ token_pe_numfreq: int = -1,
+ ):
+ super().__init__()
+ if token_pe_numfreq > 0:
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
+ self.to_token_embedding = nn.Sequential(
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
+ nn.Linear(token_dim_new, dim),
+ )
+ else:
+ self.to_token_embedding = nn.Linear(token_dim, dim)
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
+ if emb_dropout_type == "drop":
+ self.dropout = DropTokenDropout(emb_dropout)
+ elif emb_dropout_type == "zero":
+ self.dropout = ZeroTokenDropout(emb_dropout)
+ else:
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
+ self.emb_dropout_loc = emb_dropout_loc
+
+ self.transformer = Transformer(
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
+ )
+
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
+ x = inp
+
+ if self.emb_dropout_loc == "input":
+ x = self.dropout(x)
+ x = self.to_token_embedding(x)
+
+ if self.emb_dropout_loc == "token":
+ x = self.dropout(x)
+ b, n, _ = x.shape
+ x += self.pos_embedding[:, :n]
+
+ if self.emb_dropout_loc == "token_afterpos":
+ x = self.dropout(x)
+ x = self.transformer(x, *args)
+ return x
+
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ token_dim: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ emb_dropout_type: str = 'drop',
+ norm: str = "layer",
+ norm_cond_dim: int = -1,
+ context_dim: Optional[int] = None,
+ skip_token_embedding: bool = False,
+ ):
+ super().__init__()
+ if not skip_token_embedding:
+ self.to_token_embedding = nn.Linear(token_dim, dim)
+ else:
+ self.to_token_embedding = nn.Identity()
+ if token_dim != dim:
+ raise ValueError(
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
+ )
+
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
+ if emb_dropout_type == "drop":
+ self.dropout = DropTokenDropout(emb_dropout)
+ elif emb_dropout_type == "zero":
+ self.dropout = ZeroTokenDropout(emb_dropout)
+ elif emb_dropout_type == "normal":
+ self.dropout = nn.Dropout(emb_dropout)
+
+ self.transformer = TransformerCrossAttn(
+ dim,
+ depth,
+ heads,
+ dim_head,
+ mlp_dim,
+ dropout,
+ norm=norm,
+ norm_cond_dim=norm_cond_dim,
+ context_dim=context_dim,
+ )
+
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
+
+ x = self.to_token_embedding(inp)
+ b, n, _ = x.shape
+
+ x = self.dropout(x)
+ x += self.pos_embedding[:, :n]
+
+ x = self.transformer(x, *args, context=context, context_list=context_list)
+ return x
+
diff --git a/unish/heads/t_cond_mlp.py b/unish/heads/t_cond_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d5a09bf54f67712a69953039b7b5af41c3f029
--- /dev/null
+++ b/unish/heads/t_cond_mlp.py
@@ -0,0 +1,199 @@
+import copy
+from typing import List, Optional
+
+import torch
+
+
+class AdaptiveLayerNorm1D(torch.nn.Module):
+ def __init__(self, data_dim: int, norm_cond_dim: int):
+ super().__init__()
+ if data_dim <= 0:
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
+ if norm_cond_dim <= 0:
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
+ self.norm = torch.nn.LayerNorm(
+ data_dim
+ ) # TODO: Check if elementwise_affine=True is correct
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
+ torch.nn.init.zeros_(self.linear.weight)
+ torch.nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ # x: (batch, ..., data_dim)
+ # t: (batch, norm_cond_dim)
+ # return: (batch, data_dim)
+ x = self.norm(x)
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
+
+ # Add singleton dimensions to alpha and beta
+ if x.dim() > 2:
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
+
+ return x * (1 + alpha) + beta
+
+
+class SequentialCond(torch.nn.Sequential):
+ def forward(self, input, *args, **kwargs):
+ for module in self:
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
+ # print(f'Passing on args to {module}', [a.shape for a in args])
+ input = module(input, *args, **kwargs)
+ else:
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
+ input = module(input)
+ return input
+
+
+def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
+ if norm == "batch":
+ return torch.nn.BatchNorm1d(dim)
+ elif norm == "layer":
+ return torch.nn.LayerNorm(dim)
+ elif norm == "ada":
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
+ elif norm is None:
+ return torch.nn.Identity()
+ else:
+ raise ValueError(f"Unknown norm: {norm}")
+
+
+def linear_norm_activ_dropout(
+ input_dim: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+) -> SequentialCond:
+ layers = []
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
+ if norm is not None:
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
+ layers.append(copy.deepcopy(activation))
+ if dropout > 0.0:
+ layers.append(torch.nn.Dropout(dropout))
+ return SequentialCond(*layers)
+
+
+def create_simple_mlp(
+ input_dim: int,
+ hidden_dims: List[int],
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+) -> SequentialCond:
+ layers = []
+ prev_dim = input_dim
+ for hidden_dim in hidden_dims:
+ layers.extend(
+ linear_norm_activ_dropout(
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ )
+ )
+ prev_dim = hidden_dim
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
+ return SequentialCond(*layers)
+
+
+class ResidualMLPBlock(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ num_hidden_layers: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ if not (input_dim == output_dim == hidden_dim):
+ raise NotImplementedError(
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
+ )
+
+ layers = []
+ prev_dim = input_dim
+ for i in range(num_hidden_layers):
+ layers.append(
+ linear_norm_activ_dropout(
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ )
+ )
+ prev_dim = hidden_dim
+ self.model = SequentialCond(*layers)
+ self.skip = torch.nn.Identity()
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return x + self.model(x, *args, **kwargs)
+
+
+class ResidualMLP(torch.nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ num_hidden_layers: int,
+ output_dim: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ bias: bool = True,
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
+ dropout: float = 0.0,
+ num_blocks: int = 1,
+ norm_cond_dim: int = -1,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.model = SequentialCond(
+ linear_norm_activ_dropout(
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
+ ),
+ *[
+ ResidualMLPBlock(
+ hidden_dim,
+ hidden_dim,
+ num_hidden_layers,
+ hidden_dim,
+ activation,
+ bias,
+ norm,
+ dropout,
+ norm_cond_dim,
+ )
+ for _ in range(num_blocks)
+ ],
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
+ )
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return self.model(x, *args, **kwargs)
+
+
+class FrequencyEmbedder(torch.nn.Module):
+ def __init__(self, num_frequencies, max_freq_log2):
+ super().__init__()
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
+ self.register_buffer("frequencies", frequencies)
+
+ def forward(self, x):
+ # x should be of size (N,) or (N, D)
+ N = x.size(0)
+ if x.dim() == 1: # (N,)
+ x = x.unsqueeze(1) # (N, D) where D=1
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
+ s = torch.sin(scaled)
+ c = torch.cos(scaled)
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
+ N, -1
+ ) # (N, D * 2 * num_frequencies + D)
+ return embedded
+
diff --git a/unish/heads/utils.py b/unish/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7af1f68fa0ce0a48d11a708d53aa20aa8f78ba2
--- /dev/null
+++ b/unish/heads/utils.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/unish/heads/vit.py b/unish/heads/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..900795a6587ef479f916df2b8fa451ce60181689
--- /dev/null
+++ b/unish/heads/vit.py
@@ -0,0 +1,346 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+from functools import partial
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+
+def vit():
+ return ViT(
+ img_size=(256, 192),
+ patch_size=16,
+ embed_dim=1280,
+ depth=32,
+ num_heads=16,
+ ratio=1,
+ use_checkpoint=False,
+ mlp_ratio=4,
+ qkv_bias=True,
+ drop_path_rate=0.55,
+ )
+
+def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
+ """
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+ dimension for the original embeddings.
+ Args:
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+ hw (Tuple): size of input image tokens.
+
+ Returns:
+ Absolute positional embeddings after processing with shape (1, H, W, C)
+ """
+ cls_token = None
+ B, L, C = abs_pos.shape
+ if has_cls_token:
+ cls_token = abs_pos[:, 0:1]
+ abs_pos = abs_pos[:, 1:]
+
+ if ori_h != h or ori_w != w:
+ new_abs_pos = F.interpolate(
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
+ size=(h, w),
+ mode="bicubic",
+ align_corners=False,
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
+
+ else:
+ new_abs_pos = abs_pos
+
+ if cls_token is not None:
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
+ return new_abs_pos
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+ def extra_repr(self):
+ return 'p={}'.format(self.drop_prob)
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+ proj_drop=0., attn_head_dim=None,):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.dim = dim
+
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm, attn_head_dim=None
+ ):
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
+ )
+
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
+
+ def forward(self, x, **kwargs):
+ B, C, H, W = x.shape
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+
+ x = x.flatten(2).transpose(1, 2)
+ return x, (Hp, Wp)
+
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ self.img_size = img_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ self.num_patches = feature_size[0] * feature_size[1]
+ self.proj = nn.Linear(feature_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.backbone(x)[-1]
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class ViT(nn.Module):
+ def __init__(self,
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
+ frozen_stages=-1, ratio=1, last_norm=True,
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
+ ):
+ super(ViT, self).__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.frozen_stages = frozen_stages
+ self.use_checkpoint = use_checkpoint
+ self.patch_padding = patch_padding
+ self.freeze_attn = freeze_attn
+ self.freeze_ffn = freeze_ffn
+ self.depth = depth
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
+ num_patches = self.patch_embed.num_patches
+
+ # since the pretraining model has class token
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ )
+ for i in range(depth)])
+
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
+
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ """Freeze parameters."""
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = self.blocks[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ if self.freeze_attn:
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.attn.eval()
+ m.norm1.eval()
+ for param in m.attn.parameters():
+ param.requires_grad = False
+ for param in m.norm1.parameters():
+ param.requires_grad = False
+
+ if self.freeze_ffn:
+ self.pos_embed.requires_grad = False
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ for i in range(0, self.depth):
+ m = self.blocks[i]
+ m.mlp.eval()
+ m.norm2.eval()
+ for param in m.mlp.parameters():
+ param.requires_grad = False
+ for param in m.norm2.parameters():
+ param.requires_grad = False
+
+ def init_weights(self):
+ """Initialize the weights in backbone.
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def forward_features(self, x):
+ B, C, H, W = x.shape
+ x, (Hp, Wp) = self.patch_embed(x)
+
+ if self.pos_embed is not None:
+ # fit for multiple GPU training
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
+
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+
+ x = self.last_norm(x)
+
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
+
+ return xp
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return x
+
+ def train(self, mode=True):
+ """Convert the model into training mode."""
+ super().train(mode)
+ self._freeze_stages()
diff --git a/unish/pi3/models/__pycache__/pi3.cpython-310.pyc b/unish/pi3/models/__pycache__/pi3.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf684a0686e2421834348445dea99c1ff1ee83f6
Binary files /dev/null and b/unish/pi3/models/__pycache__/pi3.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/__init__.py b/unish/pi3/models/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/unish/pi3/models/dinov2/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc b/unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74f1ac74ceef2665c099fb131260fe8c60ab9b17
Binary files /dev/null and b/unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/hub/__init__.py b/unish/pi3/models/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/unish/pi3/models/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc b/unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f341beaa63e0f12565db99dbf80f62990a1b5e3c
Binary files /dev/null and b/unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc b/unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e3f57770e850af3b2270ca9a506ac857cf1b8bf
Binary files /dev/null and b/unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc b/unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4070ae81a104de82af65856a2064c620b6e56b3c
Binary files /dev/null and b/unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/hub/backbones.py b/unish/pi3/models/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002
--- /dev/null
+++ b/unish/pi3/models/dinov2/hub/backbones.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/unish/pi3/models/dinov2/hub/utils.py b/unish/pi3/models/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/unish/pi3/models/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/unish/pi3/models/dinov2/layers/__init__.py b/unish/pi3/models/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc9d45493a230610adaff9d09d63998011d87a1c
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17a315c3ab8e7df03d23783e706a067dc885b4fa
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80f6ace66ec735bf8a773d9c06838dbfd1204537
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..73ed8eab17fe0deff52b4cff78997aaab91718e6
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2d19eca021d8c1da3bc814f776908a9900e0dba8
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bcb3f04b04ff276ec04c96f0f3276c744f0c2a4
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60dd990af2754fa38177ff00532b7b9cae6c08b9
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cdae295e4fd55b795c2a49b341940119e6ea9eba
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c727c861899667f0420341be12f167737b1be556
Binary files /dev/null and b/unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/layers/attention.py b/unish/pi3/models/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/unish/pi3/models/dinov2/layers/block.py b/unish/pi3/models/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/unish/pi3/models/dinov2/layers/dino_head.py b/unish/pi3/models/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/unish/pi3/models/dinov2/layers/drop_path.py b/unish/pi3/models/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/unish/pi3/models/dinov2/layers/layer_scale.py b/unish/pi3/models/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/unish/pi3/models/dinov2/layers/mlp.py b/unish/pi3/models/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/unish/pi3/models/dinov2/layers/patch_embed.py b/unish/pi3/models/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/unish/pi3/models/dinov2/layers/swiglu_ffn.py b/unish/pi3/models/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35
--- /dev/null
+++ b/unish/pi3/models/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ # warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ # warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/unish/pi3/models/dinov2/models/__init__.py b/unish/pi3/models/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265
--- /dev/null
+++ b/unish/pi3/models/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/unish/pi3/models/dinov2/models/__pycache__/__init__.cpython-310.pyc b/unish/pi3/models/dinov2/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..553e2730a144f4ff74c36837a832b930734bdf29
Binary files /dev/null and b/unish/pi3/models/dinov2/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/models/__pycache__/vision_transformer.cpython-310.pyc b/unish/pi3/models/dinov2/models/__pycache__/vision_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9f5e07dd28167b13f9730828c1f7f0cd0a910be
Binary files /dev/null and b/unish/pi3/models/dinov2/models/__pycache__/vision_transformer.cpython-310.pyc differ
diff --git a/unish/pi3/models/dinov2/models/vision_transformer.py b/unish/pi3/models/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..73f15cfb082d0fe629f8aa312c9d9b27a64ad4e7
--- /dev/null
+++ b/unish/pi3/models/dinov2/models/vision_transformer.py
@@ -0,0 +1,404 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+from ...layers.attention import FlashAttention
+
+
+# logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ # logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ # logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ # logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ attn_class=FlashAttention
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=False)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=False)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/unish/pi3/models/dinov2/utils/__init__.py b/unish/pi3/models/dinov2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/unish/pi3/models/dinov2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/unish/pi3/models/dinov2/utils/cluster.py b/unish/pi3/models/dinov2/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314
--- /dev/null
+++ b/unish/pi3/models/dinov2/utils/cluster.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/unish/pi3/models/dinov2/utils/config.py b/unish/pi3/models/dinov2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52
--- /dev/null
+++ b/unish/pi3/models/dinov2/utils/config.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import logging
+import os
+
+from omegaconf import OmegaConf
+
+import dinov2.distributed as distributed
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from dinov2.configs import dinov2_default_config
+
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/unish/pi3/models/dinov2/utils/dtype.py b/unish/pi3/models/dinov2/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8
--- /dev/null
+++ b/unish/pi3/models/dinov2/utils/dtype.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/unish/pi3/models/dinov2/utils/param_groups.py b/unish/pi3/models/dinov2/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f
--- /dev/null
+++ b/unish/pi3/models/dinov2/utils/param_groups.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
+ )
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
+
+ return all_param_groups
+
+
+def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
diff --git a/unish/pi3/models/dinov2/utils/utils.py b/unish/pi3/models/dinov2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8842e4145414f6f040c4ae83bf38552de8f65b2
--- /dev/null
+++ b/unish/pi3/models/dinov2/utils/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+
+# logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ # logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ # logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/unish/pi3/models/layers/__pycache__/attention.cpython-310.pyc b/unish/pi3/models/layers/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7c0670c07f432395b5d6da7cc67b68eb2328ea2
Binary files /dev/null and b/unish/pi3/models/layers/__pycache__/attention.cpython-310.pyc differ
diff --git a/unish/pi3/models/layers/__pycache__/block.cpython-310.pyc b/unish/pi3/models/layers/__pycache__/block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a8dfdaded6e56f75a0814edce53a7129be49463
Binary files /dev/null and b/unish/pi3/models/layers/__pycache__/block.cpython-310.pyc differ
diff --git a/unish/pi3/models/layers/__pycache__/camera_head.cpython-310.pyc b/unish/pi3/models/layers/__pycache__/camera_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..377db17e5ebceacd6d68f6ad00720000108c43a3
Binary files /dev/null and b/unish/pi3/models/layers/__pycache__/camera_head.cpython-310.pyc differ
diff --git a/unish/pi3/models/layers/__pycache__/pos_embed.cpython-310.pyc b/unish/pi3/models/layers/__pycache__/pos_embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0052132caf0df0163b61c774200ac7b4a9e08180
Binary files /dev/null and b/unish/pi3/models/layers/__pycache__/pos_embed.cpython-310.pyc differ
diff --git a/unish/pi3/models/layers/__pycache__/transformer_head.cpython-310.pyc b/unish/pi3/models/layers/__pycache__/transformer_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6399d862c11093d44732facdfeaea2e3a00e8dc
Binary files /dev/null and b/unish/pi3/models/layers/__pycache__/transformer_head.cpython-310.pyc differ
diff --git a/unish/pi3/models/layers/attention.py b/unish/pi3/models/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fed6682917436dac1b4f1893682dbef5128d625
--- /dev/null
+++ b/unish/pi3/models/layers/attention.py
@@ -0,0 +1,403 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+import torch
+
+from torch.nn.functional import scaled_dot_product_attention
+from torch.nn.attention import SDPBackend
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlashAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
+
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+
+ if q.dtype == torch.bfloat16:
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ x = scaled_dot_product_attention(q, k, v)
+ else:
+ with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
+ x = scaled_dot_product_attention(q, k, v)
+
+ x = x.transpose(1, 2).reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+"""
+Following is written by GPT-4o
+"""
+class CrossAttentionRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ qk_norm: bool = False,
+ norm_layer: nn.Module = nn.LayerNorm,
+ rope=None,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ # Separate projection layers for query, key, and value
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
+ self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
+ self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
+
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.rope = rope
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
+ """
+ Args:
+ query: Tensor of shape (B, N, C), input query
+ key: Tensor of shape (B, M, C), input key
+ value: Tensor of shape (B, M, C), input value
+ attn_bias: Optional tensor for attention bias
+ Returns:
+ Tensor of shape (B, N, C), output of cross-attention
+ """
+ B, N, C = query.shape
+ _, M, _ = key.shape
+
+ # Project query, key, and value
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ # Scale query
+ q = q * self.scale
+
+ # Compute attention scores
+ attn = q @ k.transpose(-2, -1) # (B, num_heads, N, M)
+ if attn_bias is not None:
+ attn = attn + attn_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ # Compute attention output
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C)
+
+ # Final projection
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffCrossAttentionRope(CrossAttentionRope):
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
+ """
+ Args:
+ query: Tensor of shape (B, N, C), input query
+ key: Tensor of shape (B, M, C), input key
+ value: Tensor of shape (B, M, C), input value
+ attn_bias: Optional tensor for attention bias
+ Returns:
+ Tensor of shape (B, N, C), output of cross-attention
+ """
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(query, key, value, attn_bias)
+
+ B, N, C = query.shape
+ _, M, _ = key.shape
+
+ # Project query, key, and value
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads)
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads)
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+
+ # Compute memory-efficient attention
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape(B, N, C)
+
+ # Final projection
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class FlashCrossAttentionRope(CrossAttentionRope):
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_bias=None, qpos=None, kpos=None) -> Tensor:
+ B, N, C = query.shape
+ _, M, _ = key.shape
+
+ # 1. 投射 query, key, value 并调整维度为 (B, num_heads, Seq_Len, head_dim)
+ q = self.q_proj(query).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ k = self.k_proj(key).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ v = self.v_proj(value).reshape(B, M, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+ if self.rope is not None:
+ q = self.rope(q, qpos)
+ k = self.rope(k, kpos)
+
+ dropout_p = self.attn_drop.p if self.training else 0.0
+
+ if q.dtype == torch.bfloat16:
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ x = scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_bias, dropout_p=dropout_p
+ )
+ else:
+ with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
+ x = scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_bias, dropout_p=dropout_p
+ )
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class AttentionRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ qk_norm: bool = False,
+ norm_layer: nn.Module = nn.LayerNorm,
+ rope=None
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
+
+ self.rope = rope
+
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttentionRope(AttentionRope):
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ qkv = qkv.transpose(1, 3)
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix
+ # global_valid_id = torch.where(score_matrix > 0)
+ # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class FlashAttentionRope(AttentionRope):
+ def forward(self, x: Tensor, attn_bias=None, xpos=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3)
+
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+ q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype)
+
+ if self.rope is not None:
+ q = self.rope(q, xpos)
+ k = self.rope(k, xpos)
+
+ if q.dtype == torch.bfloat16:
+ with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ x = scaled_dot_product_attention(q, k, v)
+ else:
+ with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
+ x = scaled_dot_product_attention(q, k, v)
+
+ x = x.transpose(1, 2).reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+def get_attn_score(blk_class, x, frame_num, token_length, xpos=None):
+ x = blk_class.norm1(x)
+
+ B, N, C = x.shape
+ qkv = blk_class.attn.qkv(x).reshape(B, N, 3, blk_class.attn.num_heads, C // blk_class.attn.num_heads)
+
+ qkv = qkv.transpose(1, 3)
+ # q, k, v = unbind(qkv, 2)
+ q, k, v = [qkv[:,:,i] for i in range(3)]
+ q, k = blk_class.attn.q_norm(q).to(v.dtype), blk_class.attn.k_norm(k).to(v.dtype)
+
+ if blk_class.attn.rope is not None:
+ q = blk_class.attn.rope(q, xpos)
+ k = blk_class.attn.rope(k, xpos)
+
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+
+ score = (q.permute(0, 2, 1, 3) * blk_class.attn.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(B, frame_num, token_length, frame_num, token_length).mean(dim=[2, 4]).sum(-1)
+
+ return score
\ No newline at end of file
diff --git a/unish/pi3/models/layers/block.py b/unish/pi3/models/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2c1f95bd4ff36c6a7fafc2e004364c030e20c7d
--- /dev/null
+++ b/unish/pi3/models/layers/block.py
@@ -0,0 +1,406 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope
+from ..dinov2.layers.drop_path import DropPath
+from ..dinov2.layers.layer_scale import LayerScale
+from ..dinov2.layers.mlp import Mlp
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+class BlockRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool=False,
+ rope=None
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ rope=rope
+ )
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, xpos=None) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), xpos=xpos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+class CrossBlockRope(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ cross_attn_class: Callable[..., nn.Module] = CrossAttentionRope,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ init_values=None,
+ qk_norm: bool=False,
+ rope=None
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ rope=rope,
+ qk_norm=qk_norm
+ )
+
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.ls_y = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.norm_y = norm_layer(dim)
+ self.cross_attn = cross_attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ rope=rope,
+ qk_norm=qk_norm
+ )
+
+ self.norm3 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ bias=ffn_bias,
+ )
+
+ def forward(self, x: Tensor, y: Tensor, xpos=None, ypos=None) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), xpos=xpos))
+
+ def cross_attn_residual_func(x: Tensor, y: Tensor) -> Tensor:
+ return self.ls_y(self.cross_attn(self.norm2(x), y, y, qpos=xpos, kpos=ypos))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm3(x)))
+
+ x = x + attn_residual_func(x)
+ y_ = self.norm_y(y)
+ x = x + cross_attn_residual_func(x, y_)
+ x = x + ffn_residual_func(x)
+
+ return x
\ No newline at end of file
diff --git a/unish/pi3/models/layers/camera_head.py b/unish/pi3/models/layers/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d844f7b76851c3e523e419e18358838e9d23410
--- /dev/null
+++ b/unish/pi3/models/layers/camera_head.py
@@ -0,0 +1,93 @@
+import torch
+import torch.nn as nn
+from copy import deepcopy
+import torch.nn.functional as F
+
+# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172'
+class ResConvBlock(nn.Module):
+ """
+ 1x1 convolution residual block
+ """
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.head_skip = nn.Identity() if self.in_channels == self.out_channels else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
+ # self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0)
+ # self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
+ # self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
+
+ # change 1x1 convolution to linear
+ self.res_conv1 = nn.Linear(self.in_channels, self.out_channels)
+ self.res_conv2 = nn.Linear(self.out_channels, self.out_channels)
+ self.res_conv3 = nn.Linear(self.out_channels, self.out_channels)
+
+ def forward(self, res):
+ x = F.relu(self.res_conv1(res))
+ x = F.relu(self.res_conv2(x))
+ x = F.relu(self.res_conv3(x))
+ res = self.head_skip(res) + x
+ return res
+
+class CameraHead(nn.Module):
+ def __init__(self, dim=512):
+ super().__init__()
+ output_dim = dim
+ self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim))
+ for _ in range(2)])
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.more_mlps = nn.Sequential(
+ nn.Linear(output_dim,output_dim),
+ nn.ReLU(),
+ nn.Linear(output_dim,output_dim),
+ nn.ReLU()
+ )
+ self.fc_t = nn.Linear(output_dim, 3)
+ self.fc_rot = nn.Linear(output_dim, 9)
+
+ def forward(self, feat, patch_h, patch_w):
+ BN, hw, c = feat.shape
+
+ for i in range(2):
+ feat = self.res_conv[i](feat)
+
+ # feat = self.avgpool(feat)
+ feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ##########
+ feat = feat.view(feat.size(0), -1)
+
+ feat = self.more_mlps(feat) # [B, D_]
+ with torch.amp.autocast(device_type='cuda', enabled=False):
+ out_t = self.fc_t(feat.float()) # [B,3]
+ out_r = self.fc_rot(feat.float()) # [B,9]
+ pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device)
+
+ return pose
+
+ def convert_pose_to_4x4(self, B, out_r, out_t, device):
+ out_r = self.svd_orthogonalize(out_r) # [N,3,3]
+ pose = torch.zeros((B, 4, 4), device=device)
+ pose[:, :3, :3] = out_r
+ pose[:, :3, 3] = out_t
+ pose[:, 3, 3] = 1.
+ return pose
+
+ def svd_orthogonalize(self, m):
+ """Convert 9D representation to SO(3) using SVD orthogonalization.
+
+ Args:
+ m: [BATCH, 3, 3] 3x3 matrices.
+
+ Returns:
+ [BATCH, 3, 3] SO(3) rotation matrices.
+ """
+ if m.dim() < 3:
+ m = m.reshape((-1, 3, 3))
+ m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2)
+ u, s, v = torch.svd(m_transpose)
+ det = torch.det(torch.matmul(v, u.transpose(-2, -1)))
+ # Check orientation reflection.
+ r = torch.matmul(
+ torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2),
+ u.transpose(-2, -1)
+ )
+ return r
\ No newline at end of file
diff --git a/unish/pi3/models/layers/pos_embed.py b/unish/pi3/models/layers/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27ea0fce111bc3ba49a1e9f0062f956101116b8
--- /dev/null
+++ b/unish/pi3/models/layers/pos_embed.py
@@ -0,0 +1,174 @@
+# Copyright (C) 2022-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+
+
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+
+
+import numpy as np
+
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if n_cls_token>0:
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
+
+
+#----------------------------------------------------------
+# RoPE2D: RoPE implementation in 2D
+#----------------------------------------------------------
+
+try:
+ from models.curope import cuRoPE2D
+ RoPE2D = cuRoPE2D
+except ImportError:
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
+
+ class RoPE2D(torch.nn.Module):
+
+ def __init__(self, freq=100.0, F0=1.0):
+ super().__init__()
+ self.base = freq
+ self.F0 = F0
+ self.cache = {}
+
+ def get_cos_sin(self, D, seq_len, device, dtype):
+ if (D,seq_len,device,dtype) not in self.cache:
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ cos = freqs.cos() # (Seq, Dim)
+ sin = freqs.sin()
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
+ return self.cache[D,seq_len,device,dtype]
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
+ assert pos1d.ndim==2
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
+
+ def forward(self, tokens, positions):
+ """
+ input:
+ * tokens: batch_size x nheads x ntokens x dim
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
+ output:
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
+ """
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
+ D = tokens.size(3) // 2
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
+ # split features into two along the feature dimension, and apply rope1d on each half
+ y, x = tokens.chunk(2, dim=-1)
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
+ tokens = torch.cat((y, x), dim=-1)
+ return tokens
+
+# patch embedding
+class PositionGetter(object):
+ """ return positions of patches """
+
+ def __init__(self):
+ self.cache_positions = {}
+
+ def __call__(self, b, h, w, device):
+ if not (h,w) in self.cache_positions:
+ x = torch.arange(w, device=device)
+ y = torch.arange(h, device=device)
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
+ return pos
\ No newline at end of file
diff --git a/unish/pi3/models/layers/transformer_head.py b/unish/pi3/models/layers/transformer_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba64e72eaac4ae278b0bfc6846b4c73e49ecc808
--- /dev/null
+++ b/unish/pi3/models/layers/transformer_head.py
@@ -0,0 +1,337 @@
+from .attention import FlashAttentionRope, FlashCrossAttentionRope
+from .block import BlockRope, CrossBlockRope
+from ..dinov2.layers import Mlp
+import torch.nn as nn
+from functools import partial
+from torch.utils.checkpoint import checkpoint
+import torch
+import torch.nn.functional as F
+from typing import *
+import functools
+import itertools
+
+class TransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ dec_embed_dim=512,
+ depth=5,
+ dec_num_heads=8,
+ mlp_ratio=4,
+ rope=None,
+ need_project=True,
+ use_checkpoint=False,
+ return_intermediate_layers=None, # list of layer indices to return
+ ):
+ super().__init__()
+
+ self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
+ self.use_checkpoint = use_checkpoint
+ self.return_intermediate_layers = return_intermediate_layers or []
+
+ self.blocks = nn.ModuleList([
+ BlockRope(
+ dim=dec_embed_dim,
+ num_heads=dec_num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ drop_path=0.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ ffn_layer=Mlp,
+ init_values=None,
+ qk_norm=False,
+ # attn_class=MemEffAttentionRope,
+ attn_class=FlashAttentionRope,
+ rope=rope
+ ) for _ in range(depth)])
+
+ self.linear_out = nn.Linear(dec_embed_dim, out_dim)
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.use_checkpoint = True
+
+ def forward(self, hidden, xpos=None):
+ hidden = self.projects(hidden)
+ intermediate_features = [hidden]
+
+ for i, blk in enumerate(self.blocks):
+ if self.use_checkpoint and self.training:
+ hidden = checkpoint(blk, hidden, xpos=xpos, use_reentrant=False)
+ else:
+ hidden = blk(hidden, xpos=xpos)
+
+ # Store intermediate features if requested
+ if i in self.return_intermediate_layers:
+ intermediate_features.append(hidden)
+
+ out = self.linear_out(hidden)
+
+ if self.return_intermediate_layers:
+ return out, intermediate_features
+ return out
+
+class LinearPts3d (nn.Module):
+ """
+ Linear head for dust3r
+ Each token outputs: - 16x16 3D points (+ confidence)
+ """
+
+ def __init__(self, patch_size, dec_embed_dim, output_dim=3,):
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Linear(dec_embed_dim, (output_dim)*self.patch_size**2)
+
+ def enable_gradient_checkpointing(self):
+ """Enable gradient checkpointing for memory optimization."""
+ # LinearPts3d is relatively simple, gradient checkpointing might not be necessary
+ # but we provide this method for consistency
+ pass
+
+ def forward(self, decout, img_shape):
+ H, W = img_shape
+ tokens = decout[-1]
+ B, S, D = tokens.shape
+
+ # extract 3D points
+ feat = self.proj(tokens) # B,S,D
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
+
+ # permute + norm depth
+ return feat.permute(0, 2, 3, 1)
+
+
+class ContextTransformerDecoder(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ dec_embed_dim=512,
+ depth=5,
+ dec_num_heads=8,
+ mlp_ratio=4,
+ rope=None,
+ ):
+ super().__init__()
+
+ self.projects_x = nn.Linear(in_dim, dec_embed_dim)
+ self.projects_y = nn.Linear(in_dim, dec_embed_dim)
+
+ self.blocks = nn.ModuleList([
+ CrossBlockRope(
+ dim=dec_embed_dim,
+ num_heads=dec_num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ ffn_layer=Mlp,
+ init_values=None,
+ qk_norm=False,
+ # attn_class=MemEffAttentionRope,
+ # cross_attn_class=MemEffCrossAttentionRope,
+ attn_class=FlashAttentionRope,
+ cross_attn_class=FlashCrossAttentionRope,
+ rope=rope
+ ) for _ in range(depth)])
+
+ self.linear_out = nn.Linear(dec_embed_dim, out_dim)
+ self.use_checkpoint = False
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ self.use_checkpoint = True
+
+ def forward(self, hidden, context, xpos=None, ypos=None):
+ hidden = self.projects_x(hidden)
+ context = self.projects_y(context)
+
+ for i, blk in enumerate(self.blocks):
+ if self.use_checkpoint and self.training:
+ hidden = checkpoint(blk, hidden, context, xpos=xpos, ypos=ypos, use_reentrant=False)
+ else:
+ hidden = blk(hidden, context, xpos=xpos, ypos=ypos)
+
+ out = self.linear_out(hidden)
+
+ return out
+
+
+def wrap_module_with_gradient_checkpointing(module: nn.Module):
+ from torch.utils.checkpoint import checkpoint
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ hidden_channels: int = None,
+ kernel_size: int = 3,
+ padding_mode: str = 'replicate',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
+ ):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = nn.ReLU
+ elif activation == 'leaky_relu':
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
+ elif activation =='silu':
+ activation_cls = nn.SiLU
+ elif activation == 'elu':
+ activation_cls = nn.ELU
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class Resampler(nn.Sequential):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
+ scale_factor: int = 2,
+ ):
+ if type_ == 'pixel_shuffle':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.PixelShuffle(scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ for i in range(1, scale_factor ** 2):
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
+ elif type_ in ['nearest', 'bilinear']:
+ nn.Sequential.__init__(self,
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'conv_transpose':
+ nn.Sequential.__init__(self,
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
+ elif type_ == 'pixel_unshuffle':
+ nn.Sequential.__init__(self,
+ nn.PixelUnshuffle(scale_factor),
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'avg_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ elif type_ == 'max_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ else:
+ raise ValueError(f'Unsupported resampler type: {type_}')
+
+
+class ConvStack(nn.Module):
+ def __init__(self,
+ dim_in: List[Optional[int]],
+ dim_res_blocks: List[int],
+ dim_out: List[Optional[int]],
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ ):
+ super().__init__()
+ self.input_blocks = nn.ModuleList([
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
+ ])
+ self.resamplers = nn.ModuleList([
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
+ dim_res_blocks[:-1],
+ dim_res_blocks[1:],
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
+ ))
+ ])
+ self.res_blocks = nn.ModuleList([
+ nn.Sequential(
+ *(
+ ResidualConvBlock(
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
+ )
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
+ ])
+ self.output_blocks = nn.ModuleList([
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
+ ])
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.resamplers)):
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
+ for i in range(len(self.res_blocks)):
+ for j in range(len(self.res_blocks[i])):
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
+
+ def forward(self, in_features: List[torch.Tensor]):
+ out_features = []
+ for i in range(len(self.res_blocks)):
+ feature = self.input_blocks[i](in_features[i])
+ if i == 0:
+ x = feature
+ elif feature is not None:
+ x = x + feature
+ x = self.res_blocks[i](x)
+ out_features.append(self.output_blocks[i](x))
+ if i < len(self.res_blocks) - 1:
+ x = self.resamplers[i](x)
+ return out_features
\ No newline at end of file
diff --git a/unish/pi3/models/loss.py b/unish/pi3/models/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2845ac9b795ece36d7b4d3e934b58d884f4ae6ae
--- /dev/null
+++ b/unish/pi3/models/loss.py
@@ -0,0 +1,358 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from typing import *
+import math
+
+from ..utils.geometry import homogenize_points, se3_inverse, depth_edge
+from ..utils.alignment import align_points_scale, align_points_scale_z_shift
+
+# from datasets import __HIGH_QUALITY_DATASETS__, __MIDDLE_QUALITY_DATASETS__
+__HIGH_QUALITY_DATASETS__ = ["bedlam"]
+
+# ---------------------------------------------------------------------------
+# Some functions from MoGe
+# ---------------------------------------------------------------------------
+
+def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.mean(dim=dim, keepdim=keepdim)
+ else:
+ w = w.to(x.dtype)
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
+
+def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
+ if beta == 0:
+ return err
+ else:
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
+
+def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
+ return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
+
+# ---------------------------------------------------------------------------
+# PointLoss: Scale-invariant Local Pointmap
+# ---------------------------------------------------------------------------
+
+class PointLoss(nn.Module):
+ def __init__(self, local_align_res=4096, train_conf=False, expected_dist_thresh=0.02):
+ super().__init__()
+ self.local_align_res = local_align_res
+ self.criteria_local = nn.L1Loss(reduction='none')
+
+ self.train_conf = train_conf
+ if self.train_conf:
+ self.prepare_segformer()
+ self.conf_loss_fn = torch.nn.BCEWithLogitsLoss()
+ self.expected_dist_thresh = expected_dist_thresh
+
+ def prepare_segformer(self):
+ from pi3.models.segformer.model import EncoderDecoder
+ self.segformer = EncoderDecoder()
+ self.segformer.load_state_dict(torch.load('ckpts/segformer.b0.512x512.ade.160k.pth', map_location=torch.device('cpu'), weights_only=False)['state_dict'])
+ self.segformer = self.segformer.cuda()
+
+ def predict_sky_mask(self, imgs):
+ with torch.no_grad():
+ output = self.segformer.inference_(imgs)
+ output = output == 2
+ return output
+
+ def prepare_ROE(self, pts, mask, target_size=4096):
+ B, N, H, W, C = pts.shape
+ output = []
+
+ for i in range(B):
+ valid_pts = pts[i][mask[i]]
+
+ if valid_pts.shape[0] > 0:
+ valid_pts = valid_pts.permute(1, 0).unsqueeze(0) # (1, 3, N1)
+ # NOTE: Is is important to use nearest interpolate. Linear interpolate will lead to unstable result!
+ valid_pts = F.interpolate(valid_pts, size=target_size, mode='nearest') # (1, 3, target_size)
+ valid_pts = valid_pts.squeeze(0).permute(1, 0) # (target_size, 3)
+ else:
+ valid_pts = torch.ones((target_size, C), device=valid_pts.device)
+
+ output.append(valid_pts)
+
+ return torch.stack(output, dim=0)
+
+ def noraml_loss(self, points, gt_points, mask):
+ not_edge = ~depth_edge(gt_points[..., 2], rtol=0.03)
+ mask = torch.logical_and(mask, not_edge)
+
+ leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :]
+ upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
+ leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
+ downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
+ rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
+
+ gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :]
+ gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1)
+ gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1)
+ gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
+ gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1)
+
+ mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:]
+ mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
+ mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
+ mask_downxright = mask_leftdown & mask_rightup & mask_leftup
+ mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
+
+ loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+
+ loss = loss.mean() / (4 * max(points.shape[-3:-1]))
+
+ return loss
+
+ def forward(self, pred, gt):
+ pred_local_pts = pred['local_points']
+ gt_local_pts = gt['local_points']
+
+ valid_masks = gt['point_masks']
+ gt_human_mask = gt['human_masks'].clone() if gt['human_masks'] is not None else torch.zeros_like(valid_masks)
+
+ # valid_masks = torch.logical_and(valid_masks, ~gt_human_mask)
+
+ # if gt["pointmaps_relative_valid"] is not None:
+ # # pm_valid_mask = gt["pointmaps_relative_valid"].view(gt_local_pts.shape[0], -1)
+ # # valid_ratio = pm_valid_mask.sum(dim=-1) / pm_valid_mask.shape[-1]
+ # # valid_scene = valid_ratio > 0.8
+ # # valid_masks = valid_masks & valid_scene.view(-1, 1, 1, 1)
+ # valid_masks = valid_masks & gt["pointmaps_relative_valid"]
+
+ details = dict()
+ final_loss = 0.0
+
+ B, N, H, W, _ = pred_local_pts.shape
+
+ weights_ = gt_local_pts[..., 2]
+ weights_ = weights_.clamp_min(0.1 * weighted_mean(weights_, valid_masks, dim=(-2, -1), keepdim=True))
+ weights_ = 1 / (weights_ + 1e-6)
+
+ # alignment
+ with torch.no_grad():
+ xyz_pred_local = self.prepare_ROE(pred_local_pts.reshape(B, N, H, W, 3), valid_masks.reshape(B, N, H, W), target_size=self.local_align_res).contiguous()
+ xyz_gt_local = self.prepare_ROE(gt_local_pts.reshape(B, N, H, W, 3), valid_masks.reshape(B, N, H, W), target_size=self.local_align_res).contiguous()
+ xyz_weights_local = self.prepare_ROE((weights_[..., None]).reshape(B, N, H, W, 1), valid_masks.reshape(B, N, H, W), target_size=self.local_align_res).contiguous()[:, :, 0]
+
+ S_opt_local = align_points_scale(xyz_pred_local, xyz_gt_local, xyz_weights_local)
+ # S_opt_local, _ = align_points_scale_z_shift(xyz_pred_local, xyz_gt_local, xyz_weights_local)
+ S_opt_local[S_opt_local <= 0] *= -1
+
+ aligned_local_pts = S_opt_local.view(B, 1, 1, 1, 1) * pred_local_pts
+
+ # local point loss
+ local_pts_loss = self.criteria_local(aligned_local_pts[valid_masks].float(), gt_local_pts[valid_masks].float()) * weights_[valid_masks].float()[..., None]
+
+ # conf loss
+ if self.train_conf:
+ pred_conf = pred['point_conf']
+
+ # probability loss
+ valid = local_pts_loss.detach().mean(-1, keepdims=True) < self.expected_dist_thresh
+ local_conf_loss = self.conf_loss_fn(pred_conf[valid_masks], valid.float())
+
+ sky_mask = self.predict_sky_mask(gt['imgs'].reshape(B*N, 3, H, W)).reshape(B, N, H, W)
+ sky_mask[valid_masks] = False
+ if sky_mask.sum() == 0:
+ sky_mask_loss = 0.0 * aligned_local_pts.mean()
+ else:
+ sky_mask_loss = self.conf_loss_fn(pred_conf[sky_mask], torch.zeros_like(pred_conf[sky_mask]))
+
+ final_loss += 0.05 * (local_conf_loss + sky_mask_loss)
+ details['local_conf_loss'] = (local_conf_loss + sky_mask_loss)
+
+ final_loss += local_pts_loss.mean()
+ details['local_pts_loss'] = local_pts_loss.mean()
+
+ # normal loss
+ normal_batch_id = [i for i in range(len(gt['datasets_name'])) if gt['datasets_name'][i] in __HIGH_QUALITY_DATASETS__]
+ if len(normal_batch_id) == 0:
+ normal_loss = 0.0 * aligned_local_pts.mean()
+ else:
+ normal_loss = self.noraml_loss(aligned_local_pts[normal_batch_id], gt_local_pts[normal_batch_id], valid_masks[normal_batch_id])
+ final_loss += normal_loss.mean()
+ details['normal_loss'] = normal_loss.mean()
+
+ # # [Optional] Global Point Loss
+ # if 'global_points' in pred and pred['global_points'] is not None:
+ # gt_pts = gt['global_points']
+
+ # pred_global_pts = pred['global_points'] * S_opt_local.view(B, 1, 1, 1, 1)
+ # global_pts_loss = self.criteria_local(pred_global_pts[valid_masks].float(), gt_pts[valid_masks].float()) * weights_[valid_masks].float()[..., None]
+
+ # final_loss += global_pts_loss.mean()
+ # details['global_pts_loss'] = global_pts_loss.mean()
+
+ return final_loss, S_opt_local, details
+
+# ---------------------------------------------------------------------------
+# CameraLoss: Affine-invariant Camera Pose
+# ---------------------------------------------------------------------------
+
+class CameraLoss(nn.Module):
+ def __init__(self, alpha=100):
+ super().__init__()
+ self.alpha = alpha
+
+ def rot_ang_loss(self, R, Rgt, eps=1e-6):
+ """
+ Args:
+ R: estimated rotation matrix [B, 3, 3]
+ Rgt: ground-truth rotation matrix [B, 3, 3]
+ Returns:
+ R_err: rotation angular error
+ """
+ residual = torch.matmul(R.transpose(1, 2), Rgt)
+ trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
+ cosine = (trace - 1) / 2
+ R_err = torch.acos(torch.clamp(cosine, -1.0 + eps, 1.0 - eps)) # handle numerical errors and NaNs
+ return R_err.mean() # [0, 3.14]
+
+ def forward(self, pred, gt, scale):
+ pred_pose = pred['c2ws']
+ gt_pose = se3_inverse(gt['extrinsics'])
+
+ B, N, _, _ = pred_pose.shape
+
+ pred_pose_align = pred_pose.clone()
+ pred_pose_align[..., :3, 3] *= scale.view(B, 1, 1)
+
+ pred_w2c = se3_inverse(pred_pose_align)
+ gt_w2c = gt['extrinsics']
+
+ pred_w2c_exp = pred_w2c.unsqueeze(2)
+ pred_pose_exp = pred_pose_align.unsqueeze(1)
+
+ gt_w2c_exp = gt_w2c.unsqueeze(2)
+ gt_pose_exp = gt_pose.unsqueeze(1)
+
+ pred_rel_all = torch.matmul(pred_w2c_exp, pred_pose_exp)
+ gt_rel_all = torch.matmul(gt_w2c_exp, gt_pose_exp)
+
+ mask = ~torch.eye(N, dtype=torch.bool, device=pred_pose.device)
+
+ t_pred = pred_rel_all[..., :3, 3][:, mask, ...]
+ R_pred = pred_rel_all[..., :3, :3][:, mask, ...]
+
+ t_gt = gt_rel_all[..., :3, 3][:, mask, ...]
+ R_gt = gt_rel_all[..., :3, :3][:, mask, ...]
+
+ trans_loss = F.huber_loss(t_pred, t_gt, reduction='mean', delta=0.1)
+
+ rot_loss = self.rot_ang_loss(
+ R_pred.reshape(-1, 3, 3),
+ R_gt.reshape(-1, 3, 3)
+ )
+
+ total_loss = self.alpha * trans_loss + rot_loss
+
+ return total_loss, dict(trans_loss=trans_loss, rot_loss=rot_loss)
+
+# ---------------------------------------------------------------------------
+# Final Loss
+# ---------------------------------------------------------------------------
+
+class Pi3Loss(nn.Module):
+ def __init__(
+ self,
+ train_conf=False,
+ ):
+ super().__init__()
+ self.point_loss = PointLoss(train_conf=train_conf)
+ self.camera_loss = CameraLoss()
+
+ # def prepare_gt(self, gt):
+ # gt_pts = torch.stack([view['pts3d'] for view in gt], dim=1)
+ # masks = torch.stack([view['valid_mask'] for view in gt], dim=1)
+ # poses = torch.stack([view['camera_pose'] for view in gt], dim=1)
+
+ # B, N, H, W, _ = gt_pts.shape
+
+ # # transform to first frame camera coordinate
+ # w2c_target = se3_inverse(poses[:, 0])
+ # gt_pts = torch.einsum('bij, bnhwj -> bnhwi', w2c_target, homogenize_points(gt_pts))[..., :3]
+ # poses = torch.einsum('bij, bnjk -> bnik', w2c_target, poses)
+
+ # # normalize points
+ # valid_batch = masks.sum([-1, -2, -3]) > 0
+ # if valid_batch.sum() > 0:
+ # B_ = valid_batch.sum()
+ # all_pts = gt_pts[valid_batch].clone()
+ # all_pts[~masks[valid_batch]] = 0
+ # all_pts = all_pts.reshape(B_, N, -1, 3)
+ # all_dis = all_pts.norm(dim=-1)
+ # norm_factor = all_dis.sum(dim=[-1, -2]) / (masks[valid_batch].float().sum(dim=[-1, -2, -3]) + 1e-8)
+
+ # gt_pts[valid_batch] = gt_pts[valid_batch] / norm_factor[..., None, None, None, None]
+ # poses[valid_batch, ..., :3, 3] /= norm_factor[..., None, None]
+
+ # extrinsics = se3_inverse(poses)
+ # gt_local_pts = torch.einsum('bnij, bnhwj -> bnhwi', extrinsics, homogenize_points(gt_pts))[..., :3]
+
+ # dataset_names = gt[0]['dataset']
+
+ # return dict(
+ # imgs = torch.stack([view['img'] for view in gt], dim=1),
+ # global_points=gt_pts,
+ # local_points=gt_local_pts,
+ # valid_masks=masks,
+ # camera_poses=poses,
+ # dataset_names=dataset_names
+ # )
+
+ def normalize_pred(self, pred, gt):
+ local_points = pred['local_points']
+ camera_poses = pred['c2ws']
+ B, N, H, W, _ = local_points.shape
+ masks = gt['point_masks']
+
+ masks = masks & pred["masks"]
+ # if gt['pointmaps_relative_valid'] is not None:
+ # masks = masks & gt['pointmaps_relative_valid']
+
+ # normalize predict points
+ all_pts = local_points.clone()
+ all_pts[~masks] = 0
+ all_pts = all_pts.reshape(B, N, -1, 3)
+ all_dis = all_pts.norm(dim=-1)
+ norm_factor = all_dis.sum(dim=[-1, -2]) / (masks.float().sum(dim=[-1, -2, -3]) + 1e-8)
+ local_points = local_points / norm_factor[..., None, None, None, None]
+
+ # if 'global_points' in pred and pred['global_points'] is not None:
+ # pred['global_points'] /= norm_factor[..., None, None, None, None]
+
+ camera_poses_normalized = camera_poses.clone()
+ camera_poses_normalized[..., :3, 3] /= norm_factor.view(B, 1, 1)
+
+ pred['local_points'] = local_points
+ pred['c2ws'] = camera_poses_normalized
+
+ return pred, norm_factor
+
+ def forward(self, pred, gt):
+ # gt = self.prepare_gt(gt_raw)
+ pred, norm_factor = self.normalize_pred(pred, gt)
+
+ details = dict()
+ details['norm_factor'] = norm_factor
+
+ # Local Point Loss
+ point_loss, scale, point_loss_details = self.point_loss(pred, gt)
+ details.update(point_loss_details)
+
+ if point_loss.isnan():
+ point_loss = torch.tensor(0.0, device=point_loss.device, dtype=point_loss.dtype)
+
+ # Camera Loss
+ camera_loss, camera_loss_details = self.camera_loss(pred, gt, scale)
+ details.update(camera_loss_details)
+
+ return point_loss, camera_loss, scale, details
+
diff --git a/unish/pi3/models/pi3.py b/unish/pi3/models/pi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc9ab3491ecce410fd4beda44360288771b75c6
--- /dev/null
+++ b/unish/pi3/models/pi3.py
@@ -0,0 +1,319 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from copy import deepcopy
+import itertools
+
+from .dinov2.layers import Mlp
+from ..utils.geometry import homogenize_points, se3_inverse
+from .layers.pos_embed import RoPE2D, PositionGetter
+from .layers.block import BlockRope
+from .layers.attention import FlashAttentionRope
+from .layers.transformer_head import TransformerDecoder, LinearPts3d, ConvStack
+from .layers.camera_head import CameraHead
+from ...heads.dpt_head import DPTHead
+from .dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg
+from huggingface_hub import PyTorchModelHubMixin
+
+class Pi3(nn.Module, PyTorchModelHubMixin):
+ def __init__(
+ self,
+ pos_type='rope100',
+ decoder_size='large',
+ ):
+ super().__init__()
+
+ # ----------------------
+ # Encoder
+ # ----------------------
+ self.encoder = dinov2_vitl14_reg(pretrained=False)
+ self.patch_size = 14
+ del self.encoder.mask_token
+
+ # ----------------------
+ # Positonal Encoding
+ # ----------------------
+ self.pos_type = pos_type if pos_type is not None else 'none'
+ self.rope=None
+ if self.pos_type.startswith('rope'): # eg rope100
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
+ freq = float(self.pos_type[len('rope'):])
+ self.rope = RoPE2D(freq=freq)
+ self.position_getter = PositionGetter()
+ else:
+ raise NotImplementedError
+
+
+ # ----------------------
+ # Decoder
+ # ----------------------
+ enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024
+ if decoder_size == 'small':
+ dec_embed_dim = 384
+ dec_num_heads = 6
+ mlp_ratio = 4
+ dec_depth = 24
+ elif decoder_size == 'base':
+ dec_embed_dim = 768
+ dec_num_heads = 12
+ mlp_ratio = 4
+ dec_depth = 24
+ elif decoder_size == 'large':
+ dec_embed_dim = 1024
+ dec_num_heads = 16
+ mlp_ratio = 4
+ dec_depth = 36
+ else:
+ raise NotImplementedError
+ self.decoder = nn.ModuleList([
+ BlockRope(
+ dim=dec_embed_dim,
+ num_heads=dec_num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ drop_path=0.0,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ ffn_layer=Mlp,
+ init_values=0.01,
+ qk_norm=True,
+ attn_class=FlashAttentionRope,
+ rope=self.rope
+ ) for _ in range(dec_depth)])
+ self.dec_embed_dim = dec_embed_dim
+
+ # ----------------------
+ # Register_token
+ # ----------------------
+ num_register_tokens = 5
+ self.patch_start_idx = num_register_tokens
+ self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim))
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # ----------------------
+ # Local Points Decoder
+ # ----------------------
+ self.point_decoder = TransformerDecoder(
+ in_dim=2*self.dec_embed_dim,
+ dec_embed_dim=1024,
+ dec_num_heads=16,
+ out_dim=1024,
+ rope=self.rope,
+ return_intermediate_layers=[0, 2, 4], # layers 1, 3, 5 (0-indexed)
+ )
+ # DPTHead for multi-scale point prediction
+ self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3)
+ # self.point_head_2 = DPTHead(
+ # dim_in=self.dec_embed_dim, # input dimension from decoder
+ # patch_size=14,
+ # output_dim=4, # 3D points + 1 dummy channel (we'll ignore confidence)
+ # activation="linear", # linear activation for raw 3D coordinates
+ # conf_activation="expp1", # for the dummy confidence channel
+ # intermediate_layer_idx=[0, 1, 2, 3], # use input + 3 intermediate layers
+ # )
+
+ # # Zero-initialize point_head_2 for ControlNet-like behavior
+ # self._zero_init_point_head_2()
+
+ # ----------------------
+ # Conf Decoder
+ # ----------------------
+ self.conf_decoder = deepcopy(self.point_decoder)
+ self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1)
+
+ # ----------------------
+ # Camera Pose Decoder
+ # ----------------------
+ self.camera_decoder = TransformerDecoder(
+ in_dim=2*self.dec_embed_dim,
+ dec_embed_dim=1024,
+ dec_num_heads=16, # 8
+ out_dim=512,
+ rope=self.rope,
+ use_checkpoint=False
+ )
+ self.camera_head = CameraHead(dim=512)
+
+ # For ImageNet Normalize
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+
+ self.register_buffer("image_mean", image_mean)
+ self.register_buffer("image_std", image_std)
+
+ def gradient_checkpointing_enable(self):
+ """Enable gradient checkpointing for memory optimization."""
+ # Enable gradient checkpointing for encoder (DinoV2)
+ if hasattr(self.encoder, 'gradient_checkpointing_enable'):
+ self.encoder.gradient_checkpointing_enable()
+
+ # Enable gradient checkpointing for decoder blocks
+ from torch.utils.checkpoint import checkpoint
+ from .layers.transformer_head import wrap_module_with_gradient_checkpointing
+ for i, block in enumerate(self.decoder):
+ self.decoder[i] = wrap_module_with_gradient_checkpointing(block)
+
+ # Enable gradient checkpointing for transformer decoders
+ if hasattr(self.point_decoder, 'gradient_checkpointing_enable'):
+ self.point_decoder.gradient_checkpointing_enable()
+ if hasattr(self.conf_decoder, 'gradient_checkpointing_enable'):
+ self.conf_decoder.gradient_checkpointing_enable()
+ if hasattr(self.camera_decoder, 'gradient_checkpointing_enable'):
+ self.camera_decoder.gradient_checkpointing_enable()
+
+ # Enable gradient checkpointing for heads that support it
+ if hasattr(self.point_head, 'enable_gradient_checkpointing'):
+ self.point_head.enable_gradient_checkpointing()
+ if hasattr(self.conf_head, 'enable_gradient_checkpointing'):
+ self.conf_head.enable_gradient_checkpointing()
+
+ def _zero_init_point_head_2(self):
+ """
+ Zero-initialize the final output layer of point_head_2 for ControlNet-like behavior.
+ This ensures that point_head_2 outputs zero initially and doesn't affect the original model.
+ """
+ # Zero-initialize the final output convolution layer
+ final_conv = self.point_head_2.scratch.output_conv2[-1] # Last Conv2d layer
+ nn.init.zeros_(final_conv.weight)
+ if final_conv.bias is not None:
+ nn.init.zeros_(final_conv.bias)
+
+ # Optionally, also zero-initialize the output convolution layers in fusion blocks
+ # This provides additional stability during early training
+ for refinenet in [self.point_head_2.scratch.refinenet1,
+ self.point_head_2.scratch.refinenet2,
+ self.point_head_2.scratch.refinenet3,
+ self.point_head_2.scratch.refinenet4]:
+ if hasattr(refinenet, 'out_conv'):
+ nn.init.zeros_(refinenet.out_conv.weight)
+ if refinenet.out_conv.bias is not None:
+ nn.init.zeros_(refinenet.out_conv.bias)
+
+ def decode(self, hidden, N, H, W):
+ BN, hw, _ = hidden.shape
+ B = BN // N
+
+ final_output = []
+
+ hidden = hidden.reshape(B*N, hw, -1)
+
+ register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:])
+
+ # Concatenate special tokens with patch tokens
+ hidden = torch.cat([register_token, hidden], dim=1)
+ hw = hidden.shape[1]
+
+ if self.pos_type.startswith('rope'):
+ pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device)
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos = pos + 1
+ pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype)
+ pos = torch.cat([pos_special, pos], dim=1)
+
+ for i in range(len(self.decoder)):
+ blk = self.decoder[i]
+
+ if i % 2 == 0:
+ pos = pos.reshape(B*N, hw, -1)
+ hidden = hidden.reshape(B*N, hw, -1)
+ else:
+ pos = pos.reshape(B, N*hw, -1)
+ hidden = hidden.reshape(B, N*hw, -1)
+
+ hidden = blk(hidden, xpos=pos)
+
+ if i+1 in [len(self.decoder)-1, len(self.decoder)]:
+ final_output.append(hidden.reshape(B*N, hw, -1))
+
+ return torch.cat([final_output[0], final_output[1]], dim=-1), pos.reshape(B*N, hw, -1)
+
+ def forward(self, imgs):
+ imgs = (imgs - self.image_mean) / self.image_std
+
+ B, N, _, H, W = imgs.shape
+ patch_h, patch_w = H // 14, W // 14
+
+ # encode by dinov2
+ imgs = imgs.reshape(B*N, _, H, W).to(torch.bfloat16)
+ encoder_output = self.encoder(imgs, is_training=True)
+
+ if isinstance(encoder_output, dict):
+ hidden = encoder_output["x_norm_patchtokens"]
+ else:
+ hidden = encoder_output
+
+ hidden, pos = self.decode(hidden, N, H, W)
+
+ point_hidden, point_intermediates = self.point_decoder(hidden, xpos=pos)
+ conf_hidden, _ = self.conf_decoder(hidden, xpos=pos) # ignore intermediate features for conf
+ camera_hidden = self.camera_decoder(hidden, xpos=pos)
+
+ with torch.amp.autocast(device_type='cuda', enabled=False, dtype=torch.float32):
+ # local points - prepare aggregated tokens list for DPTHead
+ point_hidden = point_hidden.float()
+ self.point_head.to(torch.float32)
+ self.conf_head.to(torch.float32)
+ self.camera_head.to(torch.float32)
+
+ # Prepare aggregated_tokens_list: [input, intermediate_layer_1, intermediate_layer_3, intermediate_layer_5]
+ # Reshape to [B, N, num_patches, dim] format expected by DPTHead
+ # patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ # # Intermediate tokens from point_decoder (layers 1, 3, 5)
+ # for i, feat in enumerate(point_intermediates):
+ # point_intermediates[i] = feat.reshape(B, N, -1, feat.shape[-1]) # [B, N, hw+register, dim]
+
+ # # Create fake images tensor with correct shape for DPTHead
+ # fake_images = torch.zeros(B, N, 3, H, W, device=hidden.device, dtype=torch.float32)
+
+ # # Call DPTHead
+ # dpt_output, dpt_conf = self.point_head_2(
+ # aggregated_tokens_list=point_intermediates,
+ # images=fake_images,
+ # patch_start_idx=self.patch_start_idx
+ # )
+
+ # # dpt_output shape: [B, N, H, W, 4] (3D points + 1 dummy confidence)
+ # # We only use the first 3 channels for 3D points, ignore the 4th channel
+ # res= dpt_output[..., :3] # [B, N, H, W, 3]
+
+ # # Apply depth processing similar to original code
+ # xy, z = local_points_raw.split([2, 1], dim=-1)
+
+ ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
+ # ret += res
+ xy, z = ret.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ local_points = torch.cat([xy * z, z], dim=-1)
+
+ # confidence
+ conf_hidden = conf_hidden.float()
+ conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, N, H, W, -1)
+
+ # camera
+ camera_hidden = camera_hidden.float()
+ camera_poses = self.camera_head(camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, N, 4, 4)
+
+ # transform to first frame camera coordinate
+ w2c_target = se3_inverse(camera_poses[:, 0])
+ # gt_pts = torch.einsum('bij, bnhwj -> bnhwi', w2c_target, homogenize_points(gt_pts))[..., :3]
+ camera_poses_cano = torch.einsum('bij, bnjk -> bnik', w2c_target, camera_poses)
+
+ # unproject local points using camera poses
+ points_cano = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses_cano, homogenize_points(local_points))[..., :3]
+
+
+ return dict(
+ world_points=points_cano,
+ local_points=local_points,
+ point_conf=conf, # only from original conf_head
+ c2ws=camera_poses,
+ c2ws_cano=camera_poses_cano,
+ hidden=hidden,
+ )
diff --git a/unish/pi3/models/segformer/backbone.py b/unish/pi3/models/segformer/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a7793e9638b8418ad4f92262e3eee7be883b416
--- /dev/null
+++ b/unish/pi3/models/segformer/backbone.py
@@ -0,0 +1,365 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from functools import partial
+from timm.models.layers import to_2tuple, trunc_normal_, DropPath
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ x = x.transpose(1, 2).view(B, C, H, W)
+ x = self.dwconv(x)
+ x = x.flatten(2).transpose(1, 2)
+
+ return x
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ x = self.dwconv(x, H, W)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ if self.sr_ratio > 1:
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
+ x_ = self.norm(x_)
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ else:
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ k, v = kv[0], kv[1]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+
+ return x
+
+class OverlapPatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
+ self.norm = nn.LayerNorm(embed_dim)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+
+ return x, H, W
+
+
+
+
+class MixVisionTransformer(nn.Module):
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+
+ # patch_embed
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
+ embed_dim=embed_dims[0])
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1])
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2])
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3])
+
+ # transformer encoder
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+ self.block1 = nn.ModuleList([Block(
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[0])
+ for i in range(depths[0])])
+ self.norm1 = norm_layer(embed_dims[0])
+
+ cur += depths[0]
+ self.block2 = nn.ModuleList([Block(
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[1])
+ for i in range(depths[1])])
+ self.norm2 = norm_layer(embed_dims[1])
+
+ cur += depths[1]
+ self.block3 = nn.ModuleList([Block(
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[2])
+ for i in range(depths[2])])
+ self.norm3 = norm_layer(embed_dims[2])
+
+ cur += depths[2]
+ self.block4 = nn.ModuleList([Block(
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
+ sr_ratio=sr_ratios[3])
+ for i in range(depths[3])])
+ self.norm4 = norm_layer(embed_dims[3])
+
+ # classification head
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ self.load_state_dict(torch.load(pretrained, map_location='cpu', strict=False)) ###############
+
+ def reset_drop_path(self, drop_path_rate):
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
+ cur = 0
+ for i in range(self.depths[0]):
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[0]
+ for i in range(self.depths[1]):
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[1]
+ for i in range(self.depths[2]):
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
+
+ cur += self.depths[2]
+ for i in range(self.depths[3]):
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
+
+ def freeze_patch_emb(self):
+ self.patch_embed1.requires_grad = False
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ outs = []
+
+ # stage 1
+ x, H, W = self.patch_embed1(x)
+ for i, blk in enumerate(self.block1):
+ x = blk(x, H, W)
+ x = self.norm1(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ # stage 2
+ x, H, W = self.patch_embed2(x)
+ for i, blk in enumerate(self.block2):
+ x = blk(x, H, W)
+ x = self.norm2(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ # stage 3
+ x, H, W = self.patch_embed3(x)
+ for i, blk in enumerate(self.block3):
+ x = blk(x, H, W)
+ x = self.norm3(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ # stage 4
+ x, H, W = self.patch_embed4(x)
+ for i, blk in enumerate(self.block4):
+ x = blk(x, H, W)
+ x = self.norm4(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ return outs
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ # x = self.head(x)
+
+ return x
+
+
+class mit_b0(MixVisionTransformer):
+ def __init__(self, **kwargs):
+ super(mit_b0, self).__init__(
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
+ drop_rate=0.0, drop_path_rate=0.1)
diff --git a/unish/pi3/models/segformer/head.py b/unish/pi3/models/segformer/head.py
new file mode 100644
index 0000000000000000000000000000000000000000..f21e0cdcd6ab8ffa68befb92ac8b70d921581854
--- /dev/null
+++ b/unish/pi3/models/segformer/head.py
@@ -0,0 +1,714 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Optional, Tuple, Union
+from abc import ABCMeta, abstractmethod
+import inspect
+from torch.nn import Conv2d, Tanh, PReLU, Sigmoid, GELU, ReLU, BatchNorm2d
+from torch.nn import SyncBatchNorm as SyncBN
+
+# from mmcv.runner import auto_fp16, force_fp32
+
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
+ """Build convolution layer.
+
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if inspect.isclass(layer_type):
+ return layer_type(*args, **kwargs, **cfg_) # type: ignore
+ # Switch registry to the target scope. If `conv_layer` cannot be found
+ # in the registry, fallback to search `conv_layer` in the
+ conv_layer = eval(layer_type)
+ # mmengine.MODELS.
+ if conv_layer is None:
+ raise KeyError(f'Cannot find {conv_layer} in registry under scope '
+ f'name {registry.scope}')
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
+
+def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
+ """Build padding layer.
+
+ Args:
+ cfg (dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if inspect.isclass(padding_type):
+ return padding_type(*args, **kwargs, **cfg_)
+ # Switch registry to the target scope. If `padding_layer` cannot be found
+ # in the registry, fallback to search `padding_layer` in the
+ # mmengine.MODELS.
+ if padding_layer is None:
+ raise KeyError(f'Cannot find {padding_layer} in registry under scope '
+ f'name {registry.scope}')
+ layer = padding_layer(*args, **kwargs, **cfg_)
+
+ return layer
+
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ efficient_conv_bn_eval (bool): Whether use efficient conv when the
+ consecutive bn is in eval mode (either training or testing), as
+ proposed in https://arxiv.org/abs/2305.11624 . Default: `False`.
+ """
+
+ _abbr_ = 'conv_block'
+
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: Union[bool, str] = 'auto',
+ conv_cfg: Optional[Dict] = None,
+ norm_cfg: Optional[Dict] = None,
+ act_cfg: Optional[Dict] = dict(type='ReLU'),
+ inplace: bool = True,
+ with_spectral_norm: bool = False,
+ padding_mode: str = 'zeros',
+ order: tuple = ('conv', 'norm', 'act'),
+ efficient_conv_bn_eval: bool = False):
+ super().__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == {'conv', 'norm', 'act'}
+
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ # self.norm_name = norm_cfg.pop('type')
+ # norm = eval(self.norm_name)(**norm_cfg)
+ # norm = eval(self.norm_name)()
+ self.norm_name = 'bn'
+ norm = BatchNorm2d(norm_channels)
+
+ # self.norm_name, norm = build_norm_layer(
+ # norm_cfg, norm_channels) # type: ignore
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None # type: ignore
+
+ self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
+
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy() # type: ignore
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ activate_ = eval(act_cfg_.pop('type'))
+ # self.activate = build_activation_layer(act_cfg_)
+ self.activate = activate_(**act_cfg_)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self,
+ x: torch.Tensor,
+ activate: bool = True,
+ norm: bool = True) -> torch.Tensor:
+ layer_index = 0
+ while layer_index < len(self.order):
+ layer = self.order[layer_index]
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ # if the next operation is norm and we have a norm layer in
+ # eval mode and we have enabled `efficient_conv_bn_eval` for
+ # the conv operator, then activate the optimized forward and
+ # skip the next norm operator since it has been fused
+ if layer_index + 1 < len(self.order) and \
+ self.order[layer_index + 1] == 'norm' and norm and \
+ self.with_norm and not self.norm.training and \
+ self.efficient_conv_bn_eval_forward is not None:
+ self.conv.forward = partial(
+ self.efficient_conv_bn_eval_forward, self.norm,
+ self.conv)
+ layer_index += 1
+ x = self.conv(x)
+ del self.conv.forward
+ else:
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ layer_index += 1
+ return x
+
+ def turn_on_efficient_conv_bn_eval(self, efficient_conv_bn_eval=True):
+ # efficient_conv_bn_eval works for conv + bn
+ # with `track_running_stats` option
+ if efficient_conv_bn_eval and self.norm \
+ and isinstance(self.norm, _BatchNorm) \
+ and self.norm.track_running_stats:
+ self.efficient_conv_bn_eval_forward = efficient_conv_bn_eval_forward # noqa: E501
+ else:
+ self.efficient_conv_bn_eval_forward = None # type: ignore
+
+ @staticmethod
+ def create_from_conv_bn(conv: torch.nn.modules.conv._ConvNd,
+ bn: torch.nn.modules.batchnorm._BatchNorm,
+ efficient_conv_bn_eval=True) -> 'ConvModule':
+ """Create a ConvModule from a conv and a bn module."""
+ self = ConvModule.__new__(ConvModule)
+ super(ConvModule, self).__init__()
+
+ self.conv_cfg = None
+ self.norm_cfg = None
+ self.act_cfg = None
+ self.inplace = False
+ self.with_spectral_norm = False
+ self.with_explicit_padding = False
+ self.order = ('conv', 'norm', 'act')
+
+ self.with_norm = True
+ self.with_activation = False
+ self.with_bias = conv.bias is not None
+
+ # build convolution layer
+ self.conv = conv
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = self.conv.padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ # build normalization layers
+ self.norm_name, norm = 'bn', bn
+ self.add_module(self.norm_name, norm)
+
+ self.turn_on_efficient_conv_bn_eval(efficient_conv_bn_eval)
+
+ return self
+
+def resize(input,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None,
+ warning=True):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if ((output_h > 1 and output_w > 1 and input_h > 1
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)):
+ warnings.warn(
+ f'When align_corners={align_corners}, '
+ 'the output would more aligned if '
+ f'input size {(input_h, input_w)} is `x+1` and '
+ f'out size {(output_h, output_w)} is `nx+1`')
+ if isinstance(size, torch.Size):
+ size = tuple(int(x) for x in size)
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ decoder_params=None,
+ ignore_index=255,
+ sampler=None,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+
+ # if sampler is not None:
+ # self.sampler = build_pixel_sampler(sampler, context=self)
+ # else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ # normal_init(self.conv_seg, mean=0, std=0.01)
+ pass
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ # @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+
+
+class MLP(nn.Module):
+ """
+ Linear Embedding
+ """
+ def __init__(self, input_dim=2048, embed_dim=768):
+ super().__init__()
+ self.proj = nn.Linear(input_dim, embed_dim)
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+
+class SegFormerHead(BaseDecodeHead):
+ """
+ SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
+ """
+ def __init__(self, feature_strides, **kwargs):
+ super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs)
+ assert len(feature_strides) == len(self.in_channels)
+ assert min(feature_strides) == feature_strides[0]
+ self.feature_strides = feature_strides
+
+ c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
+
+ decoder_params = kwargs['decoder_params']
+ embedding_dim = decoder_params['embed_dim']
+
+ self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
+ self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
+ self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
+ self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
+
+ self.linear_fuse = ConvModule(
+ in_channels=embedding_dim*4,
+ out_channels=embedding_dim,
+ kernel_size=1,
+ norm_cfg=dict(type='SyncBN', requires_grad=True)
+ )
+
+ self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
+
+ def forward(self, inputs):
+ x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
+ c1, c2, c3, c4 = x
+
+ ############## MLP decoder on C1-C4 ###########
+ n, _, h, w = c4.shape
+
+ _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
+ _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
+
+ _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
+ _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
+
+ _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
+ _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
+
+ _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
+
+ _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
+
+ x = self.dropout(_c)
+ x = self.linear_pred(x)
+
+ return x
diff --git a/unish/pi3/models/segformer/model.py b/unish/pi3/models/segformer/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcd33489a7442564ad7f637bffad41d11de8daf6
--- /dev/null
+++ b/unish/pi3/models/segformer/model.py
@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# from .. import builder
+# from .segmentor import BaseSegmentor
+import warnings
+
+from .head import SegFormerHead, resize
+from .backbone import mit_b0
+
+
+
+class EncoderDecoder(nn.Module):
+ """Encoder Decoder segmentors.
+
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+ """
+
+ def __init__(self):
+ super(EncoderDecoder, self).__init__()
+ # self.backbone = builder.build_backbone(backbone)
+ self.backbone = mit_b0() ##############
+
+ decode_head=dict(
+ in_channels=[32, 64, 160, 256],
+ in_index=[0, 1, 2, 3],
+ feature_strides=[4, 8, 16, 32],
+ channels=128,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ align_corners=False,
+ decoder_params=dict(embed_dim=256),
+ loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+
+ self.decode_head = SegFormerHead(**decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+
+ self.with_neck = False
+
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ seg_logits = self.decode_head.forward_test(x, img_metas, None)
+ return seg_logits
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def whole_inference(self, img, img_meta, rescale):
+ """Inference with full image."""
+
+ seg_logit = self.encode_decode(img, img_meta)
+ if rescale:
+ seg_logit = resize(
+ seg_logit,
+ size=img_meta[0]['ori_shape'][:2],
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+
+ return seg_logit
+
+ def inference(self, img, img_meta, rescale):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output segmentation map.
+ """
+
+ assert self.test_cfg.mode in ['slide', 'whole']
+ ori_shape = img_meta[0]['ori_shape']
+ assert all(_['ori_shape'] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == 'slide':
+ seg_logit = self.slide_inference(img, img_meta, rescale)
+ else:
+ seg_logit = self.whole_inference(img, img_meta, rescale)
+ output = F.softmax(seg_logit, dim=1)
+ flip = img_meta[0]['flip']
+ if flip:
+ flip_direction = img_meta[0]['flip_direction']
+ assert flip_direction in ['horizontal', 'vertical']
+ if flip_direction == 'horizontal':
+ output = output.flip(dims=(3, ))
+ elif flip_direction == 'vertical':
+ output = output.flip(dims=(2, ))
+
+ return output
+
+ def inference_(self, imgs):
+ imgs_meta = dict(
+ img_shape=(imgs.shape[2], imgs.shape[3]),
+ scale_factor=1.0,
+ flip=False
+ )
+ seg_logit = self.encode_decode(imgs, imgs_meta)
+ output = F.softmax(seg_logit, dim=1)
+ output = torch.argmax(output, dim=1)
+ return output
diff --git a/unish/pi3/models/segformer/segmentor.py b/unish/pi3/models/segformer/segmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ac477f7c19f16138bd0fb3a1c927d2de78b327
--- /dev/null
+++ b/unish/pi3/models/segformer/segmentor.py
@@ -0,0 +1,202 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from mmcv.runner import auto_fp16
+
+
+class BaseSegmentor(nn.Module):
+ """Base class for segmentors."""
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(BaseSegmentor, self).__init__()
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the segmentor has neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ @property
+ def with_auxiliary_head(self):
+ """bool: whether the segmentor has auxiliary head"""
+ return hasattr(self,
+ 'auxiliary_head') and self.auxiliary_head is not None
+
+ @property
+ def with_decode_head(self):
+ """bool: whether the segmentor has decode head"""
+ return hasattr(self, 'decode_head') and self.decode_head is not None
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Placeholder for extract features from images."""
+ pass
+
+ @abstractmethod
+ def encode_decode(self, img, img_metas):
+ """Placeholder for encode images with backbone and decode into a
+ semantic segmentation map of the same size as input."""
+ pass
+
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """Placeholder for Forward function for training."""
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, img_meta, **kwargs):
+ """Placeholder for single image test."""
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Placeholder for augmentation test."""
+ pass
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in segmentor.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = logging.getLogger()
+ logger.info(f'load model from: {pretrained}')
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got '
+ f'{type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) != '
+ f'num of image meta ({len(img_metas)})')
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_['ori_shape'] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_['img_shape'] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_['pad_shape'] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data_batch['img'].data))
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+
\ No newline at end of file
diff --git a/unish/pi3/utils/__pycache__/geometry.cpython-310.pyc b/unish/pi3/utils/__pycache__/geometry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3dabed86385fa876decebb7a17f2bed556dc37fc
Binary files /dev/null and b/unish/pi3/utils/__pycache__/geometry.cpython-310.pyc differ
diff --git a/unish/pi3/utils/alignment.py b/unish/pi3/utils/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4ff68b3ca2f27980413ae24be5d994a486a563
--- /dev/null
+++ b/unish/pi3/utils/alignment.py
@@ -0,0 +1,499 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+# import utils3d
+
+
+def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
+ return torch.return_types.min((minimum, indices))
+
+
+def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
+ results = []
+ for i in range(n_chunks):
+ chunk_args = tuple(arg[i] for arg in splited_args)
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
+ results.append(fn(*chunk_args, **chunk_kwargs))
+
+ if isinstance(results[0], tuple):
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
+ else:
+ return torch.cat(results, dim=0)
+
+
+def _pad_inf(x_: torch.Tensor):
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
+
+
+def _pad_cumsum(cumsum: torch.Tensor):
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
+
+
+def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
+
+
+def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+ """
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
+
+ w_i must be >= 0.
+
+ ### Parameters:
+ - `x`: tensor of shape (..., n)
+ - `y`: tensor of shape (..., n)
+ - `w`: tensor of shape (..., n)
+ - `trunc`: optional, float or tensor of shape (..., n) or None
+
+ ### Returns:
+ - `a`: tensor of shape (...), differentiable
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
+ """
+ if trunc is None:
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ y_div_x = y / x.clamp_min(eps)
+ y_div_x, argsort = y_div_x.sort(dim=-1)
+
+ wx = torch.gather(x * w, dim=-1, index=argsort)
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
+
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
+
+ else:
+ # Reshape to (batch_size, n) for simplicity
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ batch_shape = x.shape[:-1]
+ batch_size = math.prod(batch_shape)
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
+
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ wx, wy = w * x, w * y
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
+
+ y_div_x = A = y / x.clamp_min(eps)
+ B = (wy - trunc) / wx.clamp_min(eps)
+ C = (wy + trunc) / wx.clamp_min(eps)
+ with torch.no_grad():
+ # Caculate prefix sum by orders of A, B, C
+ A, A_argsort = A.sort(dim=-1)
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
+
+ B, B_argsort = B.sort(dim=-1)
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
+
+ C, C_argsort = C.sort(dim=-1)
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
+
+ # Caculate left and right derivative of A
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+
+ # Find extrema
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
+
+ # Calculate objective value at extrema
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
+ extrema_value = torch.cat([
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
+ ]) # (num_extrema,)
+
+ # Find minima among corresponding extrema
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
+ index = where_extrema_index[indices]
+
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
+ a = a.reshape(batch_shape)
+ loss = minima.reshape(batch_shape)
+ index = index.reshape(batch_shape)
+
+ return a, loss, index
+
+
+def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+
+ """
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
+
+ return scale
+
+
+def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+ - `weight: torch.Tensor` of shape (..., N)
+ - `trunc: float` or tensor of shape (..., N) or None
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (...).
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
+ batch_size = math.prod(batch_shape)
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
+
+ # Here, we take anchors only for non-zero weights.
+ # Although the results will be still correct even anchor points have zero weight,
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
+
+ # Stop gradient when solving optimal anchors
+ with torch.no_grad():
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
+
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
+
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
+ index_2 = index[index_anchor] # (batch_size,)
+
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
+ shift = tgt_1 - scale * src_1
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
+
+ return scale, shift
+
+def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ w = weight
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
+ y = depth_tgt
+
+ for i in range(max_iter):
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
+
+ return beta[..., 0], beta[..., 1]
+
+
+def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weight: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
+ - `b: torch.Tensor` of shape (...)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
+
+ return scale
+
+
+def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+ with torch.no_grad():
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+
+ with torch.no_grad():
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ # Get optimal scale and shift for each batch element
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
+
+ return shift
+
+
+def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
+
+ return shift
+
+
+def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
+
+ ### Parameters:
+ - `x: torch.Tensor` of shape (..., N)
+ - `y: torch.Tensor` of shape (..., N)
+ - `w: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...,)
+ - `b: torch.Tensor` of shape (...,)
+ """
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
+ B = (w_sqrt * y)[..., None]
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
+ return a, b
+
+
+def align_affine_lstsq_z_shift(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Solve `min sum_i w_i * ||a * x_i + b - y_i||^2`, where x_i and y_i are 3D points,
+ `a` is a scalar (isotropic scaling), and `b` is a translation vector of the form `[0, 0, shift_z]`.
+ The minimization is with respect to `a` (scalar_scale) and `shift_z`.
+
+ The input point clouds x and y are expected to have a shape like (..., N, 3),
+ where N is the number of points and the last dimension has size 3 (X, Y, Z).
+ The weights w, if provided, should have shape (..., N) corresponding to the points.
+
+ This function adapts the structure of a 1D affine least squares solver to this specific
+ 3D problem by reformulating the design matrix A and observation vector B for torch.linalg.lstsq.
+
+ Parameters:
+ - `x: torch.Tensor` of shape (..., N, 3), representing the source point cloud.
+ - `y: torch.Tensor` of shape (..., N, 3), representing the target point cloud.
+ - `w: torch.Tensor` (optional) of shape (..., N), representing weights for each point.
+ If None, all points are weighted equally.
+
+ Returns:
+ - `a: torch.Tensor` of shape (...,), the scalar scaling factor.
+ - `b: torch.Tensor` of shape (..., 3), the translation vector `[0, 0, shift_z]`.
+ """
+ if x.shape[-1] != 3 or y.shape[-1] != 3:
+ raise ValueError("Input tensors x and y must have 3 features in the last dimension (X, Y, Z). "
+ f"Got x shape: {x.shape}, y shape: {y.shape}")
+ # Check all dimensions except the last one (feature dimension)
+ if x.shape[:-1] != y.shape[:-1]:
+ raise ValueError("Input tensors x and y must have matching shapes up to the last dimension. "
+ f"Got x shape: {x.shape}, y shape: {y.shape}")
+ if w is not None and w.shape != x.shape[:-1]:
+ raise ValueError("Weights w, if provided, must have shape (..., N) matching x and y's point dimensions. "
+ f"Got w shape: {w.shape}, x shape: {x.shape}")
+
+ # Determine batch shape and number of points
+ # Example: x shape (B1, B2, N, 3) -> batch_shape (B1, B2), num_points N
+ batch_shape = x.shape[:-2]
+ num_points = x.shape[-2]
+
+ # Prepare w_sqrt. If w is None, use unit weights.
+ # w_sqrt_points will have shape (..., N)
+ if w is None:
+ w_sqrt_points = torch.ones(*batch_shape, num_points, device=x.device, dtype=x.dtype)
+ else:
+ w_sqrt_points = w.sqrt()
+
+ # Dimension along which to concatenate point data from different coordinates (X, Y, Z)
+ dim_to_cat = len(batch_shape)
+
+ # Coefficients for 'a_val' (the scalar scale)
+ s_terms_x = w_sqrt_points * x[..., :, 0] # Shape (..., N)
+ s_terms_y = w_sqrt_points * x[..., :, 1] # Shape (..., N)
+ s_terms_z = w_sqrt_points * x[..., :, 2] # Shape (..., N)
+ a_val_coeff_column = torch.cat([s_terms_x, s_terms_y, s_terms_z], dim=dim_to_cat) # Shape (..., 3*N)
+
+ # Coefficients for 'shift_z_val'
+ zeros_for_shift_coeffs = torch.zeros_like(s_terms_x) # Shape (..., N)
+ shift_z_val_coeff_column = torch.cat([zeros_for_shift_coeffs, zeros_for_shift_coeffs, w_sqrt_points], dim=dim_to_cat) # Shape (..., 3*N)
+
+ # Construct the design matrix A_ls (shape (..., 3*N, 2))
+ A_ls = torch.stack([a_val_coeff_column, shift_z_val_coeff_column], dim=-1)
+
+ # Construct the observation vector B_ls (shape (..., 3*N, 1))
+ B_terms_x = w_sqrt_points * y[..., :, 0] # Shape (..., N)
+ B_terms_y = w_sqrt_points * y[..., :, 1] # Shape (..., N)
+ B_terms_z = w_sqrt_points * y[..., :, 2] # Shape (..., N)
+ B_ls_flat = torch.cat([B_terms_x, B_terms_y, B_terms_z], dim=dim_to_cat) # Shape (..., 3*N)
+ B_ls = B_ls_flat.unsqueeze(-1)
+
+ # Solve the least squares problem
+ solution = torch.linalg.lstsq(A_ls, B_ls)[0] # solution shape (..., 2, 1)
+
+ # Extract the scalar scale 'a_val' and 'shift_z_val'
+ a_val = solution[..., 0, 0] # Shape (...,)
+ shift_z_val = solution[..., 1, 0] # Shape (...,)
+
+ # Construct the output translation vector b = [0, 0, shift_z_val]
+ zeros_for_b = torch.zeros_like(a_val)
+ b_vector = torch.stack([zeros_for_b, zeros_for_b, shift_z_val], dim=-1) # Shape (..., 3)
+
+ return a_val, b_vector
diff --git a/unish/pi3/utils/basic.py b/unish/pi3/utils/basic.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ac73492409b1f2441a84e2f9de9681b3cf3ca9f
--- /dev/null
+++ b/unish/pi3/utils/basic.py
@@ -0,0 +1,223 @@
+import os
+import os.path as osp
+import math
+import cv2
+from PIL import Image
+import torch
+from torchvision import transforms
+from plyfile import PlyData, PlyElement
+import numpy as np
+
+def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000):
+ """
+ Loads images from a directory or video, resizes them to a uniform size,
+ then converts and stacks them into a single [N, 3, H, W] PyTorch tensor.
+ """
+ sources = []
+
+ # --- 1. Load image paths or video frames ---
+ if osp.isdir(path):
+ print(f"Loading images from directory: {path}")
+ filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))])
+ for i in range(0, len(filenames), interval):
+ img_path = osp.join(path, filenames[i])
+ try:
+ sources.append(Image.open(img_path).convert('RGB'))
+ except Exception as e:
+ print(f"Could not load image {filenames[i]}: {e}")
+ elif path.lower().endswith('.mp4'):
+ print(f"Loading frames from video: {path}")
+ cap = cv2.VideoCapture(path)
+ if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}")
+ frame_idx = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret: break
+ if frame_idx % interval == 0:
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ sources.append(Image.fromarray(rgb_frame))
+ frame_idx += 1
+ cap.release()
+ else:
+ raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}")
+
+ if not sources:
+ print("No images found or loaded.")
+ return torch.empty(0)
+
+ print(f"Found {len(sources)} images/frames. Processing...")
+
+ # --- 2. Determine a uniform target size for all images based on the first image ---
+ # This is necessary to ensure all tensors have the same dimensions for stacking.
+ first_img = sources[0]
+ W_orig, H_orig = first_img.size
+ scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1
+ W_target, H_target = W_orig * scale, H_orig * scale
+ k, m = round(W_target / 14), round(H_target / 14)
+ while (k * 14) * (m * 14) > PIXEL_LIMIT:
+ if k / m > W_target / H_target: k -= 1
+ else: m -= 1
+ TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14
+ print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})")
+
+ # --- 3. Resize images and convert them to tensors in the [0, 1] range ---
+ tensor_list = []
+ # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1]
+ to_tensor_transform = transforms.ToTensor()
+
+ for img_pil in sources:
+ try:
+ # Resize to the uniform target size
+ resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS)
+ # Convert to tensor
+ img_tensor = to_tensor_transform(resized_img)
+ tensor_list.append(img_tensor)
+ except Exception as e:
+ print(f"Error processing an image: {e}")
+
+ if not tensor_list:
+ print("No images were successfully processed.")
+ return torch.empty(0)
+
+ # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor ---
+ return torch.stack(tensor_list, dim=0)
+
+
+def tensor_to_pil(tensor):
+ """
+ Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension
+ (if it has size 3) to the last axis before converting.
+
+ Args:
+ tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W].
+
+ Returns:
+ PIL.Image: The converted PIL image.
+ """
+ if torch.is_tensor(tensor):
+ array = tensor.detach().cpu().numpy()
+ else:
+ array = tensor
+
+ return array_to_pil(array)
+
+
+def array_to_pil(array):
+ """
+ Converts a NumPy array to a PIL image. Automatically:
+ - Squeezes dimensions of size 1.
+ - Moves the channel dimension (if it has size 3) to the last axis.
+
+ Args:
+ array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W].
+
+ Returns:
+ PIL.Image: The converted PIL image.
+ """
+ # Remove singleton dimensions
+ array = np.squeeze(array)
+
+ # Ensure the array has the channel dimension as the last axis
+ if array.ndim == 3 and array.shape[0] == 3: # If the channel is the first axis
+ array = np.transpose(array, (1, 2, 0)) # Move channel to the last axis
+
+ # Handle single-channel grayscale images
+ if array.ndim == 2: # [H, W]
+ return Image.fromarray((array * 255).astype(np.uint8), mode="L")
+ elif array.ndim == 3 and array.shape[2] == 3: # [H, W, C] with 3 channels
+ return Image.fromarray((array * 255).astype(np.uint8), mode="RGB")
+ else:
+ raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}")
+
+
+def rotate_target_dim_to_last_axis(x, target_dim=3):
+ shape = x.shape
+ axis_to_move = -1
+ # Iterate backwards to find the first occurrence from the end
+ # (which corresponds to the last dimension of size 3 in the original order).
+ for i in range(len(shape) - 1, -1, -1):
+ if shape[i] == target_dim:
+ axis_to_move = i
+ break
+
+ # 2. If the axis is found and it's not already in the last position, move it.
+ if axis_to_move != -1 and axis_to_move != len(shape) - 1:
+ # Create the new dimension order.
+ dims_order = list(range(len(shape)))
+ dims_order.pop(axis_to_move)
+ dims_order.append(axis_to_move)
+
+ # Use permute to reorder the dimensions.
+ ret = x.transpose(*dims_order)
+ else:
+ ret = x
+
+ return ret
+
+
+def write_ply(
+ xyz,
+ rgb=None,
+ path='output.ply',
+) -> None:
+ if torch.is_tensor(xyz):
+ xyz = xyz.detach().cpu().numpy()
+
+ if torch.is_tensor(rgb):
+ rgb = rgb.detach().cpu().numpy()
+
+ if rgb is not None and rgb.max() > 1:
+ rgb = rgb / 255.
+
+ xyz = rotate_target_dim_to_last_axis(xyz, 3)
+ xyz = xyz.reshape(-1, 3)
+
+ if rgb is not None:
+ rgb = rotate_target_dim_to_last_axis(rgb, 3)
+ rgb = rgb.reshape(-1, 3)
+
+ if rgb is None:
+ min_coord = np.min(xyz, axis=0)
+ max_coord = np.max(xyz, axis=0)
+ normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8)
+
+ hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2]
+ hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1)
+
+ c = hsv[:,2:] * hsv[:,1:2]
+ x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 ))
+ m = hsv[:,2:] - c
+
+ rgb = np.zeros_like(hsv)
+ cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1)
+ rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])])
+ cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2)
+ rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])])
+ cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3)
+ rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]])
+ cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4)
+ rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]])
+ cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5)
+ rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]])
+ cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6)
+ rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]])
+ rgb = (rgb + m)
+
+ dtype = [
+ ("x", "f4"),
+ ("y", "f4"),
+ ("z", "f4"),
+ ("nx", "f4"),
+ ("ny", "f4"),
+ ("nz", "f4"),
+ ("red", "u1"),
+ ("green", "u1"),
+ ("blue", "u1"),
+ ]
+ normals = np.zeros_like(xyz)
+ elements = np.empty(xyz.shape[0], dtype=dtype)
+ attributes = np.concatenate((xyz, normals, rgb * 255), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ vertex_element = PlyElement.describe(elements, "vertex")
+ ply_data = PlyData([vertex_element])
+ ply_data.write(path)
\ No newline at end of file
diff --git a/unish/pi3/utils/cropping.py b/unish/pi3/utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27429600a4136892e7bea6c04ed71654b16653f
--- /dev/null
+++ b/unish/pi3/utils/cropping.py
@@ -0,0 +1,197 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# croppping utilities
+# --------------------------------------------------------
+import PIL.Image
+import os
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2 # noqa
+import numpy as np # noqa
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+from utils.basic import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
+
+class ImageList:
+ """ Convenience class to aply the same operation to a whole set of images.
+ """
+
+ def __init__(self, images):
+ if not isinstance(images, (tuple, list, set)):
+ images = [images]
+ self.images = []
+ for image in images:
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+ self.images.append(image)
+
+ def __len__(self):
+ return len(self.images)
+
+ def to_pil(self):
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+ @property
+ def size(self):
+ sizes = [im.size for im in self.images]
+ assert all(sizes[0] == s for s in sizes)
+ return sizes[0]
+
+ def resize(self, *args, **kwargs):
+ return ImageList(self._dispatch('resize', *args, **kwargs))
+
+ def crop(self, *args, **kwargs):
+ return ImageList(self._dispatch('crop', *args, **kwargs))
+
+ def _dispatch(self, func, *args, **kwargs):
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution, force=True, normal=None, far_mask=None):
+ """ Jointly rescale a (image, depthmap)
+ so that (out_width, out_height) >= output_res
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W,H)
+ output_resolution = np.array(output_resolution)
+ if depthmap is not None:
+ # can also use this with masks instead of depthmaps
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
+
+ # define output resolution
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ return (image.to_pil(), depthmap, camera_intrinsics)
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ # first rescale the image so that it contains the crop
+ image = image.resize(tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic)
+ if depthmap is not None:
+ depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
+
+ if normal is not None:
+ normal = cv2.resize(normal, output_resolution, fx=scale_final,
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
+ if far_mask is not None:
+ far_mask = cv2.resize(far_mask, output_resolution, fx=scale_final,
+ fy=scale_final, interpolation=cv2.INTER_NEAREST)
+
+ # no offset here; simple rescaling
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
+
+ return image.to_pil(), depthmap, camera_intrinsics, normal, far_mask
+
+def center_crop_image_depthmap(image, depthmap, camera_intrinsics, crop_scale, normal=None, far_mask=None):
+ """
+ Jointly center-crop an image and its depthmap, and adjust the camera intrinsics accordingly.
+
+ Parameters:
+ - image: PIL.Image or similar, the input image.
+ - depthmap: np.ndarray, the corresponding depth map.
+ - camera_intrinsics: np.ndarray, the 3x3 camera intrinsics matrix.
+ - crop_scale: float between 0 and 1, the fraction of the image to keep.
+
+ Returns:
+ - cropped_image: PIL.Image, the center-cropped image.
+ - cropped_depthmap: np.ndarray, the center-cropped depth map.
+ - adjusted_intrinsics: np.ndarray, the adjusted camera intrinsics matrix.
+ """
+ # Ensure crop_scale is valid
+ assert 0 < crop_scale <= 1, "crop_scale must be between 0 and 1"
+
+ # Convert image to ImageList for consistent processing
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (width, height)
+ if depthmap is not None:
+ # Ensure depthmap matches the image size
+ assert depthmap.shape[:2] == tuple(image.size[::-1]), "Depthmap size must match image size"
+
+ # Compute output resolution after cropping
+ output_resolution = np.floor(input_resolution * crop_scale).astype(int)
+ # get the correct crop_scale
+ crop_scale = output_resolution / input_resolution
+
+ # Compute margins (amount to crop from each side)
+ margins = input_resolution - output_resolution
+ offset = margins / 2 # Since we are center cropping
+
+ # Calculate the crop bounding box
+ l, t = offset.astype(int)
+ r = l + output_resolution[0]
+ b = t + output_resolution[1]
+ crop_bbox = (l, t, r, b)
+
+ # Crop the image and depthmap
+ image = image.crop(crop_bbox)
+ if depthmap is not None:
+ depthmap = depthmap[t:b, l:r]
+ if normal is not None:
+ normal = normal[t:b, l:r]
+ if far_mask is not None:
+ far_mask = far_mask[t:b, l:r]
+
+ # Adjust the camera intrinsics
+ adjusted_intrinsics = camera_intrinsics.copy()
+
+ # Adjust focal lengths (fx, fy) # no need to adjust focal lengths for cropping
+ # adjusted_intrinsics[0, 0] /= crop_scale[0] # fx
+ # adjusted_intrinsics[1, 1] /= crop_scale[1] # fy
+
+ # Adjust principal point (cx, cy)
+ adjusted_intrinsics[0, 2] -= l # cx
+ adjusted_intrinsics[1, 2] -= t # cy
+
+ return image.to_pil(), depthmap, adjusted_intrinsics, normal, far_mask
+
+
+def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
+ # Margins to offset the origin
+ margins = np.asarray(input_resolution) * scaling - output_resolution
+ assert np.all(margins >= 0.0)
+ if offset is None:
+ offset = offset_factor * margins
+
+ # Generate new camera parameters
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+ output_camera_matrix_colmap[:2, :] *= scaling
+ output_camera_matrix_colmap[:2, 2] -= offset
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+ return output_camera_matrix
+
+
+def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox, normal=None, far_mask=None):
+ """
+ Return a crop of the input view.
+ """
+ image = ImageList(image)
+ l, t, r, b = crop_bbox
+
+ image = image.crop((l, t, r, b))
+ depthmap = depthmap[t:b, l:r]
+ if normal is not None:
+ normal = normal[t:b, l:r]
+ if far_mask is not None:
+ far_mask = far_mask[t:b, l:r]
+
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= l
+ camera_intrinsics[1, 2] -= t
+
+ return image.to_pil(), depthmap, camera_intrinsics, normal, far_mask
+
+
+def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
+ out_width, out_height = output_resolution
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
+ crop_bbox = (l, t, l + out_width, t + out_height)
+ return crop_bbox
diff --git a/unish/pi3/utils/geometry.py b/unish/pi3/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee2f09ac253600cf897a4790c169b4255cbdcf9
--- /dev/null
+++ b/unish/pi3/utils/geometry.py
@@ -0,0 +1,475 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from typing import *
+from functools import partial
+
+def se3_inverse(T):
+ """
+ Computes the inverse of a batch of SE(3) matrices.
+ """
+
+ if torch.is_tensor(T):
+ R = T[..., :3, :3]
+ t = T[..., :3, 3].unsqueeze(-1)
+ R_inv = R.transpose(-2, -1)
+ t_inv = -torch.matmul(R_inv, t)
+ T_inv = torch.cat([
+ torch.cat([R_inv, t_inv], dim=-1),
+ torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(*T.shape[:-2], 1, 1)
+ ], dim=-2)
+ else:
+ R = T[..., :3, :3]
+ t = T[..., :3, 3, np.newaxis]
+
+ R_inv = np.swapaxes(R, -2, -1)
+ t_inv = -R_inv @ t
+
+ bottom_row = np.zeros((*T.shape[:-2], 1, 4), dtype=T.dtype)
+ bottom_row[..., :, 3] = 1
+
+ top_part = np.concatenate([R_inv, t_inv], axis=-1)
+ T_inv = np.concatenate([top_part, bottom_row], axis=-2)
+
+ return T_inv
+
+def get_pixel(H, W):
+ # get 2D pixels (u, v) for image_a in cam_a pixel space
+ u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
+ # u_a = np.flip(u_a, axis=1)
+ # v_a = np.flip(v_a, axis=0)
+ pixels_a = np.stack([
+ u_a.flatten() + 0.5,
+ v_a.flatten() + 0.5,
+ np.ones_like(u_a.flatten())
+ ], axis=0)
+
+ return pixels_a
+
+def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+ if z_far > 0:
+ valid_mask = valid_mask & (depthmap < z_far)
+
+ X_world = X_cam # default
+ if camera_pose is not None:
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates (invalid depth values)
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+
+ return X_world, valid_mask
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: there are no skew terms
+ # assert camera_intrinsics[0, 1] == 0.0
+ # assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ # Mask for valid coordinates
+ valid_mask = (depthmap > 0.0)
+ # Invalid any depth > 80m
+ valid_mask = valid_mask
+ return X_cam, valid_mask
+
+def homogenize_points(
+ points,
+):
+ """Convert batched points (xyz) to (xyz1)."""
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+
+
+def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
+
+ if H is None:
+ B,H,W = depth1.shape
+ else:
+ B = depth1.shape[0]
+ with torch.no_grad():
+ x1_n = torch.meshgrid(
+ *[
+ torch.linspace(
+ -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
+ )
+ for n in (B, H, W)
+ ],
+ indexing = 'ij'
+ )
+ x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
+ mask, x2 = warp_kpts(
+ x1_n.double(),
+ depth1.double(),
+ depth2.double(),
+ T_1to2.double(),
+ K1.double(),
+ K2.double(),
+ depth_interpolation_mode = depth_interpolation_mode,
+ relative_depth_error_threshold = relative_depth_error_threshold,
+ )
+ prob = mask.float().reshape(B, H, W)
+ x2 = x2.reshape(B, H, W, 2)
+ return x2, prob
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
+ Also check covisibility and depth consistency.
+ Depth is consistent if relative error < 0.2 (hard-coded).
+ # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+ Args:
+ kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
+ depth0 (torch.Tensor): [N, H, W],
+ depth1 (torch.Tensor): [N, H, W],
+ T_0to1 (torch.Tensor): [N, 3, 4],
+ K0 (torch.Tensor): [N, 3, 3],
+ K1 (torch.Tensor): [N, 3, 3],
+ Returns:
+ calculable_mask (torch.Tensor): [N, L]
+ warped_keypoints0 (torch.Tensor): [N, L, 2]
+ """
+ (
+ n,
+ h,
+ w,
+ ) = depth0.shape
+ if depth_interpolation_mode == "combined":
+ # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
+ if smooth_mask:
+ raise NotImplementedError("Combined bilinear and NN warp not implemented")
+ valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "bilinear",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
+ smooth_mask = smooth_mask,
+ return_relative_depth_error = return_relative_depth_error,
+ depth_interpolation_mode = "nearest-exact",
+ relative_depth_error_threshold = relative_depth_error_threshold)
+ nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
+ warp = warp_bilinear.clone()
+ warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+ valid = valid_bilinear | valid_nearest
+ return valid, warp
+
+
+ kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
+ :, 0, :, 0
+ ]
+ kpts0 = torch.stack(
+ (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+ ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+ # Sample depth, get calculable_mask on depth != 0
+ # nonzero_mask = kpts0_depth != 0
+ # Sample depth, get calculable_mask on depth > 0
+ nonzero_mask = kpts0_depth > 0
+
+ # Unproject
+ kpts0_h = (
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+ * kpts0_depth[..., None]
+ ) # (N, L, 3)
+ kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
+ kpts0_cam = kpts0_n
+
+ # Rigid Transform
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+ # Project
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
+ w_kpts0_h[:, :, [2]] + 1e-4
+ ) # (N, L, 2), +1e-4 to avoid zero depth
+
+ # Covisible Check
+ h, w = depth1.shape[1:3]
+ covisible_mask = (
+ (w_kpts0[:, :, 0] > 0)
+ * (w_kpts0[:, :, 0] < w - 1)
+ * (w_kpts0[:, :, 1] > 0)
+ * (w_kpts0[:, :, 1] < h - 1)
+ )
+ w_kpts0 = torch.stack(
+ (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+ ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+ # w_kpts0[~covisible_mask, :] = -5 # xd
+
+ w_kpts0_depth = F.grid_sample(
+ depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+ )[:, 0, :, 0]
+
+ relative_depth_error = (
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+ ).abs()
+ if not smooth_mask:
+ consistent_mask = relative_depth_error < relative_depth_error_threshold
+ else:
+ consistent_mask = (-relative_depth_error/smooth_mask).exp()
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
+ if return_relative_depth_error:
+ return relative_depth_error, w_kpts0
+ else:
+ return valid_mask, w_kpts0
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """ Apply a geometric transformation to a list of 3-D points.
+
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
+
+ ncol: int. number of columns of the result (2 or 3)
+ norm: float. if != 0, the resut is projected on the z=norm plane.
+
+ Returns an array of projected 2d points.
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ # adapt shape if necessary
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ # optimized code
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
+ Trf.ndim == 3 and pts.ndim == 4):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d + 1:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
+ else:
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim - 2
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+ pts = pts[:, None, :]
+
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+ return res
+
+
+def inv(mat):
+ """ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f'bad matrix type = {type(mat)}')
+
+def opencv_camera_to_plucker(poses, K, H, W):
+ device = poses.device
+ B = poses.shape[0]
+
+ pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W)
+ pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel)
+ ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel)
+
+ ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1)
+
+ ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
+ plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
+ plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
+
+ return plucker_ray
+
+
+def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
+ """
+ Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
+
+ Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
+ """
+ shape = depth.shape
+ depth = depth.reshape(-1, 1, *shape[-2:])
+ if mask is not None:
+ mask = mask.reshape(-1, 1, *shape[-2:])
+
+ if mask is None:
+ diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
+ else:
+ diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
+
+ edge = torch.zeros_like(depth, dtype=torch.bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= (diff / depth).nan_to_num_() > rtol
+ edge = edge.reshape(*shape)
+ return edge
+
+def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+ err = (f * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ xy_proj = xy / (z + optim_shift)[: , None]
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+
+ return optim_shift, optim_focal
+
+
+def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ err = (focal * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ return optim_shift
+
+def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
+ """
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
+
+ Note that it assumes:
+ - the optical center is at the center of the map
+ - the map is undistorted
+ - the map is isometric in the x and y directions
+
+ ### Parameters:
+ - `points: torch.Tensor` of shape (..., H, W, 3)
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
+
+ ### Returns:
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
+ """
+ shape = points.shape
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ points = points.reshape(-1, *shape[-3:])
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
+ focal = focal.reshape(-1) if focal is not None else None
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
+
+ uv_lr_np = uv_lr.cpu().numpy()
+ points_lr_np = points_lr.detach().cpu().numpy()
+ focal_np = focal.cpu().numpy() if focal is not None else None
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
+ optim_shift, optim_focal = [], []
+ for i in range(points.shape[0]):
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
+ if uv_lr_i_np.shape[0] < 2:
+ optim_focal.append(1)
+ optim_shift.append(0)
+ continue
+ if focal is None:
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
+ optim_focal.append(float(optim_focal_i))
+ else:
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
+ optim_shift.append(float(optim_shift_i))
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+
+ if focal is None:
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+ else:
+ optim_focal = focal.reshape(shape[:-3])
+
+ return optim_focal, optim_shift
\ No newline at end of file
diff --git a/unish/pipeline.py b/unish/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3cc4cc6c120f77b8b5925321b2cd59a3db76d8b
--- /dev/null
+++ b/unish/pipeline.py
@@ -0,0 +1,244 @@
+import torch
+import torch.nn as nn
+import logging
+from huggingface_hub import PyTorchModelHubMixin
+
+from unish.pi3.models.pi3 import Pi3
+from unish.heads.vit import vit
+from unish.heads.human_head_cliff import HumanHeadCliff
+from unish.heads.align_net import AlignNet
+from unish.utils.data_utils import rotmat_to_aa, depth_to_points
+from unish.utils.smpl_utils import transform_smpl
+from unish.pi3.utils.geometry import recover_focal_shift, se3_inverse
+import utils3d
+
+logger = logging.getLogger(__name__)
+
+class UniSHPipeline(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, dtype='float32'):
+ super().__init__()
+ self.pi3 = Pi3()
+
+ self.human_head = HumanHeadCliff()
+ self.alignnet = AlignNet()
+ self.backbone = vit()
+
+ def get_human_prediction(self, human_patches, bbox_info):
+ b, s, c, h, w = human_patches.shape
+ human_patches = human_patches.reshape(b * s, c, h, w)
+ features = self.backbone(human_patches[..., 32:-32])
+ _, c, h_patch, w_patch = features.shape
+ features = features.reshape(b, s, c, h_patch, w_patch).float()
+ smpl_feats = self.human_head(features, bbox_info)
+ return smpl_feats
+
+ def get_normalize_factor(self, local_points, masks):
+ B, N, H, W, _ = local_points.shape
+
+ # normalize predict points
+ all_pts = local_points.clone()
+ all_pts[~masks] = 0
+ all_pts = all_pts.reshape(B, N, -1, 3)
+ all_dis = all_pts.norm(dim=-1)
+ norm_factor = all_dis.sum(dim=[-1, -2]) / (masks.float().sum(dim=[-1, -2, -3]) + 1e-8)
+
+ return norm_factor.view(B)
+
+ def scale_predictions(self, predictions, scale):
+
+ batch_size, num_views = predictions["world_points"].shape[:2]
+
+ norm_factor = self.get_normalize_factor(predictions["local_points"], predictions["masks"])
+ # Ensure scale has the correct shape [batch_size] for element-wise division
+ scale = scale.squeeze(-1) # [batch, 1] -> [batch]
+ scale /= norm_factor # (batch_size)
+ predictions["world_points"] *= scale.view(batch_size, 1, 1, 1, 1)
+ predictions["local_points"] *= scale.view(batch_size, 1, 1, 1, 1)
+ predictions["c2ws"][:, :, :3, 3] *= scale.view(batch_size, 1, 1)
+ predictions["c2ws_cano"][:, :, :3, 3] *= scale.view(batch_size, 1, 1)
+ predictions["w2cs_cano"] = se3_inverse(predictions["c2ws_cano"])
+ predictions["trans_cam"] *= scale.view(batch_size, 1, 1)
+ predictions["trans_world"] *= scale.view(batch_size, 1, 1)
+
+ return predictions, scale
+
+ def get_depth_intrinsic_shift(self, predictions):
+
+ points = predictions["local_points"]
+ masks = predictions["masks"]
+
+ with torch.autocast(device_type='cuda', enabled=False, dtype=torch.float32):
+ points = points.to(torch.float32)
+ focal, shift = recover_focal_shift(points, masks)
+
+ original_height, original_width = points.shape[-3:-1]
+ aspect_ratio = original_width / original_height
+ fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
+
+ # Ensure all inputs are on the same device for utils3d
+ device = points.device
+ cx_norm = torch.tensor(0.5, device=device)
+ cy_norm = torch.tensor(0.5, device=device)
+ cx = torch.tensor(original_width / 2.0, device=device)
+ cy = torch.tensor(original_height / 2.0, device=device)
+
+ intrinsics_normed = utils3d.torch.intrinsics_from_focal_center(fx, fy, cx_norm, cy_norm)
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx * original_width, fy * original_height, cx, cy)
+ shifted_points = points.clone()
+ shifted_points[..., 2] += shift[..., None, None]
+ masks &= shifted_points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
+ depth = shifted_points[..., 2].clone()
+
+ return depth, intrinsics, intrinsics_normed, shift
+
+ def forward(self, batch):
+ batch_size, num_views = batch["images"].shape[:2]
+ geometry_results = self.pi3(batch["images"])
+
+ with torch.amp.autocast(device_type='cuda', enabled=False, dtype=torch.float32):
+
+ for k, v in geometry_results.items():
+ if isinstance(v, torch.Tensor):
+ geometry_results[k] = v.to(torch.float32)
+
+ # Predict Cameras
+ c2ws_cano = geometry_results["c2ws_cano"] # (B, S, 4, 4)
+ w2cs_cano = se3_inverse(c2ws_cano)
+ geometry_results["point_conf"] = torch.sigmoid(geometry_results["point_conf"])
+ masks = geometry_results["point_conf"] > 0.1
+
+ # Predict SMPL
+ self.human_head.to(torch.float32)
+
+ human_predict = self.get_human_prediction(batch["human_patches"], batch["bbox_info"])
+
+ if isinstance(human_predict, dict):
+ cam_token = human_predict["token_out"].view(batch_size, num_views, 1, 1024)
+ else:
+ cam_token = human_predict.view(batch_size, num_views, 1, 1024)
+
+ # predict scale and smpl z-trans
+ hidden = geometry_results["hidden"].view(batch_size, num_views, -1, geometry_results["hidden"].shape[-1]) # [1, 10, 782, 2048]
+ self.alignnet.to(torch.float32)
+ align_results = self.alignnet(hidden, cam_token)
+ pred_scale = align_results["scale"] # [1]
+
+ if isinstance(human_predict, dict):
+ align_results["betas"] = human_predict["betas"]
+ align_results["pose_cam"] = human_predict["pose_cam"]
+
+ align_results_world = transform_smpl(align_results, w2cs_cano, copy_dict=True)
+
+ ret_dict = {
+ "pose_world": align_results_world["pose_world"], # [B, S, 72]
+ "trans_world": align_results_world["trans_world"], # [B, S, 3]
+ "pose_cam": align_results["pose_cam"], # [B, S, 72]
+ "trans_cam": align_results["trans_cam"], # [B, S, 3]
+ "betas": align_results["betas"], # [B, S, 10]
+ "world_points": geometry_results["world_points"], # [B, S, H, W, 3]
+ "local_points": geometry_results["local_points"], # [B, S, H, W, 3]
+ "point_conf": geometry_results["point_conf"].squeeze(-1), # [B, S, H, W]
+ "c2ws": geometry_results["c2ws"], # [B, S, 4, 4]
+ "c2ws_cano": geometry_results["c2ws_cano"], # [B, S, 4, 4]
+ "scale": pred_scale, # [B]
+ "w2cs_cano": w2cs_cano, # [B, S, 4, 4]
+ "masks": masks.squeeze(-1), # [B, S, H, W]
+ }
+
+ return ret_dict
+
+ def inference(self, batch):
+ batch_size, num_views = batch["images"].shape[:2]
+ geometry_results = self.pi3(batch["images"])
+
+ with torch.amp.autocast(device_type='cuda', enabled=False, dtype=torch.float32):
+
+ for k, v in geometry_results.items():
+ if isinstance(v, torch.Tensor):
+ geometry_results[k] = v.to(torch.float32)
+
+ # Predict Cameras
+ c2ws_cano = geometry_results["c2ws_cano"] # (B, S, 4, 4)
+ w2cs_cano = se3_inverse(c2ws_cano)
+ geometry_results["point_conf"] = torch.sigmoid(geometry_results["point_conf"])
+ masks = geometry_results["point_conf"] > 0.1
+ geometry_results["masks"] = masks.squeeze(-1)
+
+ _, pred_intrinsics, _, _ = self.get_depth_intrinsic_shift(geometry_results)
+
+ # Predict SMPL
+ self.human_head.to(torch.float32)
+
+ bbox_info = batch["bbox_info"]
+
+ if "intrinsics" in batch:
+ intrinsics_gt = batch["intrinsics"]
+ focal_gt = (intrinsics_gt[..., 0, 0] + intrinsics_gt[..., 1, 1]) / 2
+ focal_pred = (pred_intrinsics[..., 0, 0] + pred_intrinsics[..., 1, 1]) / 2
+ focal_gt = focal_gt.view(batch_size, -1, 1).to(bbox_info)
+ focal_pred = focal_pred.view(batch_size, -1, 1).to(bbox_info)
+ bbox_info = bbox_info * focal_gt / focal_pred
+ else:
+ intrinsics = pred_intrinsics
+ focal = (intrinsics[..., 0, 0] + intrinsics[..., 1, 1]) / 2
+ focal = focal.view(batch_size, -1, 1).to(bbox_info)
+ bbox_info = bbox_info / focal
+
+
+ human_predict = self.get_human_prediction(batch["human_patches"], bbox_info)
+
+ if isinstance(human_predict, dict):
+ cam_token = human_predict["token_out"].view(batch_size, num_views, 1, 1024)
+ else:
+ cam_token = human_predict.view(batch_size, num_views, 1, 1024)
+
+ # predict scale and smpl z-trans
+ hidden = geometry_results["hidden"].view(batch_size, num_views, -1, geometry_results["hidden"].shape[-1]) # [1, 10, 782, 2048]
+ self.alignnet.to(torch.float32)
+ align_results = self.alignnet(hidden, cam_token)
+ pred_scale = align_results["scale"] # [1]
+
+ if isinstance(human_predict, dict):
+ align_results["betas"] = human_predict["betas"]
+ align_results["pose_cam"] = human_predict["pose_cam"]
+
+ align_results_world = transform_smpl(align_results, w2cs_cano, copy_dict=True)
+
+ pred_pose_world = align_results_world["pose_world"]
+ pred_pose_world_aa = rotmat_to_aa(pred_pose_world.reshape(-1, 3, 3))
+ pred_pose_world_aa = pred_pose_world_aa.reshape(batch_size, num_views, -1)
+ align_results_world["pose_world"] = pred_pose_world_aa
+
+ pred_pose_cam = align_results["pose_cam"]
+ pred_pose_cam_aa = rotmat_to_aa(pred_pose_cam.reshape(-1, 3, 3))
+ pred_pose_cam_aa = pred_pose_cam_aa.reshape(batch_size, num_views, -1)
+ align_results["pose_cam"] = pred_pose_cam_aa
+
+ align_results["betas"] = align_results["betas"].repeat(1, num_views, 1)
+
+ ret_dict = {
+ "pose_world": align_results_world["pose_world"], # [B, S, 72]
+ "trans_world": align_results_world["trans_world"], # [B, S, 3]
+ "pose_cam": align_results["pose_cam"], # [B, S, 72]
+ "trans_cam": align_results["trans_cam"], # [B, S, 3]
+ "betas": align_results["betas"], # [B, S, 10]
+ "world_points": geometry_results["world_points"], # [B, S, H, W, 3]
+ "local_points": geometry_results["local_points"], # [B, S, H, W, 3]
+ "point_conf": geometry_results["point_conf"].squeeze(-1), # [B, S, H, W]
+ "c2ws": geometry_results["c2ws"], # [B, S, 4, 4]
+ "c2ws_cano": geometry_results["c2ws_cano"], # [B, S, 4, 4]
+ "scale": pred_scale, # [B]
+ "w2cs_cano": w2cs_cano, # [B, S, 4, 4]
+ "masks": masks.squeeze(-1), # [B, S, H, W]
+ "intrinsics": pred_intrinsics, # [B, S, 3, 3]
+ }
+
+ ret_dict, scale = self.scale_predictions(ret_dict, ret_dict["scale"])
+ depth_map, intrinsics, intrinsics_normed, shift = self.get_depth_intrinsic_shift(ret_dict)
+ ret_dict["trans_cam"][..., 2] += shift.view(1, -1)
+ ret_dict["depth_map"] = depth_map
+ ret_dict["point_map"] = depth_to_points(depth_map, intrinsics=intrinsics_normed)
+ ret_dict["intrinsics"] = intrinsics
+
+ return ret_dict
+
\ No newline at end of file
diff --git a/unish/utils/__pycache__/constants.cpython-310.pyc b/unish/utils/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24d4a5a84024f958c53730111d0edcb2da50bef6
Binary files /dev/null and b/unish/utils/__pycache__/constants.cpython-310.pyc differ
diff --git a/unish/utils/__pycache__/data_utils.cpython-310.pyc b/unish/utils/__pycache__/data_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49b87f5a20d97c1a6b77d0db4aa214b81520403f
Binary files /dev/null and b/unish/utils/__pycache__/data_utils.cpython-310.pyc differ
diff --git a/unish/utils/__pycache__/inference_utils.cpython-310.pyc b/unish/utils/__pycache__/inference_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..981fa16cd494b3cff9f40314b4060fb9fe8fa542
Binary files /dev/null and b/unish/utils/__pycache__/inference_utils.cpython-310.pyc differ
diff --git a/unish/utils/__pycache__/renderer.cpython-310.pyc b/unish/utils/__pycache__/renderer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd23e8c9836d72add1ad693d84ca96de40623540
Binary files /dev/null and b/unish/utils/__pycache__/renderer.cpython-310.pyc differ
diff --git a/unish/utils/__pycache__/smpl_utils.cpython-310.pyc b/unish/utils/__pycache__/smpl_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f729302c993eb656b20ab3a722e0d0a4a222771
Binary files /dev/null and b/unish/utils/__pycache__/smpl_utils.cpython-310.pyc differ
diff --git a/unish/utils/constants.py b/unish/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..497afca3e138c2880260bf17e454a068403869aa
--- /dev/null
+++ b/unish/utils/constants.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+SMPL_MEAN_PARAMS = {
+ 'shape': np.array([ 0.20560974, 0.33556297, -0.35068282, 0.35612896, 0.41754073,
+ 0.03088791, 0.30475676, 0.23613405, 0.20912662, 0.31212646],
+ dtype=np.float32),
+ 'pose': np.array([ 1.00000000e+00, 0.00000000e+00, 0.00000000e+00, -1.00000048e+00,
+ 0.00000000e+00, 1.50995817e-07, 9.95594144e-01, -9.35074911e-02,
+ 8.96214917e-02, 9.70808983e-01, -2.75727212e-02, -2.20876053e-01,
+ 9.95853722e-01, 8.32496583e-02, -7.22568333e-02, 9.68795717e-01,
+ 5.52654527e-02, -2.33461723e-01, 9.99805868e-01, -1.16801877e-02,
+ 1.55893452e-02, 9.61456239e-01, -1.20510980e-02, 2.74709910e-01,
+ 9.93516386e-01, 7.95017406e-02, -1.06274955e-01, 9.03366446e-01,
+ 4.03832123e-02, 4.21436191e-01, 9.98174787e-01, -5.75351082e-02,
+ 5.92384301e-02, 8.73740733e-01, 1.17418924e-02, 4.82977092e-01,
+ 9.99480844e-01, -6.16004784e-03, 5.43189887e-03, 9.99719322e-01,
+ 3.17586996e-02, 2.28755195e-02, 9.89863694e-01, 5.47103174e-02,
+ -5.37955277e-02, 9.98496652e-01, -1.31437480e-01, 3.35601554e-03,
+ 9.75578487e-01, -1.29129276e-01, 1.39408067e-01, 9.89138842e-01,
+ 1.69740841e-01, -7.02139735e-02, 9.99953926e-01, -8.90268944e-03,
+ 8.85364786e-03, 9.99870598e-01, 3.71519849e-03, 1.33989686e-02,
+ 9.81357634e-01, -1.24362208e-01, 9.10880938e-02, 9.72318351e-01,
+ -1.69234112e-01, -1.97815821e-01, 9.75871921e-01, 2.00855538e-01,
+ -2.04031423e-01, 9.78503823e-01, -7.77502581e-02, -4.67682816e-02,
+ 9.99965727e-01, 3.49923270e-03, -3.45715135e-03, 9.99978244e-01,
+ 7.52193388e-03, -5.59064560e-03, 9.55124676e-01, 2.70941108e-01,
+ -2.68006951e-01, 9.62574720e-01, 1.26131222e-01, -6.38862699e-03,
+ 9.52408850e-01, -2.37964272e-01, 2.28888422e-01, 9.71029997e-01,
+ -2.01314509e-01, -2.17657294e-02, 9.98640895e-01, -3.41732055e-02,
+ 3.14728543e-02, 9.97235835e-01, 4.15433869e-02, 6.59763440e-02,
+ 7.28100538e-01, 6.24727070e-01, -6.39606535e-01, 7.67192006e-01,
+ 2.46521905e-01, 1.45370275e-01, 7.56920695e-01, -5.90696871e-01,
+ 6.05992019e-01, 7.94557154e-01, -2.44631916e-01, 1.40556693e-01,
+ 5.88521481e-01, -2.55633533e-01, 1.52344644e-01, 9.66765523e-01,
+ 7.93998480e-01, 3.98525596e-03, 5.74924588e-01, 2.02400237e-01,
+ -2.33697295e-01, 9.69179034e-01, -7.84121990e-01, -1.40449643e-01,
+ 9.93295610e-01, -4.65713926e-02, 5.34210391e-02, 9.96592045e-01,
+ 1.02518320e-01, -6.80837333e-02, 9.93379056e-01, 6.87826574e-02,
+ -7.23023862e-02, 9.96751606e-01, -8.92773867e-02, -4.18949872e-02,
+ 9.82127666e-01, 1.82622463e-01, -1.74190253e-01, 9.73588109e-01,
+ 7.12950081e-02, -1.37022391e-01, 9.73974586e-01, -2.03668222e-01,
+ 1.94157705e-01, 9.76112187e-01, -1.16945654e-01, -7.56589174e-02],
+ dtype=np.float32),
+ 'cam': np.array([0.9, 0. , 0. ], dtype=np.float32)
+}
diff --git a/unish/utils/data_utils.py b/unish/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dc82ce92e58430a9c1faf75210e112a00d90797
--- /dev/null
+++ b/unish/utils/data_utils.py
@@ -0,0 +1,690 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import cv2
+import numpy as np
+import PIL
+import torch
+import torch.nn.functional as F
+from typing import Tuple
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-5,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Converts a depth map to world coordinates (HxWx3) given the camera extrinsic and intrinsic.
+ Returns both the world coordinates and the intermediate camera coordinates,
+ as well as a mask for valid depth.
+
+ Args:
+ depth_map (np.ndarray):
+ Depth map of shape (H, W).
+ extrinsic (np.ndarray):
+ Extrinsic matrix of shape (3, 4), representing the camera pose in OpenCV convention (w2c).
+ intrinsic (np.ndarray):
+ Intrinsic matrix of shape (3, 3).
+ eps (float):
+ Small epsilon for thresholding valid depth.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray, np.ndarray]:
+ (world_coords_points, cam_coords_points, point_mask)
+
+ - world_coords_points: (H, W, 3) array of 3D points in world frame.
+ - cam_coords_points: (H, W, 3) array of 3D points in camera frame.
+ - point_mask: (H, W) boolean array where True indicates valid (non-zero) depth.
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # The extrinsic is camera-from-world, so invert it to transform camera->world
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = (
+ np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world
+ ) # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(
+ depth_map: np.ndarray, intrinsic: np.ndarray
+) -> np.ndarray:
+ """
+ Unprojects a depth map into camera coordinates, returning (H, W, 3).
+
+ Args:
+ depth_map (np.ndarray):
+ Depth map of shape (H, W).
+ intrinsic (np.ndarray):
+ 3x3 camera intrinsic matrix.
+ Assumes zero skew and standard OpenCV layout:
+ [ fx 0 cx ]
+ [ 0 fy cy ]
+ [ 0 0 1 ]
+
+ Returns:
+ np.ndarray:
+ An (H, W, 3) array, where each pixel is mapped to (x, y, z) in the camera frame.
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert (
+ intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0
+ ), "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ return np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+def aa_to_rotmat(theta: torch.Tensor):
+ """
+ Convert axis-angle representation to rotation matrix.
+ Works by first converting it to a quaternion.
+ Args:
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
+ angle = torch.unsqueeze(norm, -1)
+ normalized = torch.div(theta, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
+ return quat_to_rotmat(quat)
+
+def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
+ """
+ Convert quaternion representation to rotation matrix.
+ Args:
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
+ Returns:
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
+ """
+ norm_quat = quat
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
+ return rotMat
+
+def rotmat_to_aa(rotation_matrix):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to Rodrigues vector
+
+ Args:
+ rotation_matrix (Tensor): rotation matrix.
+
+ Returns:
+ Tensor: Rodrigues vector transformation.
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 3)`
+
+ Example:
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
+ """
+ if rotation_matrix.shape[1:] == (3,3):
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
+ hom = torch.tensor([0, 0, 1], dtype=torch.float32,
+ device=rotation_matrix.device).reshape(1, 3, 1).expand(rot_mat.shape[0], -1, -1)
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
+
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
+ aa = quaternion_to_angle_axis(quaternion)
+ # Handle any remaining NaN or inf values
+ aa = torch.where(torch.isfinite(aa), aa, torch.zeros_like(aa))
+ return aa
+
+def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert quaternion vector to angle axis of rotation.
+
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
+
+ Args:
+ quaternion (torch.Tensor): tensor with quaternions.
+
+ Return:
+ torch.Tensor: tensor with angle axis of rotation.
+
+ Shape:
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
+ - Output: :math:`(*, 3)`
+
+ Example:
+ >>> quaternion = torch.rand(2, 4) # Nx4
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
+ """
+ if not torch.is_tensor(quaternion):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(quaternion)))
+
+ if not quaternion.shape[-1] == 4:
+ raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
+ .format(quaternion.shape))
+ # unpack input and compute conversion
+ q1: torch.Tensor = quaternion[..., 1]
+ q2: torch.Tensor = quaternion[..., 2]
+ q3: torch.Tensor = quaternion[..., 3]
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
+
+ # Add numerical stability: ensure sin_squared_theta is non-negative
+ eps = 1e-8
+ sin_squared_theta = torch.clamp(sin_squared_theta, min=0.0)
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta + eps)
+ cos_theta: torch.Tensor = quaternion[..., 0]
+ two_theta: torch.Tensor = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta))
+
+ # Add numerical stability: avoid division by very small sin_theta values
+ threshold = 1e-6
+ sin_theta_safe = torch.where(sin_theta < threshold, threshold, sin_theta)
+ k_pos: torch.Tensor = two_theta / sin_theta_safe
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
+
+ # Use a more conservative threshold for numerical stability
+ k: torch.Tensor = torch.where(sin_squared_theta > threshold * threshold, k_pos, k_neg)
+
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
+ angle_axis[..., 0] += q1 * k
+ angle_axis[..., 1] += q2 * k
+ angle_axis[..., 2] += q3 * k
+ return angle_axis
+
+
+def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
+ """
+ This function is borrowed from https://github.com/kornia/kornia
+
+ Convert 3x4 rotation matrix to 4d quaternion vector
+
+ This algorithm is based on algorithm described in
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
+
+ Args:
+ rotation_matrix (Tensor): the rotation matrix to convert.
+
+ Return:
+ Tensor: the rotation in quaternion
+
+ Shape:
+ - Input: :math:`(N, 3, 4)`
+ - Output: :math:`(N, 4)`
+
+ Example:
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
+ """
+ if not torch.is_tensor(rotation_matrix):
+ raise TypeError("Input type is not a torch.Tensor. Got {}".format(
+ type(rotation_matrix)))
+
+ if len(rotation_matrix.shape) > 3:
+ raise ValueError(
+ "Input size must be a three dimensional tensor. Got {}".format(
+ rotation_matrix.shape))
+ if not rotation_matrix.shape[-2:] == (3, 4):
+ raise ValueError(
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
+ rotation_matrix.shape))
+
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
+
+ mask_d2 = rmat_t[:, 2, 2] < eps
+
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
+
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
+ t0_rep = t0.repeat(4, 1).t()
+
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
+ q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
+ t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
+ t1_rep = t1.repeat(4, 1).t()
+
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
+ t2_rep = t2.repeat(4, 1).t()
+
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
+ q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
+ t3_rep = t3.repeat(4, 1).t()
+
+ mask_c0 = mask_d2 * mask_d0_d1
+ mask_c1 = mask_d2 * ~mask_d0_d1
+ mask_c2 = ~mask_d2 * mask_d0_nd1
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
+
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
+
+ # Add numerical stability to avoid division by zero or very small numbers
+ denominator = torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
+ t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
+ denominator = torch.clamp(denominator, min=eps)
+ q /= denominator
+ q *= 0.5
+
+ # Normalize quaternion to unit length for additional stability
+ q_norm = torch.norm(q, p=2, dim=-1, keepdim=True)
+ q_norm = torch.clamp(q_norm, min=eps)
+ q = q / q_norm
+
+ return q
+
+def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to 3x3 rotation matrix.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
+ """
+ assert len(x.shape) == 2 and x.shape[1] == 6
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous() # (B, 6) -> (B, 3, 2)
+ a1 = x[:, :, 0]
+ a2 = x[:, :, 1]
+ b1 = F.normalize(a1)
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
+ b3 = torch.cross(b1, b2, dim=1)
+ return torch.stack((b1, b2, b3), dim=-1)
+
+def rotmat_to_rot6d(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 3x3 rotation matrix to 6D rotation representation.
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
+ Args:
+ x (torch.Tensor): (B,3,3) Batch of 3x3 rotation matrices.
+ Returns:
+ torch.Tensor: Batch of corresponding 6D rotation representations with shape (B,6).
+ """
+ assert len(x.shape) == 3 and x.shape[1] == 3 and x.shape[2] == 3
+ # 按列的顺序输出:前3个元素是第一列,后3个元素是第二列
+ # 这样与rot6d_to_rotmat中的reshape操作匹配
+ return torch.cat([x[:, :, 0], x[:, :, 1]], dim=-1)
+
+def aa_to_rot6d(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert axis-angle representation to 6D rotation representation.
+ Args:
+ x (torch.Tensor): (B,3) Batch of axis-angle representations.
+ Returns:
+ torch.Tensor: Batch of corresponding 6D rotation representations with shape (B,6).
+ """
+ assert len(x.shape) == 2 and x.shape[1] == 3
+ x = aa_to_rotmat(x)
+ x = rotmat_to_rot6d(x)
+ return x
+
+def rot6d_to_aa(x: torch.Tensor) -> torch.Tensor:
+ """
+ Convert 6D rotation representation to axis-angle representation.
+ Args:
+ x (torch.Tensor): (B,6) Batch of 6D rotation representations.
+ Returns:
+ torch.Tensor: Batch of corresponding axis-angle representations with shape (B,3).
+ """
+ assert len(x.shape) == 2 and x.shape[1] == 6
+ x = rot6d_to_rotmat(x)
+ x = rotmat_to_aa(x)
+ return x
+
+def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
+ """
+ Rotate a 2D point on the x-y plane.
+ Args:
+ pt_2d (np.array): Input 2D point with shape (2,).
+ rot_rad (float): Rotation angle
+ Returns:
+ np.array: Rotated 2D point.
+ """
+ x = pt_2d[0]
+ y = pt_2d[1]
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+ xx = x * cs - y * sn
+ yy = x * sn + y * cs
+ return np.array([xx, yy], dtype=np.float32)
+
+def convert_to_full_img_cam(pred_cam, bbox_height, bbox_center, img_w, img_h, focal_length):
+ s, tx, ty = pred_cam[:, 0], pred_cam[:, 1], pred_cam[:, 2]
+ tz = 2. * focal_length / (bbox_height * s)
+ cx = 2. * (bbox_center[:, 0] - (img_w / 2.)) / (s * bbox_height)
+ cy = 2. * (bbox_center[:, 1] - (img_h / 2.)) / (s * bbox_height)
+ trans = torch.stack([tx + cx, ty + cy, tz], dim=-1)
+ return trans
+
+
+def gen_trans_from_patch_cv(c_x: float, c_y: float,
+ src_width: float, src_height: float,
+ dst_width: float, dst_height: float,
+ scale: float, rot: float) -> np.array:
+ """
+ Create transformation matrix for the bounding box crop.
+ Args:
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ src_width (float): Bounding box width.
+ src_height (float): Bounding box height.
+ dst_width (float): Output box width.
+ dst_height (float): Output box height.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ trans (np.array): Target geometric transformation.
+ """
+ # augment size with scale
+ src_w = src_width * scale
+ src_h = src_height * scale
+ src_center = np.zeros(2)
+ src_center[0] = c_x
+ src_center[1] = c_y
+ # augment rotation
+ rot_rad = np.pi * rot / 180
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
+
+ dst_w = dst_width
+ dst_h = dst_height
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = src_center
+ src[1, :] = src_center + src_downdir
+ src[2, :] = src_center + src_rightdir
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = dst_center
+ dst[1, :] = dst_center + dst_downdir
+ dst[2, :] = dst_center + dst_rightdir
+
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
+ bb_width: float, bb_height: float,
+ patch_width: float, patch_height: float,
+ do_flip: bool, scale: float, rot: float,
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
+ """
+ Crop the input image and return the crop and the corresponding transformation matrix.
+ Args:
+ img (np.array): Input image of shape (H, W, 3)
+ c_x (float): Bounding box center x coordinate in the original image.
+ c_y (float): Bounding box center y coordinate in the original image.
+ bb_width (float): Bounding box width.
+ bb_height (float): Bounding box height.
+ patch_width (float): Output box width.
+ patch_height (float): Output box height.
+ do_flip (bool): Whether to flip image or not.
+ scale (float): Rescaling factor for the bounding box (augmentation).
+ rot (float): Random rotation applied to the box.
+ Returns:
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
+ trans (np.array): Transformation matrix.
+ """
+
+ img_height, img_width, img_channels = img.shape
+ if do_flip:
+ img = img[:, ::-1, :]
+ c_x = img_width - c_x - 1
+
+
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
+
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=border_mode,
+ borderValue=border_value,
+ )
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
+ flags=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ )
+
+ return img_patch, trans
+
+def depth_to_pointmap(depth_map: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
+ """
+ Convert depth map and camera intrinsics to 3D point map.
+
+ Args:
+ depth_map (torch.Tensor): Depth map tensor of shape (B, S, H, W, 1) or (B, S, H, W)
+ intrinsics (torch.Tensor): Camera intrinsics tensor of shape (B, S, 3, 3)
+
+ Returns:
+ torch.Tensor: Point map tensor of shape (B, S, H, W, 3) where last dim is [X, Y, Z]
+ """
+ # Handle different input shapes for depth_map
+ if depth_map.dim() == 5: # (B, S, H, W, 1)
+ depth = depth_map.squeeze(-1) # (B, S, H, W)
+ elif depth_map.dim() == 4: # (B, S, H, W)
+ depth = depth_map
+ else:
+ raise ValueError(f"Expected depth_map to have 4 or 5 dimensions, got {depth_map.dim()}")
+
+ B, S, H, W = depth.shape
+ device = depth.device
+ dtype = depth.dtype
+
+ # Extract camera parameters from intrinsics
+ # intrinsics shape: (B, S, 3, 3)
+ fx = intrinsics[..., 0, 0] # (B, S)
+ fy = intrinsics[..., 1, 1] # (B, S)
+ cx = intrinsics[..., 0, 2] # (B, S)
+ cy = intrinsics[..., 1, 2] # (B, S)
+
+ # Create coordinate grids
+ # Generate u, v coordinates for all pixels
+ v_coords, u_coords = torch.meshgrid(
+ torch.arange(H, dtype=dtype, device=device),
+ torch.arange(W, dtype=dtype, device=device),
+ indexing='ij'
+ )
+
+ # Expand to match batch dimensions: (1, 1, H, W)
+ u_coords = u_coords.unsqueeze(0).unsqueeze(0).expand(B, S, H, W)
+ v_coords = v_coords.unsqueeze(0).unsqueeze(0).expand(B, S, H, W)
+
+ # Expand camera parameters to match spatial dimensions: (B, S, H, W)
+ fx_expanded = fx.unsqueeze(-1).unsqueeze(-1).expand(B, S, H, W)
+ fy_expanded = fy.unsqueeze(-1).unsqueeze(-1).expand(B, S, H, W)
+ cx_expanded = cx.unsqueeze(-1).unsqueeze(-1).expand(B, S, H, W)
+ cy_expanded = cy.unsqueeze(-1).unsqueeze(-1).expand(B, S, H, W)
+
+ # Convert pixel coordinates to normalized camera coordinates
+ # X = (u - cx) * Z / fx
+ # Y = (v - cy) * Z / fy
+ # Z = depth
+ X = (u_coords - cx_expanded) * depth / fx_expanded
+ Y = (v_coords - cy_expanded) * depth / fy_expanded
+ Z = depth
+
+ # Stack to create point map: (B, S, H, W, 3)
+ point_map = torch.stack([X, Y, Z], dim=-1)
+
+ return point_map
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
+
+def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor:
+ """
+ Get image space UV grid, ranging in [0, 1].
+
+ >>> image_uv(10, 10):
+ [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
+ ... ... ...
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ torch.Tensor: shape (height, width, 2)
+ """
+ if left is None: left = 0
+ if top is None: top = 0
+ if right is None: right = width
+ if bottom is None: bottom = height
+ u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype)
+ v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+def unproject_cv(
+ uv_coord: torch.Tensor,
+ depth: torch.Tensor = None,
+ extrinsics: torch.Tensor = None,
+ intrinsics: torch.Tensor = None
+) -> torch.Tensor:
+ """
+ Unproject uv coordinates to 3D view space following the OpenCV convention
+
+ Args:
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ depth (torch.Tensor): [..., N] depth value
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
+
+ Returns:
+ points (torch.Tensor): [..., N, 3] 3d points
+ """
+ assert intrinsics is not None, "intrinsics matrix is required"
+ points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
+ points = points @ torch.inverse(intrinsics).transpose(-2, -1)
+ if depth is not None:
+ points = points * depth[..., None]
+ if extrinsics is not None:
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+ points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
+ return points
+
+def depth_to_points(depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None):
+ height, width = depth.shape[-2:]
+ uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device)
+ pts = unproject_cv(uv, depth, intrinsics=intrinsics[..., None, :, :], extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None)
+ return pts
\ No newline at end of file
diff --git a/unish/utils/inference_utils.py b/unish/utils/inference_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d881c1747902312e605bee4cc8713cab3b46e7b
--- /dev/null
+++ b/unish/utils/inference_utils.py
@@ -0,0 +1,2056 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm
+import os
+import cv2
+import trimesh
+import open3d as o3d
+import glob
+import tempfile
+import shutil
+from PIL import Image
+from ultralytics import YOLO
+from sam2.sam2_video_predictor import SAM2VideoPredictor
+import logging
+from huggingface_hub import hf_hub_download
+
+
+from unish.pipeline import UniSHPipeline
+from safetensors.torch import load_file
+from unish.utils.data_utils import generate_image_patch_cv2, closed_form_inverse_se3
+from unish.utils.smpl_utils import SMPLWrapper, transform_smpl
+
+logger = logging.getLogger(__name__)
+
+def aa_to_quat(axis_angle: torch.Tensor) -> torch.Tensor:
+ """
+ Convert axis-angle representation to quaternion (w, x, y, z).
+ Args:
+ axis_angle (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
+ Returns:
+ torch.Tensor: Corresponding quaternions with shape (B, 4) in (w, x, y, z) format.
+ """
+ norm = torch.norm(axis_angle + 1e-8, p=2, dim=1)
+ angle = torch.unsqueeze(norm, -1)
+ normalized = torch.div(axis_angle, angle)
+ angle = angle * 0.5
+ v_cos = torch.cos(angle)
+ v_sin = torch.sin(angle)
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1) # [w, x, y, z]
+ return quat
+
+
+def quat_to_aa(quaternion: torch.Tensor) -> torch.Tensor:
+ """
+ Convert quaternion to axis-angle representation.
+ Args:
+ quaternion (torch.Tensor): Tensor of shape (B, 4) in (w, x, y, z) format.
+ Returns:
+ torch.Tensor: Corresponding axis-angle with shape (B, 3).
+ """
+ # Normalize quaternion
+ quaternion = quaternion / quaternion.norm(p=2, dim=1, keepdim=True)
+
+ w, x, y, z = quaternion[:, 0], quaternion[:, 1], quaternion[:, 2], quaternion[:, 3]
+ sin_squared_theta = x * x + y * y + z * z
+
+ sin_theta = torch.sqrt(sin_squared_theta)
+ cos_theta = w
+ two_theta = 2.0 * torch.where(
+ cos_theta < 0.0,
+ torch.atan2(-sin_theta, -cos_theta),
+ torch.atan2(sin_theta, cos_theta)
+ )
+
+ k_pos = two_theta / sin_theta
+ k_neg = 2.0 * torch.ones_like(sin_theta)
+ k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
+
+ axis_angle = torch.zeros_like(quaternion)[:, :3]
+ axis_angle[:, 0] = x * k
+ axis_angle[:, 1] = y * k
+ axis_angle[:, 2] = z * k
+ return axis_angle
+
+def load_model(checkpoint_path=None):
+ """
+ Load the trained model.
+ Supports loading from local path OR automatic download from Hugging Face.
+ """
+
+ logger.info("Creating model...")
+ model = UniSHPipeline()
+
+ if checkpoint_path is None:
+ logger.info(f"No checkpoint_path provided. Downloading default model from HuggingFace")
+ checkpoint_path = hf_hub_download(repo_id="Murphyyyy/UniSH/UniSH", filename="checkpoints/unish_release.safetensors")
+
+ elif not os.path.exists(checkpoint_path):
+ logger.info(f"File '{checkpoint_path}' not found locally. Trying to download from Hugging Face...")
+ try:
+ checkpoint_path = hf_hub_download(repo_id="Murphyyyy/UniSH", filename="checkpoints/unish_release.safetensors")
+ except Exception as e:
+ raise FileNotFoundError(f"Checkpoint not found locally at '{checkpoint_path}' and failed to download from HuggingFace'. Error: {e}")
+
+ logger.info(f"Loading model weights from {checkpoint_path}...")
+
+ if checkpoint_path.endswith(('.safetensor', '.safetensors')):
+ state_dict = load_file(checkpoint_path)
+ else:
+ state_dict = torch.load(checkpoint_path, map_location='cpu')
+
+ model.load_state_dict(state_dict, strict=True)
+
+ weight_dtype = torch.bfloat16
+ if weight_dtype != torch.float32:
+ logger.info(f"Casting model to {weight_dtype}...")
+ for module in model.modules():
+ if hasattr(module, 'weight') and module.weight is not None:
+ module.weight.data = module.weight.data.to(dtype=weight_dtype)
+ if hasattr(module, 'bias') and module.bias is not None:
+ module.bias.data = module.bias.data.to(dtype=weight_dtype)
+
+ for param_name, param in module.named_parameters(recurse=False):
+ if param_name not in ['weight', 'bias']:
+ param.data = param.data.to(dtype=weight_dtype)
+
+ return model
+
+def resize_frames(frames, target_size=518):
+ """Resize frames with proper scaling logic (following load_and_preprocess_data_from_npz_pi3)"""
+
+ # Process frames with proper resizing logic
+ processed_frames = []
+ for frame in frames:
+ # Convert to torch tensor and reorder to (C, H, W)
+ frame_tensor = torch.from_numpy(frame).float() / 255.0 # Normalize to 0-1
+ frame_tensor = frame_tensor.permute(2, 0, 1) # Shape: (3, H, W)
+
+ # Get original size
+ original_height, original_width = frame_tensor.shape[1], frame_tensor.shape[2]
+
+ # Determine if video is landscape or portrait and scale the longer dimension to target_size
+ if original_width >= original_height:
+ # Landscape video: scale width to target_size
+ new_width = target_size
+ new_height = round(original_height * (new_width / original_width) / 14) * 14
+ else:
+ # Portrait video: scale height to target_size
+ new_height = target_size
+ new_width = round(original_width * (new_height / original_height) / 14) * 14
+
+ # Resize frame
+ frame_resized = torch.nn.functional.interpolate(
+ frame_tensor.unsqueeze(0), # Add batch dimension
+ size=(new_height, new_width),
+ mode='bilinear',
+ align_corners=False
+ ).squeeze(0) # Remove batch dimension
+
+ # Center crop if any dimension is larger than target_size (crop mode)
+ if new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ frame_resized = frame_resized[:, start_y : start_y + target_size, :]
+ if new_width > target_size:
+ start_x = (new_width - target_size) // 2
+ frame_resized = frame_resized[:, :, start_x : start_x + target_size]
+
+ processed_frames.append(frame_resized)
+
+ # Stack all processed frames
+ rgbs = torch.stack(processed_frames) # Shape: (N, 3, H, W)
+
+ return rgbs
+
+def generate_patches_and_bbox(rgbs, human_boxes, box_scale=1.0):
+ """Generate human patches and bbox info from processed RGB frames and human bounding boxes"""
+
+ height, width = rgbs.shape[2], rgbs.shape[3]
+ num_frames = rgbs.shape[0]
+ num_humans = human_boxes.shape[0] if len(human_boxes.shape) > 0 else 0
+
+ if num_humans == 0:
+ raise ValueError("No humans detected in the video. Cannot generate patches and bbox info.")
+
+ # human_boxes shape: (num_humans, num_frames, 4) in xyxy format [0,1]
+ all_human_patches = []
+ all_bbox_info = []
+
+ # Process each human separately
+ for human_idx in range(num_humans):
+ human_boxes_single = human_boxes[human_idx] # Shape: (num_frames, 4)
+
+ # Calculate center coordinates and box sizes from normalized xyxy format
+ box_centers = []
+ box_sizes = []
+ for frame_idx in range(num_frames):
+ # Get normalized bbox: [x1, y1, x2, y2] in [0,1] range
+ x1, y1, x2, y2 = human_boxes_single[frame_idx]
+
+ # Convert to pixel coordinates
+ x1_pixel = x1 * width
+ y1_pixel = y1 * height
+ x2_pixel = x2 * width
+ y2_pixel = y2 * height
+
+ # Calculate center coordinates in pixels
+ center_x = (x1_pixel + x2_pixel) / 2.0
+ center_y = (y1_pixel + y2_pixel) / 2.0
+
+ # Calculate box size as the longer edge in pixels
+ bbox_width = x2_pixel - x1_pixel
+ bbox_height = y2_pixel - y1_pixel
+ box_size = max(bbox_width, bbox_height) * box_scale
+
+ box_centers.append([center_x, center_y])
+ box_sizes.append(box_size)
+
+ box_centers = torch.tensor(box_centers).float() # Shape: (num_frames, 2)
+ box_sizes = torch.tensor(box_sizes).float() # Shape: (num_frames,)
+
+ # Calculate bbox_info for human head prediction
+ cx, cy = box_centers[:, 0], box_centers[:, 1]
+ bbox_info = torch.stack([cx - width / 2.0, cy - height / 2.0, box_sizes], dim=-1)
+ all_bbox_info.append(bbox_info)
+
+ # Generate human patches using calculated centers and sizes
+ np_rgbs = rgbs.permute(0, 2, 3, 1).numpy()
+ rgbs_patch = []
+ for i in range(num_frames):
+ img_patch_cv, _ = generate_image_patch_cv2(
+ np_rgbs[i],
+ box_centers[i, 0], box_centers[i, 1],
+ box_sizes[i], box_sizes[i],
+ 256, 256,
+ False, 1.0, 0,
+ border_mode=cv2.BORDER_CONSTANT
+ )
+ rgbs_patch.append(img_patch_cv)
+ rgbs_patch = torch.from_numpy(np.stack(rgbs_patch)).permute(0, 3, 1, 2)
+ all_human_patches.append(rgbs_patch)
+
+ # Stack all humans' data
+ # all_human_patches: list of (num_frames, 3, 256, 256) -> (num_humans, num_frames, 3, 256, 256)
+ all_human_patches = torch.stack(all_human_patches, dim=0)
+ # all_bbox_info: list of (num_frames, 3) -> (num_humans, num_frames, 3)
+ all_bbox_info = torch.stack(all_bbox_info, dim=0)
+
+ return all_human_patches, all_bbox_info
+
+def save_smpl_meshes_per_frame(results, output_dir, body_models_path="body_models/"):
+ """Save SMPL meshes (not just points) for each frame using world coordinates - supports multiple humans"""
+ seq_name = results['seq_name']
+ selected_views = results['selected_views']
+ num_humans = results.get('num_humans', 1)
+
+ smpl_meshes_dir = os.path.join(output_dir, seq_name, "smpl_meshes_per_frame")
+ os.makedirs(smpl_meshes_dir, exist_ok=True)
+
+ # Initialize SMPL visualizer
+ device = torch.device('cpu') # Work on CPU for file saving
+ smpl_visualizer = SMPLWrapper(
+ model_folder=body_models_path,
+ model_type='smpl',
+ device=device,
+ dtype=torch.float32
+ )
+
+ # Process each human
+ all_human_vertices = []
+ for human_idx in range(num_humans):
+ if num_humans > 1:
+ # Multi-human case: use the _all_humans data
+ pred_pose_world_aa = results['pred_pose_world_aa_all_humans'][human_idx]
+ pred_trans_world = results['pred_trans_world_all_humans'][human_idx]
+ pred_betas = results['pred_betas_all_humans'][human_idx]
+ else:
+ # Single human case: use the original data
+ pred_pose_world_aa = results['pred_pose_world_aa']
+ pred_trans_world = results['pred_trans_world']
+ pred_betas = results['pred_betas']
+
+ # Get vertices for all frames in batch using world coordinates
+ vertices, joints = smpl_visualizer.get_batch_vertices(
+ pred_pose_world_aa, pred_betas, pred_trans_world, "neutral"
+ )
+ all_human_vertices.append(vertices)
+
+ # Get SMPL faces (same for all frames and humans)
+ smpl_faces = smpl_visualizer.models['neutral'].faces.astype(np.int32)
+
+ # Save each frame as mesh PLY file
+ for i in tqdm(range(len(selected_views)), desc="Processing SMPL meshes per frame"):
+ frame_idx = selected_views[i]
+
+ if num_humans == 1:
+ # Single human: save individual mesh
+ frame_vertices = all_human_vertices[0][i].cpu().numpy() # [V, 3]
+
+ # Create vertex colors (skin color)
+ vertex_colors = np.ones((len(frame_vertices), 3)) * [0.8, 0.6, 0.4]
+ vertex_colors = (vertex_colors * 255).astype(np.uint8)
+
+ # Create mesh
+ mesh = trimesh.Trimesh(
+ vertices=frame_vertices,
+ faces=smpl_faces,
+ vertex_colors=vertex_colors
+ )
+
+ # Save mesh
+ mesh_filename = f"smpl_mesh_frame_{frame_idx:04d}.ply"
+ mesh_path = os.path.join(smpl_meshes_dir, mesh_filename)
+ mesh.export(mesh_path)
+ else:
+ # Multiple humans: save combined mesh and individual meshes
+ combined_vertices = []
+ combined_faces = []
+ combined_colors = []
+ vertex_offset = 0
+
+ # Colors for different humans
+ colors_palette = [
+ [0.8, 0.6, 0.4], # Skin color for human 0
+ [0.2, 0.8, 0.2], # Green for human 1
+ [0.2, 0.2, 0.8], # Blue for human 2
+ [0.8, 0.2, 0.2], # Red for human 3
+ [0.8, 0.8, 0.2], # Yellow for human 4
+ [0.8, 0.2, 0.8], # Magenta for human 5
+ [0.2, 0.8, 0.8], # Cyan for human 6
+ ]
+
+ for human_idx in range(num_humans):
+ frame_vertices = all_human_vertices[human_idx][i].cpu().numpy() # [V, 3]
+
+ # Add vertices to combined mesh
+ combined_vertices.append(frame_vertices)
+
+ # Add faces with vertex offset
+ faces_with_offset = smpl_faces + vertex_offset
+ combined_faces.append(faces_with_offset)
+ vertex_offset += len(frame_vertices)
+
+ # Create colors for this human
+ human_color = colors_palette[human_idx % len(colors_palette)]
+ vertex_colors = np.ones((len(frame_vertices), 3)) * human_color
+ vertex_colors = (vertex_colors * 255).astype(np.uint8)
+ combined_colors.append(vertex_colors)
+
+ # Also save individual mesh for each human
+ individual_mesh = trimesh.Trimesh(
+ vertices=frame_vertices,
+ faces=smpl_faces,
+ vertex_colors=vertex_colors
+ )
+ individual_mesh_filename = f"human_{human_idx:02d}_smpl_mesh_frame_{frame_idx:04d}.ply"
+ individual_mesh_path = os.path.join(smpl_meshes_dir, individual_mesh_filename)
+ individual_mesh.export(individual_mesh_path)
+
+ # Create and save combined mesh
+ if combined_vertices:
+ combined_vertices_array = np.concatenate(combined_vertices, axis=0)
+ combined_faces_array = np.concatenate(combined_faces, axis=0)
+ combined_colors_array = np.concatenate(combined_colors, axis=0)
+
+ combined_mesh = trimesh.Trimesh(
+ vertices=combined_vertices_array,
+ faces=combined_faces_array,
+ vertex_colors=combined_colors_array
+ )
+
+ combined_mesh_filename = f"combined_smpl_mesh_frame_{frame_idx:04d}.ply"
+ combined_mesh_path = os.path.join(smpl_meshes_dir, combined_mesh_filename)
+ combined_mesh.export(combined_mesh_path)
+
+def save_scene_only_point_clouds(scene_only_point_clouds, output_dir, seq_name):
+ """
+ Save scene-only point clouds (without human regions) as PLY files.
+
+ Args:
+ scene_only_point_clouds: List of Open3D point clouds for scene-only
+ output_dir: Output directory
+ seq_name: Sequence name
+ """
+ scene_only_dir = os.path.join(
+ output_dir, seq_name, "scene_only_point_clouds")
+ os.makedirs(scene_only_dir, exist_ok=True)
+
+ for i, scene_pcd in enumerate(scene_only_point_clouds):
+ if len(scene_pcd.points) > 0:
+ ply_path = os.path.join(
+ scene_only_dir, f"scene_only_frame_{i:04d}.ply")
+ o3d.io.write_point_cloud(ply_path, scene_pcd)
+
+def save_human_point_clouds(complete_scene_point_clouds, scene_only_point_clouds, output_dir, seq_name, results=None):
+ """
+ Save human point clouds (extracted using human masks from point cloud) as PLY files.
+ Single human processing.
+
+ Args:
+ complete_scene_point_clouds: List of complete scene point clouds
+ scene_only_point_clouds: List of scene-only point clouds
+ output_dir: Output directory
+ seq_name: Sequence name
+ results: Results dictionary containing human masks and point maps
+ """
+ human_only_dir = os.path.join(
+ output_dir, seq_name, "human_only_point_clouds")
+ os.makedirs(human_only_dir, exist_ok=True)
+
+ # Try to use human masks for more accurate extraction
+ if results is not None and 'human_masks' in results and 'point_map' in results:
+ point_map = results['point_map'] # [S, H, W, 3] - Camera coordinates
+ depth_conf = results['depth_conf'] # [S, H, W, 1]
+ human_masks = results['human_masks'] # [num_humans, S, H, W]
+ # [S, 3, 4] or [S, 4, 4] - World-to-camera transformation
+ extrinsics = results['extrinsics']
+ # [S, 3, H, W] - Original RGB images
+ rgb_images = results.get('rgb_images', None)
+
+ for frame_idx in range(point_map.shape[0]):
+ human_pcd = o3d.geometry.PointCloud()
+
+ # Extract human points using masks
+ # [H, W, 3] - Camera coordinates
+ frame_points = point_map[frame_idx].numpy()
+ frame_conf = depth_conf[frame_idx].squeeze(-1).numpy() # [H, W]
+
+ # Get extrinsics for this frame
+ extr = extrinsics[frame_idx].numpy() # [3, 4] or [4, 4]
+ if extr.shape[0] == 3: # Convert [3, 4] to [4, 4]
+ extr = np.vstack([extr, [0, 0, 0, 1]])
+
+ # Convert camera coordinates to world coordinates
+ cam_to_world = closed_form_inverse_se3(extr[None])[0] # [4, 4]
+ R_cam_to_world = cam_to_world[:3, :3] # [3, 3]
+ t_cam_to_world = cam_to_world[:3, 3] # [3]
+
+ # Transform all points to world coordinates
+ frame_points_world = np.dot(
+ frame_points, R_cam_to_world.T) + t_cam_to_world # [H, W, 3]
+
+ # Single human processing
+ target_h, target_w = frame_points.shape[0], frame_points.shape[1]
+ combined_human_mask = np.zeros((target_h, target_w), dtype=bool)
+
+ human_idx = 0
+ if frame_idx < human_masks.shape[1]:
+ # [H_orig, W_orig]
+ human_mask = human_masks[human_idx,
+ frame_idx].cpu().numpy()
+
+ # Resize human mask to match point_map dimensions
+ if human_mask.shape != (target_h, target_w):
+ # Resize mask using nearest neighbor to preserve binary values
+ human_mask_resized = cv2.resize(human_mask.astype(np.uint8),
+ (target_w, target_h),
+ interpolation=cv2.INTER_NEAREST).astype(bool)
+ else:
+ human_mask_resized = human_mask
+
+ combined_human_mask = human_mask_resized
+
+ # Apply confidence threshold and human mask
+ valid_mask = (frame_conf > 0.05) & combined_human_mask
+
+ if np.any(valid_mask):
+ # Get valid human points in world coordinates
+ # [N, 3] - World coordinates
+ valid_points_world = frame_points_world[valid_mask]
+
+ # Filter out zero points (in original camera coordinates)
+ # [N, 3] - Camera coordinates for zero check
+ valid_points_cam = frame_points[valid_mask]
+ non_zero_mask = np.any(valid_points_cam != 0, axis=1)
+ if np.any(non_zero_mask):
+ # [N, 3] - World coordinates
+ human_points_world = valid_points_world[non_zero_mask]
+
+ # Get colors from RGB image if available
+ if rgb_images is not None and frame_idx < rgb_images.shape[0]:
+ rgb_image = rgb_images[frame_idx].permute(
+ 1, 2, 0).numpy() # [H, W, 3], values in [0, 1]
+
+ # Get 2D coordinates of valid human points
+ valid_coords_2d = np.where(valid_mask)
+ y_coords, x_coords = valid_coords_2d
+
+ # Apply non-zero mask to get final coordinates
+ y_coords_final = y_coords[non_zero_mask]
+ x_coords_final = x_coords[non_zero_mask]
+
+ # Extract colors from RGB image
+ # [N, 3], values in [0, 1]
+ human_colors = rgb_image[y_coords_final,
+ x_coords_final]
+ else:
+ # Fallback to red color if RGB not available
+ human_colors = np.tile(
+ [0.8, 0.2, 0.2], (len(human_points_world), 1))
+
+ human_pcd.points = o3d.utility.Vector3dVector(
+ human_points_world)
+ human_pcd.colors = o3d.utility.Vector3dVector(human_colors)
+
+ # Save human point cloud
+ if len(human_pcd.points) > 0:
+ ply_path = os.path.join(
+ human_only_dir, f"human_frame_{frame_idx:04d}.ply")
+ o3d.io.write_point_cloud(ply_path, human_pcd)
+
+def save_camera_parameters_per_frame(results, output_dir, seq_name):
+ """
+ Save camera parameters (extrinsics and intrinsics) for each frame as NPZ files.
+
+ Args:
+ results: Results dictionary containing camera parameters
+ output_dir: Output directory
+ seq_name: Sequence name
+ """
+ seq_dir = os.path.join(output_dir, seq_name)
+ os.makedirs(seq_dir, exist_ok=True)
+
+ extrinsics = results.get('extrinsics', None)
+ intrinsics = results.get('pred_intrinsics', None)
+
+ if extrinsics is None or intrinsics is None:
+ logger.warning("Camera parameters not found in results")
+ return
+
+ # Convert to numpy if needed
+ if hasattr(extrinsics, 'cpu'):
+ extrinsics = extrinsics.cpu().numpy()
+ if hasattr(intrinsics, 'cpu'):
+ intrinsics = intrinsics.cpu().numpy()
+
+ num_frames = len(extrinsics)
+
+ # Ensure extrinsics is 4x4 for all frames
+ if extrinsics.shape[-2] == 3: # [S, 3, 4] -> [S, 4, 4]
+ bottom_row = np.tile([0, 0, 0, 1], (num_frames, 1, 1))
+ extrinsics = np.concatenate([extrinsics, bottom_row], axis=-2)
+
+ # Save all camera parameters in one NPZ file
+ summary_file = os.path.join(seq_dir, "camera_parameters.npz")
+ np.savez(summary_file,
+ # [S, 4, 4] - World-to-Camera (w2c) transformation matrices
+ extrinsics=extrinsics,
+ intrinsics=intrinsics, # [S, 3, 3] - Camera intrinsic matrices
+ num_frames=num_frames)
+
+def segment_human(frames, human_idx, yolo_ckpt="yolo11n.pt", sam2_model="facebook/sam2-hiera-large"):
+ """
+ Segment human using YOLO detection + SAM2 video segmentation
+
+ Args:
+ frames: List of RGB frames (numpy arrays in RGB format)
+ human_idx: Index of the human to segment (0-based)
+ yolo_ckpt: Path to YOLO checkpoint
+ sam2_model: SAM2 model name
+
+ Returns:
+ human_boxes: (1, num_frames, 4) in xyxy format ranging from 0 to 1
+ human_masks: (1, num_frames, H, W) boolean masks
+ """
+
+ logger.info(f"Segmenting human {human_idx} from {len(frames)} frames...")
+
+ # Step 1: Use YOLO to detect humans in the first frame
+ yolo_model = YOLO(yolo_ckpt)
+
+ # Convert first frame to BGR for YOLO (frames are in RGB format)
+ first_frame_bgr = cv2.cvtColor(frames[0], cv2.COLOR_RGB2BGR)
+
+ # Run YOLO detection on first frame
+ results = yolo_model(first_frame_bgr)
+ first_frame_result = results[0]
+
+ # Extract human detections from first frame
+ human_detections = []
+ if first_frame_result.boxes is not None and len(first_frame_result.boxes) > 0:
+ boxes = first_frame_result.boxes
+ classes = boxes.cls
+ confs = boxes.conf
+ xyxyn_boxes = boxes.xyxyn # Normalized coordinates [0,1]
+
+ if classes is not None and confs is not None and xyxyn_boxes is not None:
+ for i in range(len(classes)):
+ if classes[i] == 0: # person class
+ confidence = float(confs[i].item())
+ # [x1, y1, x2, y2] in [0,1]
+ bbox = xyxyn_boxes[i].cpu().numpy()
+ human_detections.append((confidence, bbox))
+
+ # Sort by confidence and select the human_idx-th human
+ human_detections.sort(key=lambda x: x[0], reverse=True)
+
+ if len(human_detections) <= human_idx:
+ raise ValueError(
+ f"Only {len(human_detections)} humans detected, but requested human_idx={human_idx}")
+
+ # [x1, y1, x2, y2] in [0,1]
+ selected_human_bbox = human_detections[human_idx][1]
+ selected_human_conf = human_detections[human_idx][0]
+
+ logger.info(
+ f"Selected human {human_idx} with confidence {selected_human_conf:.3f}")
+ logger.info(f"Initial bbox: {selected_human_bbox}")
+
+ # Release YOLO model
+ del yolo_model
+ torch.cuda.empty_cache()
+
+ # Step 2: Prepare frames for SAM2 (save to temporary directory)
+ temp_dir = tempfile.mkdtemp()
+ try:
+ # Save frames to temporary directory
+ for i, frame in enumerate(frames):
+ frame_path = os.path.join(temp_dir, f"{i:08d}.jpg")
+ # Convert RGB to BGR for saving
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ # Step 3: Initialize SAM2
+ predictor = SAM2VideoPredictor.from_pretrained(sam2_model)
+ inference_state = predictor.init_state(video_path=temp_dir)
+ predictor.reset_state(inference_state)
+
+ # Step 4: Convert normalized bbox to pixel coordinates for SAM2
+ frame_height, frame_width = frames[0].shape[:2]
+ x1, y1, x2, y2 = selected_human_bbox
+ box_pixel = [
+ x1 * frame_width, # x_min
+ y1 * frame_height, # y_min
+ x2 * frame_width, # x_max
+ y2 * frame_height # y_max
+ ]
+
+ logger.info(f"SAM2 input box (pixels): {box_pixel}")
+
+ # Step 5: Add box to SAM2 for first frame
+ ann_frame_idx = 0 # First frame
+ ann_obj_id = 0 # Object ID
+
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
+ inference_state=inference_state,
+ frame_idx=ann_frame_idx,
+ obj_id=ann_obj_id,
+ box=box_pixel,
+ )
+
+ # Step 6: Propagate segmentation through all frames
+ masks = []
+ bboxes = []
+
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
+ # Get mask for our object
+ mask = (out_mask_logits[ann_obj_id] > 0.0).squeeze(0).cpu().numpy()
+
+ # Apply dilation to clean up the mask (like in segment_example.py)
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
+ mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=5)
+ mask = mask.astype(bool)
+
+ masks.append(mask)
+
+ # Calculate bbox from mask
+ if np.any(mask):
+ # Find bounding box of the mask
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+
+ if np.any(rows) and np.any(cols):
+ y_min, y_max = np.where(rows)[0][[0, -1]]
+ x_min, x_max = np.where(cols)[0][[0, -1]]
+
+ # Normalize to [0, 1]
+ bbox = [
+ x_min / frame_width,
+ y_min / frame_height,
+ x_max / frame_width,
+ y_max / frame_height
+ ]
+ else:
+ # Use previous bbox if mask is empty
+ bbox = bboxes[-1] if bboxes else selected_human_bbox.tolist()
+ else:
+ # Use previous bbox if mask is empty
+ bbox = bboxes[-1] if bboxes else selected_human_bbox.tolist()
+
+ bboxes.append(bbox)
+
+ # Release SAM2 model
+ del predictor
+ torch.cuda.empty_cache()
+
+ finally:
+ # Clean up temporary directory
+ shutil.rmtree(temp_dir, ignore_errors=True)
+
+ # Step 7: Format output
+ num_frames = len(frames)
+
+ # Convert to tensors with correct shapes
+ human_boxes = torch.tensor(bboxes).unsqueeze(0) # (1, num_frames, 4)
+ human_masks = torch.tensor(np.stack(masks)).unsqueeze(
+ 0) # (1, num_frames, H, W)
+
+ logger.info(f"Segmentation completed. Output shapes:")
+ logger.info(f" human_boxes: {human_boxes.shape}")
+ logger.info(f" human_masks: {human_masks.shape}")
+
+ return human_boxes, human_masks
+
+
+def extract_frames(video_path, fps, human_idx, start_idx=None, end_idx=None, original_fps=30.0, yolo_ckpt="yolo11n.pt", sam2_model="facebook/sam2-hiera-large"):
+ """Extract frames from video or directory at specified fps
+
+ Args:
+ video_path: Path to video file or directory
+ fps: Target fps for frame extraction
+ human_idx: Human index for segmentation
+ start_idx: Start frame index (default: None, process from beginning)
+ end_idx: End frame index (default: None, process to end)
+ original_fps: Original fps of the image sequence (default: 30.0, used only for directory input)
+ yolo_ckpt: Path to YOLO checkpoint
+ sam2_model: SAM2 model name
+ """
+
+ if not os.path.exists(video_path):
+ raise FileNotFoundError(f"Path not found: {video_path}")
+
+ # Determine frame sampling interval based on input type
+ if os.path.isdir(video_path):
+ logger.info(f"Processing directory: {video_path} at {fps} fps...")
+ # For directory, use provided original fps
+ frames = extract_frames_from_directory(
+ video_path, fps, original_fps=original_fps, start_idx=start_idx, end_idx=end_idx)
+ else:
+ logger.info(f"Processing video file: {video_path} at {fps} fps...")
+ # For video file, get actual fps
+ cap = cv2.VideoCapture(video_path)
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
+ cap.release()
+ frames = extract_frames_from_video(
+ video_path, fps, start_idx=start_idx, end_idx=end_idx)
+
+ num_frames = len(frames)
+ human_boxes, human_masks = segment_human(frames, human_idx, yolo_ckpt=yolo_ckpt, sam2_model=sam2_model)
+
+ # frame: (1, num_frames, 3, H, W)
+ # human_boxes: (1, num_frames, 4) in xyxy format ranging from 0 to 1
+ # human_masks: (1, num_frames, H, W)
+ return frames, human_boxes, human_masks
+
+
+def extract_frames_from_directory(directory_path, fps, original_fps=30.0, start_idx=None, end_idx=None):
+ """Extract frames from directory of images at specified fps
+
+ Args:
+ directory_path: Path to directory containing image files
+ fps: Target fps for frame extraction
+ original_fps: Original fps of the image sequence (default: 30.0)
+ start_idx: Start frame index (default: None, process from beginning)
+ end_idx: End frame index (default: None, process to end)
+ """
+ # Supported image extensions
+ image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tiff', '*.tif']
+
+ # Find all image files in the directory
+ image_files = []
+ for ext in image_extensions:
+ image_files.extend(glob.glob(os.path.join(directory_path, ext)))
+ image_files.extend(
+ glob.glob(os.path.join(directory_path, ext.upper())))
+
+ if len(image_files) == 0:
+ raise ValueError(
+ f"No image files found in directory: {directory_path}")
+
+ # Sort files to ensure consistent ordering
+ image_files.sort()
+ total_images = len(image_files)
+
+ logger.info(f"Found {total_images} images in directory")
+ logger.info(f"Assuming original fps: {original_fps}, target fps: {fps}")
+
+ # Calculate frame interval for desired fps
+ # Assume the image sequence has original_fps
+ # To get target fps, we need to sample every (original_fps / target_fps) frames
+ frame_interval = max(1, int(round(original_fps / fps)))
+
+ logger.info(
+ f"Frame interval: {frame_interval} (sampling every {frame_interval} frames)")
+
+ # First, extract frames based on fps
+ frames = []
+ extracted_count = 0
+
+ for i, image_file in enumerate(image_files):
+ if i % frame_interval == 0:
+ # Load image
+ frame = cv2.imread(image_file)
+ if frame is None:
+ logger.warning(f"Could not load image {image_file}")
+ continue
+
+ # Convert BGR to RGB for further processing
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames.append(frame_rgb)
+ extracted_count += 1
+
+ if len(frames) == 0:
+ raise ValueError("No frames extracted from directory")
+
+ logger.info(
+ f"Extracted {len(frames)} frames from directory (every {frame_interval} images)")
+
+ # Then apply start_idx and end_idx filtering on the extracted frames
+ total_extracted_frames = len(frames)
+
+ if start_idx is not None:
+ start_idx = max(0, start_idx)
+ else:
+ start_idx = 0
+
+ if end_idx is not None:
+ end_idx = min(total_extracted_frames, end_idx)
+ else:
+ end_idx = total_extracted_frames
+
+ # Filter extracted frames based on indices
+ frames = frames[start_idx:end_idx]
+ filtered_frames = len(frames)
+
+ logger.info(
+ f"Applied frame range filtering: frames {start_idx} to {end_idx-1} ({filtered_frames} frames)")
+
+ return frames
+
+
+def extract_frames_from_video(video_path, fps, start_idx=None, end_idx=None):
+ """Extract frames from video file at specified fps
+
+ Args:
+ video_path: Path to video file
+ fps: Target fps for frame extraction
+ start_idx: Start frame index (default: None, process from beginning)
+ end_idx: End frame index (default: None, process to end)
+ """
+
+ # Extract frames from video
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ raise ValueError(f"Cannot open video file: {video_path}")
+
+ # Get video properties
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+ logger.info(f"Video has {total_frames} frames at {video_fps} fps")
+
+ # Calculate frame interval for desired fps
+ frame_interval = max(1, int(round(video_fps / fps)))
+
+ logger.info(
+ f"Frame interval: {frame_interval} (sampling every {frame_interval} frames)")
+
+ # First, extract frames based on fps from the entire video
+ frames = []
+ frame_count = 0
+ extracted_count = 0
+
+ while frame_count < total_frames:
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ if frame_count % frame_interval == 0:
+ # Convert BGR to RGB for further processing
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ frames.append(frame_rgb)
+ extracted_count += 1
+
+ frame_count += 1
+
+ cap.release()
+
+ if len(frames) == 0:
+ raise ValueError("No frames extracted from video")
+
+ logger.info(
+ f"Extracted {len(frames)} frames from video (every {frame_interval} frames)")
+
+ # Then apply start_idx and end_idx filtering on the extracted frames
+ total_extracted_frames = len(frames)
+
+ if start_idx is not None:
+ start_idx = max(0, start_idx)
+ else:
+ start_idx = 0
+
+ if end_idx is not None:
+ end_idx = min(total_extracted_frames, end_idx)
+ else:
+ end_idx = total_extracted_frames
+
+ # Filter extracted frames based on indices
+ frames = frames[start_idx:end_idx]
+ filtered_frames = len(frames)
+
+ logger.info(
+ f"Applied frame range filtering: frames {start_idx} to {end_idx-1} ({filtered_frames} frames)")
+
+ return frames
+
+
+def process_video(video_path, fps, human_idx, target_size=518, bbox_scale=1.0, start_idx=None, end_idx=None, original_fps=30.0, yolo_ckpt="yolo11n.pt", sam2_model="facebook/sam2-hiera-large"):
+ """Process video by extracting frames at specified fps
+
+ Args:
+ video_path: Path to video file or directory
+ fps: Target fps for frame extraction
+ human_idx: Human index for segmentation
+ target_size: Target size for frame processing (default: 518)
+ bbox_scale: Scale factor for bounding box size (default: 1.0)
+ start_idx: Start frame index (default: None, process from beginning)
+ end_idx: End frame index (default: None, process to end)
+ original_fps: Original fps of the image sequence (default: 30.0, used only for directory input)
+ yolo_ckpt: Path to YOLO checkpoint
+ sam2_model: SAM2 model name
+ """
+ # Step 1: Extract frames
+ frames, human_boxes, human_masks = extract_frames(
+ video_path, fps, human_idx, start_idx=start_idx, end_idx=end_idx, original_fps=original_fps,
+ yolo_ckpt=yolo_ckpt, sam2_model=sam2_model)
+
+ # Step 3: Resize frames
+ rgbs = resize_frames(frames, target_size)
+
+ # Step 4: Generate human patches and bbox info
+ # human_boxes: tensor of shape (num_humans, num_frames, 4) in xyxy format ranging from 0 to 1
+ rgbs_patch, bbox_info = generate_patches_and_bbox(
+ rgbs, human_boxes, box_scale=bbox_scale)
+
+ # Step 5: Prepare return data
+ data_basename = os.path.splitext(os.path.basename(video_path))[
+ 0] # e.g., "downtown_arguing_00_0"
+ seq_name = data_basename # Use the full basename including person ID
+ selected_views = list(range(len(frames)))
+
+ return {
+ 'seq_name': seq_name,
+ 'images': rgbs,
+ 'human_patches': rgbs_patch,
+ 'bbox_info': bbox_info,
+ 'selected_views': selected_views,
+ 'human_boxes': human_boxes,
+ 'human_masks': human_masks,
+ }
+
+
+def run_inference(unish_model, data_dict, device, chunk_size=30):
+ """Run inference on the loaded scene data - processed for single human"""
+ seq_name = data_dict['seq_name']
+ rgbs = data_dict['images'].to(device)
+ # [1, num_frames, 3, 256, 256]
+ human_patches = data_dict['human_patches'].to(device)
+ # [1, num_frames, 3]
+ bbox_info = data_dict['bbox_info'].to(device)
+ selected_views = data_dict['selected_views']
+
+ total_frames = rgbs.shape[0]
+ human_idx = 0
+
+ # Initialize lists to store chunk results
+ chunk_results = {
+ 'depth_map': [],
+ 'point_map': [],
+ 'extrinsics': [],
+ 'pred_intrinsics': [],
+ 'depth_conf': [],
+ 'pred_pose_aa': [],
+ 'pred_trans_cam': [],
+ 'pred_betas': [],
+ }
+
+ # Process frames in chunks
+ num_chunks = (total_frames + chunk_size -
+ 1) // chunk_size # Ceiling division
+ for chunk_idx in range(num_chunks):
+ start_idx = chunk_idx * chunk_size
+ end_idx = min(start_idx + chunk_size, total_frames)
+
+ # Extract chunk data
+ rgbs_chunk = rgbs[start_idx:end_idx] # [chunk_frames, 3, H, W]
+ # [chunk_frames, 3, 256, 256]
+ human_patches_chunk = human_patches[human_idx, start_idx:end_idx]
+ bbox_info_chunk = bbox_info[human_idx,
+ start_idx:end_idx] # [chunk_frames, 3]
+
+ # Add batch dimension for model input
+ rgbs_batched = rgbs_chunk.unsqueeze(
+ 0) # [1, chunk_frames, 3, H, W]
+ human_patches_single = human_patches_chunk.unsqueeze(
+ 0).to(torch.bfloat16) # [1, chunk_frames, 3, 256, 256]
+ bbox_info_single = bbox_info_chunk.unsqueeze(
+ 0) # [1, chunk_frames, 3]
+
+ input_dict = {
+ "images": rgbs_batched,
+ "human_patches": human_patches_single,
+ "bbox_info": bbox_info_single,
+ }
+
+ unish_model.eval()
+ with torch.no_grad():
+ predictions = unish_model.inference(input_dict)
+
+ # Store chunk predictions (remove batch dimension and move to CPU)
+ chunk_results['depth_map'].append(
+ predictions["depth_map"].squeeze(0).cpu()) # [chunk_frames, H, W, 1]
+ chunk_results['point_map'].append(
+ predictions["point_map"].squeeze(0).cpu()) # [chunk_frames, H, W, 3]
+ chunk_results['extrinsics'].append(
+ predictions["w2cs_cano"].squeeze(0).cpu()) # [chunk_frames, 4, 4]
+ chunk_results['pred_intrinsics'].append(
+ predictions["intrinsics"].squeeze(0).cpu()) # [chunk_frames, 3, 3]
+ chunk_results['depth_conf'].append(
+ predictions["point_conf"].squeeze(0).cpu()) # [chunk_frames, H, W, 1]
+ chunk_results['pred_pose_aa'].append(
+ predictions["pose_cam"].squeeze(0).cpu()) # [chunk_frames, 165]
+ chunk_results['pred_trans_cam'].append(
+ predictions["trans_cam"].squeeze(0).cpu()) # [chunk_frames, 3]
+ chunk_results['pred_betas'].append(
+ predictions["betas"].squeeze(0).cpu()) # [chunk_frames, 10]
+
+ # Concatenate all chunk results
+ human_result = {}
+ for key in chunk_results:
+ # Concatenate along frame dimension
+ human_result[key] = torch.cat(chunk_results[key], dim=0)
+
+ # Transform pred_trans_cam from camera coordinates to world coordinates
+ # Create SMPL dict for transformation
+ smpl_dict = {
+ # Add batch dimension: [1, S, 165]
+ 'pose_cam': human_result['pred_pose_aa'].unsqueeze(0),
+ # Add batch dimension: [1, S, 3]
+ 'trans_cam': human_result['pred_trans_cam'].unsqueeze(0),
+ # Add batch dimension: [1, S, 10]
+ 'betas': human_result['pred_betas'].unsqueeze(0)
+ }
+
+ # Transform to world coordinates using extrinsics
+ smpl_world_dict = transform_smpl(smpl_dict, human_result['extrinsics'].unsqueeze(
+ 0)) # Add batch dimension to extrinsics
+
+ # Extract world coordinates (remove batch dimension)
+ human_result['pred_trans_world'] = smpl_world_dict['trans_world'].squeeze(
+ 0) # [S, 3]
+ human_result['pred_pose_world_aa'] = smpl_world_dict['pose_world'].squeeze(
+ 0) # [S, 165]
+
+ main_result = human_result
+
+ # Add metadata
+ main_result['seq_name'] = seq_name
+ main_result['selected_views'] = selected_views
+ # Add original RGB images for point cloud coloring
+ main_result['rgb_images'] = data_dict['images'].cpu()
+ # Add human segmentation masks for point cloud separation
+ main_result['human_masks'] = data_dict['human_masks']
+ # Add human bounding boxes for visualization
+ main_result['human_boxes'] = data_dict['human_boxes']
+ return main_result
+
+def generate_mixed_geometries_in_memory(results, body_models_path="body_models/", fps=6, conf_thres=0.1):
+ """
+ Generate mixed geometries (scene point clouds + SMPL meshes) for visualization.
+ Returns lists of scene point clouds and SMPL meshes for each frame.
+
+ Args:
+ results: inference results dictionary
+ body_models_path: path to SMPL body models
+ fps: FPS of the input video and point cloud frequency (default: 6)
+ conf_thres: confidence threshold for point cloud generation
+
+ Returns:
+ scene_point_clouds: list of Open3D point clouds for scene
+ smpl_meshes: list of lists of Open3D triangle meshes for SMPL
+ smpl_points_for_camera: list of SMPL vertices in camera coordinates (for camera positioning)
+ smpl_joints_for_camera: list of SMPL joints in camera coordinates
+ smpl_points_for_world: list of SMPL vertices in world coordinates (for NPZ saving)
+ smpl_joints_for_world: list of SMPL joints in world coordinates (for NPZ saving)
+ viz_scene_point_clouds: list of Open3D point clouds for complete scene
+ viz_smpl_meshes: list of lists of Open3D triangle meshes for SMPL
+ viz_scene_only_point_clouds: list of Open3D point clouds for scene-only (excluding human regions)
+ """
+ seq_name = results['seq_name']
+ point_map_original = results['point_map'] # [S, H, W, 3] - 3D points
+ extrinsics = results['extrinsics'] # [S, 3, 4]
+ intrinsics = results['pred_intrinsics'] # [S, 3, 3]
+ selected_views = results['selected_views']
+ rgb_images = results['rgb_images'] # [S, 3, H, W]
+ # [S, H, W, 1] - same fps as point_map
+ depth_conf_original = results['depth_conf']
+
+ rgb_images_original = rgb_images
+ point_map = point_map_original # [S, H, W, 3]
+ depth_conf = depth_conf_original # [S, H, W, 1]
+
+ # Create confidence mask for valid points - use lower threshold for denser point cloud
+ # [S, H, W] - lower threshold for more points
+ conf_mask = depth_conf.squeeze(-1) > conf_thres
+ # Apply mask to point_map to filter out invalid points
+ point_map_masked = point_map.clone()
+ point_map_masked[~conf_mask.unsqueeze(-1).expand_as(point_map)] = 0
+
+ # Initialize SMPL visualizer
+ device = torch.device('cpu') # Work on CPU for visualization
+ smpl_visualizer = SMPLWrapper(
+ model_folder=body_models_path,
+ model_type='smpl',
+ device=device,
+ dtype=torch.float32
+ )
+
+ # Get SMPL faces (same for all frames and humans)
+ smpl_faces = smpl_visualizer.models['neutral'].faces.astype(np.int32)
+
+ # Get SMPL vertices and joints for single human using camera coordinates
+ all_human_vertices = []
+ # all_human_joints = []
+ human_idx = 0
+
+ # Single human case
+ pred_pose_cam_aa = results['pred_pose_aa']
+ pred_trans_cam = results['pred_trans_cam']
+ pred_betas = results['pred_betas']
+
+ # Get vertices and joints for all frames using camera coordinates
+ vertices, joints = smpl_visualizer.get_batch_vertices(
+ pred_pose_cam_aa, pred_betas, pred_trans_cam, "neutral"
+ )
+ all_human_vertices.append(vertices)
+ # all_human_joints.append(joints)
+
+ # Output lists
+ # scene_point_clouds = [] # Scene point clouds (one per frame)
+ # smpl_meshes = [] # SMPL meshes (list of lists: [frame][human])
+ # smpl_points_for_camera = [] # Camera coordinates
+ # smpl_joints_for_camera = [] # Camera coordinates
+ # smpl_points_for_world = [] # World coordinates
+ # smpl_joints_for_world = [] # World coordinates
+
+ # Collect SMPL points for camera targeting (from first human)
+ # Convert all vertices to a list of points in camera coordinates
+ if len(all_human_vertices) > 0:
+ # [S, V, 3] -> [S*V, 3]
+ all_verts_np = all_human_vertices[0].cpu().numpy().reshape(-1, 3)
+ # Convert to list of (3,) arrays to match original format expected by run_visualization
+ smpl_points_for_camera = list(all_verts_np)
+ else:
+ smpl_points_for_camera = []
+
+ # Generate current frame only scene point clouds (at original fps, no downsampling)
+ print("Generating visualization scene point clouds at original fps...")
+ # Complete scene point clouds (all valid points)
+ viz_scene_point_clouds = []
+ # Scene-only point clouds (excluding human regions)
+ viz_scene_only_point_clouds = []
+ viz_smpl_meshes = []
+
+ # Use original fps data for current frame visualization
+ original_point_map = point_map_original # [S_original, H, W, 3]
+ original_depth_conf = depth_conf_original # [S_original, H, W, 1]
+ # [S_original, H, W]
+ original_conf_mask = original_depth_conf.squeeze(-1) > conf_thres
+
+ # Get original RGB images for coloring
+ original_rgb_for_coloring = rgb_images_original
+
+ human_masks_data = results['human_masks']
+
+ for i in range(original_point_map.shape[0]):
+ # Create scene point cloud for current frame
+ points_3d = original_point_map[i] # [H, W, 3]
+ conf_mask_frame = original_conf_mask[i] # [H, W]
+
+ # Get valid points (complete scene)
+ valid_points = points_3d[conf_mask_frame] # [N, 3]
+
+ # Create complete scene point cloud
+ if len(valid_points) > 0:
+ # Create complete point cloud
+ complete_scene_pcd = o3d.geometry.PointCloud()
+ complete_scene_pcd.points = o3d.utility.Vector3dVector(
+ valid_points.cpu().numpy())
+
+ # Add colors from RGB image
+ if i < len(original_rgb_for_coloring):
+ rgb_frame = original_rgb_for_coloring[i] # [3, H, W]
+ if rgb_frame.dim() == 3 and rgb_frame.shape[0] == 3:
+ # Convert from [3, H, W] to [H, W, 3] and normalize
+ rgb_frame = rgb_frame.permute(1, 2, 0) # [H, W, 3]
+ rgb_frame = rgb_frame.cpu().numpy()
+ if rgb_frame.max() > 1.0:
+ rgb_frame = rgb_frame / 255.0 # Normalize to [0, 1]
+
+ # Get colors for valid points
+ valid_colors = rgb_frame[conf_mask_frame] # [N, 3]
+ complete_scene_pcd.colors = o3d.utility.Vector3dVector(
+ valid_colors)
+ else:
+ # Default gray color if RGB format is unexpected
+ default_colors = np.tile(
+ [0.7, 0.7, 0.7], (len(valid_points), 1))
+ complete_scene_pcd.colors = o3d.utility.Vector3dVector(
+ default_colors)
+ else:
+ # Default gray color if no RGB available
+ default_colors = np.tile(
+ [0.7, 0.7, 0.7], (len(valid_points), 1))
+ complete_scene_pcd.colors = o3d.utility.Vector3dVector(
+ default_colors)
+ else:
+ # Empty point cloud
+ complete_scene_pcd = o3d.geometry.PointCloud()
+
+ viz_scene_point_clouds.append(complete_scene_pcd)
+
+ # Create scene-only point cloud (excluding human regions) if masks are available
+ if human_masks_data is not None:
+ # Combine all human masks for this frame
+ combined_human_mask = torch.zeros_like(
+ conf_mask_frame, dtype=torch.bool)
+
+ # Single human mask
+ human_idx = 0
+ if i < human_masks_data.shape[1]: # Check frame index bounds
+ human_mask = human_masks_data[human_idx, i] # [H, W]
+ # Resize mask if needed to match point cloud resolution
+ if human_mask.shape != conf_mask_frame.shape:
+ human_mask_np = human_mask.cpu().numpy().astype(np.uint8)
+ target_h, target_w = conf_mask_frame.shape
+ human_mask_resized = cv2.resize(
+ human_mask_np, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
+ human_mask = torch.from_numpy(
+ human_mask_resized.astype(bool))
+ combined_human_mask |= human_mask.cpu()
+
+ # Create scene-only mask (valid points AND NOT human regions)
+ scene_only_mask = conf_mask_frame & (~combined_human_mask)
+ scene_only_points = points_3d[scene_only_mask] # [N_scene, 3]
+
+ if len(scene_only_points) > 0:
+ # Create scene-only point cloud
+ scene_only_pcd = o3d.geometry.PointCloud()
+ scene_only_pcd.points = o3d.utility.Vector3dVector(
+ scene_only_points.cpu().numpy())
+
+ # Add colors from RGB image
+ if i < len(original_rgb_for_coloring):
+ rgb_frame = original_rgb_for_coloring[i] # [3, H, W]
+ if rgb_frame.dim() == 3 and rgb_frame.shape[0] == 3:
+ # Convert from [3, H, W] to [H, W, 3] and normalize
+ rgb_frame = rgb_frame.permute(1, 2, 0) # [H, W, 3]
+ rgb_frame = rgb_frame.cpu().numpy()
+ if rgb_frame.max() > 1.0:
+ # Normalize to [0, 1]
+ rgb_frame = rgb_frame / 255.0
+
+ # Get colors for scene-only points
+ # [N_scene, 3]
+ scene_only_colors = rgb_frame[scene_only_mask]
+ scene_only_pcd.colors = o3d.utility.Vector3dVector(
+ scene_only_colors)
+ else:
+ # Default gray color if RGB format is unexpected
+ default_colors = np.tile(
+ [0.7, 0.7, 0.7], (len(scene_only_points), 1))
+ scene_only_pcd.colors = o3d.utility.Vector3dVector(
+ default_colors)
+ else:
+ # Default gray color if no RGB available
+ default_colors = np.tile(
+ [0.7, 0.7, 0.7], (len(scene_only_points), 1))
+ scene_only_pcd.colors = o3d.utility.Vector3dVector(
+ default_colors)
+ else:
+ # Empty scene-only point cloud
+ scene_only_pcd = o3d.geometry.PointCloud()
+ else:
+ # No masks available, use complete scene as scene-only
+ scene_only_pcd = complete_scene_pcd
+
+ viz_scene_only_point_clouds.append(scene_only_pcd)
+
+ # Generate SMPL meshes at original fps (no upsampling)
+ print("Generating SMPL meshes at original fps for visualization...")
+ for i in range(original_point_map.shape[0]):
+ frame_smpl_meshes = []
+
+ # Use direct frame mapping (no upsampling)
+ smpl_idx = min(i, len(all_human_vertices[0]) - 1)
+
+ # Single human
+ human_idx = 0
+ # Get SMPL vertices and joints for this frame and human - use correct mapping
+ # [num_vertices, 3]
+ frame_vertices = all_human_vertices[human_idx][smpl_idx].cpu(
+ ).numpy()
+
+ # Create SMPL mesh
+ smpl_mesh = o3d.geometry.TriangleMesh()
+ smpl_mesh.vertices = o3d.utility.Vector3dVector(frame_vertices)
+ smpl_mesh.triangles = o3d.utility.Vector3iVector(smpl_faces)
+
+ # Single human: use reddish color
+ human_color = np.array([0.8, 0.2, 0.2])
+
+ # Set vertex colors
+ vertex_colors = np.tile(human_color, (len(frame_vertices), 1))
+ smpl_mesh.vertex_colors = o3d.utility.Vector3dVector(vertex_colors)
+
+ # Compute normals for better rendering
+ smpl_mesh.compute_vertex_normals()
+
+ frame_smpl_meshes.append(smpl_mesh)
+
+ viz_smpl_meshes.append(frame_smpl_meshes)
+
+ # Transform scene point clouds and SMPL meshes to world coordinates
+ for i in range(len(viz_scene_point_clouds)):
+ # Use direct frame mapping (no upsampling)
+ extr_idx = min(i, len(extrinsics) - 1)
+
+ extr = extrinsics[extr_idx].numpy() # [3, 4] or [4, 4]
+ if extr.shape[0] == 3: # Convert [3, 4] to [4, 4] if needed
+ extr = np.vstack([extr, [0, 0, 0, 1]])
+
+ # Get camera-to-world transformation
+ cam_to_world_extr = closed_form_inverse_se3(extr[None])[0]
+ R_cam_to_world = cam_to_world_extr[:3, :3]
+ t_cam_to_world = cam_to_world_extr[:3, 3]
+
+ # Transform complete scene point cloud to world coordinates
+ scene_pcd = viz_scene_point_clouds[i]
+ if len(scene_pcd.points) > 0:
+ points_cam = np.asarray(scene_pcd.points) # Camera coordinates
+ # Transform to world coordinates
+ points_world = np.dot(
+ points_cam, R_cam_to_world.T) + t_cam_to_world
+ scene_pcd.points = o3d.utility.Vector3dVector(points_world)
+
+ # Transform scene-only point cloud to world coordinates
+ scene_only_pcd = viz_scene_only_point_clouds[i]
+ if len(scene_only_pcd.points) > 0:
+ points_cam = np.asarray(
+ scene_only_pcd.points) # Camera coordinates
+ # Transform to world coordinates
+ points_world = np.dot(
+ points_cam, R_cam_to_world.T) + t_cam_to_world
+ scene_only_pcd.points = o3d.utility.Vector3dVector(points_world)
+
+ # Transform SMPL meshes to world coordinates
+ for smpl_mesh in viz_smpl_meshes[i]:
+ if len(smpl_mesh.vertices) > 0:
+ vertices_cam = np.asarray(
+ smpl_mesh.vertices) # Camera coordinates
+ # Transform to world coordinates
+ vertices_world = np.dot(
+ vertices_cam, R_cam_to_world.T) + t_cam_to_world
+ smpl_mesh.vertices = o3d.utility.Vector3dVector(vertices_world)
+ # Recompute normals after transformation
+ smpl_mesh.compute_vertex_normals()
+
+ return viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera
+
+
+def create_mp4_from_frames(frames, output_path, fps):
+ """
+ Create MP4 video from a list of RGB frames.
+
+ Args:
+ frames: List of RGB numpy arrays (H, W, 3)
+ output_path: Path to save the MP4 file
+ fps: Frames per second
+ """
+ if not frames:
+ return
+
+ height, width = frames[0].shape[:2]
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
+
+ for frame in frames:
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ video_writer.write(frame_bgr)
+
+ video_writer.release()
+
+
+def run_visualization(scene_point_clouds, smpl_meshes, smpl_points_for_camera, output_dir, seq_name, fps=6, rgb_images=None, human_boxes=None, chunk_size=30, results=None, use_predicted_camera=True, scene_only_point_clouds=None, conf_thres=0.1):
+ """
+ Visualize current frame scene point clouds + human points without accumulation.
+ Creates two types of visualizations:
+ 1. Complete point clouds: scene point clouds (based on rendering mode) + human point clouds (extracted from complete point cloud using masks)
+ 2. Scene-only point clouds (excluding human regions) + SMPL meshes
+
+ Rendering modes:
+ - Default: scene point clouds contain only current frame
+ """
+
+ # Calculate frame counts - use original fps data directly
+ total_frames = len(scene_point_clouds)
+
+ print(f"Starting visualization:")
+ print(f" - Total frames: {total_frames} at {fps} fps")
+
+ # Configuration
+ FIXED_ANGLE = 270
+ CAMERA_DISTANCE_MULTIPLIER = 2.0
+ CAMERA_HEIGHT_OFFSET_MULTIPLIER = 0.3
+ RENDER_WIDTH = 960
+ RENDER_HEIGHT = 540
+ FPS = fps # Use original fps
+ POINT_SIZE = 4.0
+
+ # Setup output paths - create separate folders for current frame only results
+ viz_folder = os.path.join(output_dir, seq_name, "visualization")
+
+ # Type 1: Scene point clouds + human point clouds (renamed from complete)
+ scene_human_frames_folder = os.path.join(
+ viz_folder, "scene_human_frames")
+ scene_human_gifs_folder = os.path.join(viz_folder, "scene_human_gifs")
+ scene_human_video_path = os.path.join(
+ viz_folder, "scene_human_video.mp4")
+ scene_human_gif_path = os.path.join(
+ scene_human_gifs_folder, "scene_human_animation.gif")
+ scene_human_gif_mp4_path = os.path.join(
+ scene_human_gifs_folder, "scene_human_animation.mp4")
+
+ # Type 2: Scene point clouds + SMPL meshes (renamed from scene_only)
+ scene_smpl_frames_folder = os.path.join(
+ viz_folder, "scene_smpl_frames")
+ scene_smpl_gifs_folder = os.path.join(
+ viz_folder, "scene_smpl_gifs")
+ scene_smpl_video_path = os.path.join(
+ viz_folder, "scene_smpl_video.mp4")
+ scene_smpl_gif_path = os.path.join(
+ scene_smpl_gifs_folder, "scene_smpl_animation.gif")
+ scene_smpl_gif_mp4_path = os.path.join(
+ scene_smpl_gifs_folder, "scene_smpl_animation.mp4")
+
+ # Common folders
+ bbox_frames_folder = os.path.join(viz_folder, "bbox_frames")
+ input_frames_folder = os.path.join(viz_folder, "input_frames")
+ input_gif_path = os.path.join(
+ viz_folder, "input_frames_animation.gif")
+ input_gif_mp4_path = os.path.join(
+ viz_folder, "input_frames_animation.mp4")
+
+ # Create all directories
+ os.makedirs(scene_human_frames_folder, exist_ok=True)
+ os.makedirs(scene_human_gifs_folder, exist_ok=True)
+ os.makedirs(scene_smpl_frames_folder, exist_ok=True)
+ os.makedirs(scene_smpl_gifs_folder, exist_ok=True)
+ os.makedirs(bbox_frames_folder, exist_ok=True)
+ os.makedirs(input_frames_folder, exist_ok=True)
+
+ # Setup camera parameters
+ if use_predicted_camera and results is not None:
+ print("Using predicted camera parameters from model")
+ pred_extrinsics = results['extrinsics']
+ pred_intrinsics = results['pred_intrinsics']
+
+ if pred_extrinsics is None or pred_intrinsics is None:
+ raise KeyError("Missing camera parameters")
+ if len(pred_extrinsics) == 0 or len(pred_intrinsics) == 0:
+ raise ValueError("Empty camera parameters")
+
+ # Map frames to camera parameters
+ camera_positions = []
+ camera_rotations = []
+ camera_intrinsics = []
+
+ for i in range(total_frames):
+ # Use direct mapping or closest available parameter
+ camera_idx = min(i, len(pred_extrinsics) - 1)
+
+ extr = pred_extrinsics[camera_idx].numpy()
+ intr = pred_intrinsics[camera_idx].numpy()
+
+ if extr.shape[0] == 3:
+ extr = np.vstack([extr, [0, 0, 0, 1]])
+
+ cam_to_world = closed_form_inverse_se3(extr[None])[0]
+
+ camera_pos_orig = cam_to_world[:3, 3]
+ camera_rot_orig = cam_to_world[:3, :3]
+
+ camera_positions.append(camera_pos_orig.copy())
+ camera_rotations.append(camera_rot_orig.copy())
+ camera_intrinsics.append(intr)
+
+ # Calculate fixed camera target (needed for both predicted and fixed camera modes)
+ if smpl_points_for_camera and len(smpl_points_for_camera) > 0:
+ all_smpl_points = []
+ for frame_points in smpl_points_for_camera:
+ if hasattr(frame_points, '__len__') and len(frame_points) > 0:
+ if isinstance(frame_points, np.ndarray):
+ all_smpl_points.extend(frame_points.tolist())
+ else:
+ all_smpl_points.extend(frame_points)
+
+ if all_smpl_points:
+ all_smpl_points = np.array(all_smpl_points)
+ # Ensure we have a proper 2D array with shape (N, 3)
+ if all_smpl_points.ndim == 1:
+ all_smpl_points = all_smpl_points.reshape(-1, 3)
+ center = np.mean(all_smpl_points, axis=0)
+ # Ensure center is a 3D vector
+ if center.shape[0] != 3:
+ print(
+ f"Warning: Invalid center shape {center.shape}, using default [0, 0, 0]")
+ fixed_target = np.array([0.0, 0.0, 0.0], dtype=np.float32)
+ else:
+ fixed_target = center.astype(np.float32)
+
+ # Calculate fixed camera position only if not using predicted camera
+ if not use_predicted_camera:
+ bbox_min = np.min(all_smpl_points, axis=0)
+ bbox_max = np.max(all_smpl_points, axis=0)
+ bbox_size = np.linalg.norm(bbox_max - bbox_min)
+ camera_distance = bbox_size * CAMERA_DISTANCE_MULTIPLIER
+
+ angle_rad = np.radians(FIXED_ANGLE)
+ camera_x = center[0] + camera_distance * np.cos(angle_rad)
+ camera_z = center[2] + camera_distance * np.sin(angle_rad)
+ camera_y = center[1] + bbox_size * \
+ CAMERA_HEIGHT_OFFSET_MULTIPLIER
+
+ fixed_camera_pos = np.array(
+ [camera_x, camera_y, camera_z], dtype=np.float32)
+ else:
+ fixed_target = np.array([0.0, 0.0, 0.0], dtype=np.float32)
+ if not use_predicted_camera:
+ fixed_camera_pos = np.array([0.0, 2.0, 5.0], dtype=np.float32)
+ else:
+ fixed_target = np.array([0.0, 0.0, 0.0], dtype=np.float32)
+ if not use_predicted_camera:
+ fixed_camera_pos = np.array([0.0, 2.0, 5.0], dtype=np.float32)
+
+ # Setup renderer (same as inference_video_eval.py)
+ render = o3d.visualization.rendering.OffscreenRenderer(
+ RENDER_WIDTH, RENDER_HEIGHT)
+
+ # Setup materials
+ point_mat = o3d.visualization.rendering.MaterialRecord()
+ point_mat.shader = "defaultUnlit" # Unlit to avoid lighting effects on points
+ point_mat.point_size = POINT_SIZE
+
+ mesh_mat = o3d.visualization.rendering.MaterialRecord()
+ mesh_mat.shader = "defaultUnlit" # Change to unlit to avoid lighting effects
+ mesh_mat.base_color = [0.8, 0.6, 0.4, 1.0] # Skin-like color
+ mesh_mat.base_metallic = 0.0
+ mesh_mat.base_roughness = 0.8
+
+ # Setup camera projection - will be updated per frame if using predicted camera
+ default_fov = 60
+ render.scene.camera.set_projection(default_fov, RENDER_WIDTH/RENDER_HEIGHT, 0.1, 100.0,
+ o3d.visualization.rendering.Camera.FovType.Vertical)
+
+ # Setup lighting and background
+ render.scene.set_background([1.0, 1.0, 1.0, 1.0])
+
+ # Try to disable all lighting effects to ensure pure white background
+ try:
+ render.scene.set_lighting(
+ o3d.visualization.rendering.Open3DScene.LightingProfile.NO_SHADOWS, np.array([0, 0, 0]))
+
+ try:
+ render.scene.scene.enable_sun_light(False)
+ except:
+ pass
+
+ try:
+ render.scene.scene.set_indirect_light_intensity(1.0)
+ except:
+ pass
+
+ except Exception as e:
+ print(f"Warning: Could not set lighting profile: {e}")
+ # Fallback: just set background without lighting
+ pass
+
+ scene_human_frames = []
+ scene_smpl_frames = [] # Scene point clouds + SMPL meshes
+ bbox_frames = []
+ input_gif_frames = []
+
+ # Main rendering loop - Create both visualizations
+ for i in tqdm(range(total_frames), desc="Rendering visualization"):
+ # Set camera position and intrinsics
+ if use_predicted_camera and i < len(camera_positions):
+ camera_pos = np.array(camera_positions[i], dtype=np.float32)
+ camera_rot = np.array(camera_rotations[i], dtype=np.float32)
+ camera_intr = np.array(camera_intrinsics[i], dtype=np.float32)
+
+ forward = camera_rot[:, 2]
+ up = -camera_rot[:, 1] # Flip Y axis for correct orientation
+
+ # Target point is camera position + forward direction
+ target = camera_pos + forward
+
+ # Set camera using look_at: look_at(target, eye, up)
+ render.scene.camera.look_at(target, camera_pos, up)
+
+ # Update camera intrinsics if available
+ if camera_intr is not None and camera_intr.shape == (3, 3):
+ # Calculate FOV from intrinsics
+ fx = camera_intr[0, 0]
+ fy = camera_intr[1, 1]
+ # Use vertical FOV based on fy
+ fov_y_rad = 2 * np.arctan(RENDER_HEIGHT / (2 * fy))
+ fov_y_deg = np.degrees(fov_y_rad)
+
+ # Clamp FOV to reasonable range
+ fov_y_deg = np.clip(fov_y_deg, 10, 120)
+
+ render.scene.camera.set_projection(fov_y_deg, RENDER_WIDTH/RENDER_HEIGHT, 0.1, 100.0,
+ o3d.visualization.rendering.Camera.FovType.Vertical)
+
+ else:
+ # Use fixed camera
+ up_vector = np.array([0.0, 1.0, 0.0], dtype=np.float32)
+ render.scene.camera.look_at(
+ fixed_target, fixed_camera_pos, up_vector)
+
+ # Reset to default FOV for fixed camera
+ render.scene.camera.set_projection(default_fov, RENDER_WIDTH/RENDER_HEIGHT, 0.1, 100.0,
+ o3d.visualization.rendering.Camera.FovType.Vertical)
+
+ # === Type 1: Complete point clouds (scene + human points from mask) ===
+ render.scene.clear_geometry()
+ # Ensure white background is set for each frame
+ render.scene.set_background([1.0, 1.0, 1.0, 1.0])
+
+ # Add scene point cloud (based on rendering mode)
+ current_scene_pcd = scene_point_clouds[i]
+ if len(current_scene_pcd.points) > 0:
+ render.scene.add_geometry(
+ f"scene_pointcloud_{i}", current_scene_pcd, point_mat)
+
+ # Add human point clouds (extracted from complete point cloud using masks)
+ if results and 'human_masks' in results and 'point_map' in results and 'rgb_images' in results:
+ human_masks_data = results['human_masks'] # [num_humans, S, H, W]
+ point_map = results['point_map'] # [S, H, W, 3]
+ rgb_images = results['rgb_images'] # [S, 3, H, W]
+ depth_conf = results['depth_conf'] # [S, H, W, 1]
+ extrinsics = results['extrinsics'] # [S, 3, 4] or [S, 4, 4]
+
+ if i < point_map.shape[0] and i < human_masks_data.shape[1]:
+ # Get current frame data
+ points_3d = point_map[i] # [H, W, 3] - in camera coordinates
+ conf_map = depth_conf[i].squeeze(-1) # [H, W]
+ conf_mask_frame = conf_map > conf_thres
+
+ # Get camera-to-world transformation for this frame
+ extr = extrinsics[i].numpy() # [3, 4] or [4, 4]
+ if extr.shape[0] == 3: # Convert [3, 4] to [4, 4] if needed
+ extr = np.vstack([extr, [0, 0, 0, 1]])
+
+ # Get camera-to-world transformation
+ cam_to_world_extr = closed_form_inverse_se3(extr[None])[0]
+ R_cam_to_world = cam_to_world_extr[:3, :3]
+ t_cam_to_world = cam_to_world_extr[:3, 3]
+
+ # Extract human points using masks
+ # Single human processing
+ human_idx = 0
+ if human_idx < human_masks_data.shape[0]:
+ human_mask = human_masks_data[human_idx, i] # [H, W]
+
+ # Resize mask if needed to match point cloud resolution
+ if human_mask.shape != conf_mask_frame.shape:
+ human_mask_np = human_mask.cpu().numpy().astype(np.uint8)
+ target_h, target_w = conf_mask_frame.shape
+ human_mask_resized = cv2.resize(
+ human_mask_np, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
+ human_mask = torch.from_numpy(
+ human_mask_resized.astype(bool))
+
+ # Create human mask (valid points AND human regions)
+ human_point_mask = conf_mask_frame & human_mask.cpu()
+ # [N_human, 3] - camera coordinates
+ human_points_cam = points_3d[human_point_mask]
+
+ if len(human_points_cam) > 0:
+ # Transform human points to world coordinates (same as scene points)
+ human_points_cam_np = human_points_cam.cpu().numpy()
+ human_points_world = np.dot(
+ human_points_cam_np, R_cam_to_world.T) + t_cam_to_world
+
+ # Create human point cloud in world coordinates
+ human_pcd = o3d.geometry.PointCloud()
+ human_pcd.points = o3d.utility.Vector3dVector(
+ human_points_world)
+
+ # Add colors from RGB image
+ if i < rgb_images.shape[0]:
+ rgb_frame = rgb_images[i] # [3, H, W]
+ if rgb_frame.dim() == 3 and rgb_frame.shape[0] == 3:
+ # Convert from [3, H, W] to [H, W, 3] and normalize
+ rgb_frame = rgb_frame.permute(
+ 1, 2, 0) # [H, W, 3]
+ rgb_frame = rgb_frame.cpu().numpy()
+ if rgb_frame.max() > 1.0:
+ # Normalize to [0, 1]
+ rgb_frame = rgb_frame / 255.0
+
+ # Get colors for human points
+ # [N_human, 3]
+ human_colors = rgb_frame[human_point_mask]
+ human_pcd.colors = o3d.utility.Vector3dVector(
+ human_colors)
+ else:
+ # Default red color for human points
+ num_points = len(human_points_cam)
+ human_colors = np.zeros((num_points, 3))
+ human_colors[:, 0] = 0.8 # Red
+ human_colors[:, 1] = 0.2 # Green
+ human_colors[:, 2] = 0.2 # Blue
+ human_pcd.colors = o3d.utility.Vector3dVector(
+ human_colors)
+ else:
+ # Default red color for human points
+ num_points = len(human_points_cam)
+ human_colors = np.zeros((num_points, 3))
+ human_colors[:, 0] = 0.8 # Red
+ human_colors[:, 1] = 0.2 # Green
+ human_colors[:, 2] = 0.2 # Blue
+ human_pcd.colors = o3d.utility.Vector3dVector(
+ human_colors)
+
+ render.scene.add_geometry(
+ f"human_pointcloud_{i}_human_{human_idx}", human_pcd, point_mat)
+
+ # Render complete visualization
+ complete_img = render.render_to_image()
+ complete_img_array = np.asarray(complete_img)
+
+ # Flip vertically and horizontally if using fixed camera mode as they come out upside down and flipped
+ if not use_predicted_camera:
+ complete_img_array = np.flipud(complete_img_array)
+ complete_img_array = np.fliplr(complete_img_array)
+
+ complete_img_bgr = cv2.cvtColor(complete_img_array, cv2.COLOR_RGB2BGR)
+
+ # Save complete frame
+ scene_human_frame_path = os.path.join(
+ scene_human_frames_folder, f"scene_human_frame_{i:04d}.png")
+ cv2.imwrite(scene_human_frame_path, complete_img_bgr)
+ scene_human_frames.append(complete_img_array)
+
+ # === Type 2: Scene-only point clouds + SMPL meshes ===
+ render.scene.clear_geometry()
+ # Ensure white background is set for each frame
+ render.scene.set_background([1.0, 1.0, 1.0, 1.0])
+
+ # Add scene-only point cloud
+ current_scene_only_pcd = scene_only_point_clouds[i]
+ if len(current_scene_only_pcd.points) > 0:
+ render.scene.add_geometry(
+ f"scene_only_pointcloud_{i}", current_scene_only_pcd, point_mat)
+
+ # Add SMPL meshes (same as complete visualization)
+ if i < len(smpl_meshes):
+ frame_meshes = smpl_meshes[i]
+ if isinstance(frame_meshes, list):
+ for human_idx, mesh in enumerate(frame_meshes):
+ if hasattr(mesh, 'vertices') and len(mesh.vertices) > 0:
+ render.scene.add_geometry(
+ f"smpl_mesh_{i}_human_{human_idx}", mesh, mesh_mat)
+ else:
+ if hasattr(frame_meshes, 'vertices') and len(frame_meshes.vertices) > 0:
+ render.scene.add_geometry(
+ f"smpl_mesh_{i}", frame_meshes, mesh_mat)
+
+ # Render scene-only visualization
+ scene_only_img = render.render_to_image()
+ scene_only_img_array = np.asarray(scene_only_img)
+
+ # Flip vertically and horizontally if using fixed camera mode as they come out upside down and flipped
+ if not use_predicted_camera:
+ scene_only_img_array = np.flipud(scene_only_img_array)
+ scene_only_img_array = np.fliplr(scene_only_img_array)
+
+ scene_only_img_bgr = cv2.cvtColor(
+ scene_only_img_array, cv2.COLOR_RGB2BGR)
+
+ # Save scene-only frame
+ scene_smpl_frame_path = os.path.join(
+ scene_smpl_frames_folder, f"scene_smpl_frame_{i:04d}.png")
+ cv2.imwrite(scene_smpl_frame_path, scene_only_img_bgr)
+ scene_smpl_frames.append(scene_only_img_array)
+
+ # Save bbox frame if available
+ if human_boxes is not None and rgb_images is not None and i < len(rgb_images):
+ # Convert tensor to numpy array for OpenCV operations
+ bbox_frame_tensor = rgb_images[i].clone()
+ if bbox_frame_tensor.dim() == 3 and bbox_frame_tensor.shape[0] == 3:
+ # Convert from [3, H, W] to [H, W, 3]
+ bbox_frame_tensor = bbox_frame_tensor.permute(1, 2, 0)
+
+ bbox_frame = bbox_frame_tensor.cpu().numpy()
+ if bbox_frame.max() <= 1.0:
+ bbox_frame = (bbox_frame * 255).astype(np.uint8)
+ else:
+ bbox_frame = bbox_frame.astype(np.uint8)
+
+ # Ensure the array is contiguous for OpenCV
+ bbox_frame = np.ascontiguousarray(bbox_frame)
+
+ # Draw bounding boxes for single human (using correct indexing)
+ img_height, img_width = bbox_frame.shape[:2]
+
+ human_idx = 0
+ if human_idx < human_boxes.shape[0]:
+ # Get normalized bbox: [x1, y1, x2, y2] in [0,1] range
+ x1, y1, x2, y2 = human_boxes[human_idx, i]
+
+ # Convert to pixel coordinates
+ x1_pixel = int(x1 * img_width)
+ y1_pixel = int(y1 * img_height)
+ x2_pixel = int(x2 * img_width)
+ y2_pixel = int(y2 * img_height)
+
+ # Clamp to image boundaries
+ x1_pixel = max(0, min(x1_pixel, img_width))
+ y1_pixel = max(0, min(y1_pixel, img_height))
+ x2_pixel = max(0, min(x2_pixel, img_width))
+ y2_pixel = max(0, min(y2_pixel, img_height))
+
+ # Choose color for each human (single color for now)
+ color = (0, 255, 0) # Green
+
+ # Draw bounding box
+ cv2.rectangle(bbox_frame, (x1_pixel, y1_pixel),
+ (x2_pixel, y2_pixel), color, 2)
+
+ # Add human ID label
+ label = f"Human {human_idx}"
+ label_size = cv2.getTextSize(
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
+ cv2.rectangle(bbox_frame, (x1_pixel, y1_pixel - label_size[1] - 5),
+ (x1_pixel + label_size[0], y1_pixel), color, -1)
+ cv2.putText(bbox_frame, label, (x1_pixel, y1_pixel - 5),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+
+ bbox_frame_path = os.path.join(
+ bbox_frames_folder, f"bbox_frame_{i:04d}.png")
+ cv2.imwrite(bbox_frame_path, cv2.cvtColor(
+ bbox_frame, cv2.COLOR_RGB2BGR))
+ bbox_frames.append(bbox_frame)
+
+ # Save input frame
+ if rgb_images is not None and i < len(rgb_images):
+ input_frame_path = os.path.join(
+ input_frames_folder, f"input_frame_{i:04d}.png")
+ # Ensure the frame is a valid numpy array
+ frame = rgb_images[i]
+ if isinstance(frame, np.ndarray) and frame.size > 0:
+ cv2.imwrite(input_frame_path, cv2.cvtColor(
+ frame, cv2.COLOR_RGB2BGR))
+ input_gif_frames.append(frame)
+ elif isinstance(frame, torch.Tensor):
+ # Convert tensor to numpy with proper shape handling
+ frame_np = frame.cpu().numpy()
+
+ # Handle different tensor formats
+ if frame_np.ndim == 3:
+ if frame_np.shape[0] == 3: # [3, H, W] format
+ frame_np = frame_np.transpose(
+ 1, 2, 0) # Convert to [H, W, 3]
+ # else: already in [H, W, 3] format
+ # [1, 3, H, W] format
+ elif frame_np.ndim == 4 and frame_np.shape[0] == 1:
+ frame_np = frame_np.squeeze(0).transpose(
+ 1, 2, 0) # Convert to [H, W, 3]
+
+ # Ensure proper data type and range
+ if frame_np.max() <= 1.0:
+ frame_np = (frame_np * 255).astype(np.uint8)
+ else:
+ frame_np = frame_np.astype(np.uint8)
+
+ # Ensure frame has 3 channels for RGB
+ if frame_np.shape[-1] == 3:
+ cv2.imwrite(input_frame_path, cv2.cvtColor(
+ frame_np, cv2.COLOR_RGB2BGR))
+ input_gif_frames.append(frame_np)
+ else:
+ print(
+ f"Warning: Frame has {frame_np.shape[-1]} channels, expected 3. Skipping frame at index {i}")
+ print(
+ f"Frame shape: {frame_np.shape}, dtype: {frame_np.dtype}")
+ print(f"Original tensor shape: {frame.shape}")
+ continue
+ else:
+ print(
+ f"Warning: Skipping invalid frame at index {i}: {type(frame)}")
+
+ # Create videos for both visualizations
+ print("Creating videos...")
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+
+ # Complete visualization video
+ print(" - Scene point clouds + human point clouds video...")
+ scene_human_video_writer = cv2.VideoWriter(
+ scene_human_video_path, fourcc, FPS, (RENDER_WIDTH, RENDER_HEIGHT))
+ for frame in scene_human_frames:
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ scene_human_video_writer.write(frame_bgr)
+ scene_human_video_writer.release()
+
+ # Scene-only visualization video
+ print(" - Scene point clouds + SMPL meshes video...")
+ scene_smpl_video_writer = cv2.VideoWriter(
+ scene_smpl_video_path, fourcc, FPS, (RENDER_WIDTH, RENDER_HEIGHT))
+ for frame in scene_smpl_frames:
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
+ scene_smpl_video_writer.write(frame_bgr)
+ scene_smpl_video_writer.release()
+
+ # Create GIFs and MP4s for both visualizations
+ print("Creating GIFs and MP4s...")
+
+ # Scene-human visualization GIF and MP4
+ print(" - Scene point clouds + human point clouds GIF...")
+ scene_human_frames_pil = [Image.fromarray(frame) for frame in scene_human_frames]
+ scene_human_frames_pil[0].save(
+ scene_human_gif_path,
+ save_all=True,
+ append_images=scene_human_frames_pil[1:],
+ duration=1000//FPS,
+ loop=0
+ )
+ print(" - Scene point clouds + human point clouds MP4...")
+ create_mp4_from_frames(scene_human_frames, scene_human_gif_mp4_path, FPS)
+
+ # Scene-SMPL visualization GIF and MP4
+ print(" - Scene point clouds + SMPL meshes GIF...")
+ scene_smpl_frames_pil = [Image.fromarray(
+ frame) for frame in scene_smpl_frames]
+ scene_smpl_frames_pil[0].save(
+ scene_smpl_gif_path,
+ save_all=True,
+ append_images=scene_smpl_frames_pil[1:],
+ duration=1000//FPS,
+ loop=0
+ )
+ print(" - Scene point clouds + SMPL meshes MP4...")
+ create_mp4_from_frames(scene_smpl_frames, scene_smpl_gif_mp4_path, FPS)
+
+ if input_gif_frames:
+ print("Creating input frames GIF...")
+ input_frames_pil = [Image.fromarray(frame)
+ for frame in input_gif_frames]
+ input_frames_pil[0].save(
+ input_gif_path,
+ save_all=True,
+ append_images=input_frames_pil[1:],
+ duration=1000//FPS,
+ loop=0
+ )
+ print("Creating input frames MP4...")
+ create_mp4_from_frames(input_gif_frames, input_gif_mp4_path, FPS)
+
+ # Create chunk GIFs and MP4s for both visualizations
+ print("Creating chunk GIFs and MP4s...")
+
+ # Scene-human visualization chunk GIFs and MP4s
+ print(" - Scene point clouds + human point clouds chunk GIFs and MP4s...")
+ scene_human_chunk_gifs_folder = os.path.join(scene_human_gifs_folder, "chunks")
+ os.makedirs(scene_human_chunk_gifs_folder, exist_ok=True)
+
+ for chunk_idx in range(0, total_frames, chunk_size):
+ end_idx = min(chunk_idx + chunk_size, total_frames)
+ chunk_frames = scene_human_frames[chunk_idx:end_idx]
+
+ if chunk_frames:
+ chunk_gif_path = os.path.join(
+ scene_human_chunk_gifs_folder, f"scene_human_chunk_{chunk_idx//chunk_size:02d}_frames_{chunk_idx:04d}-{end_idx-1:04d}.gif")
+ chunk_mp4_path = os.path.join(
+ scene_human_chunk_gifs_folder, f"scene_human_chunk_{chunk_idx//chunk_size:02d}_frames_{chunk_idx:04d}-{end_idx-1:04d}.mp4")
+ chunk_frames_pil = [Image.fromarray(
+ frame) for frame in chunk_frames]
+ chunk_frames_pil[0].save(
+ chunk_gif_path,
+ save_all=True,
+ append_images=chunk_frames_pil[1:],
+ duration=1000//FPS,
+ loop=0
+ )
+ create_mp4_from_frames(chunk_frames, chunk_mp4_path, FPS)
+
+ # Scene-SMPL visualization chunk GIFs and MP4s
+ print(" - Scene point clouds + SMPL meshes chunk GIFs and MP4s...")
+ scene_smpl_chunk_gifs_folder = os.path.join(
+ scene_smpl_gifs_folder, "chunks")
+ os.makedirs(scene_smpl_chunk_gifs_folder, exist_ok=True)
+
+ for chunk_idx in range(0, total_frames, chunk_size):
+ end_idx = min(chunk_idx + chunk_size, total_frames)
+ chunk_frames = scene_smpl_frames[chunk_idx:end_idx]
+
+ if chunk_frames:
+ chunk_gif_path = os.path.join(
+ scene_smpl_chunk_gifs_folder, f"scene_smpl_chunk_{chunk_idx//chunk_size:02d}_frames_{chunk_idx:04d}-{end_idx-1:04d}.gif")
+ chunk_mp4_path = os.path.join(
+ scene_smpl_chunk_gifs_folder, f"scene_smpl_chunk_{chunk_idx//chunk_size:02d}_frames_{chunk_idx:04d}-{end_idx-1:04d}.mp4")
+ chunk_frames_pil = [Image.fromarray(
+ frame) for frame in chunk_frames]
+ chunk_frames_pil[0].save(
+ chunk_gif_path,
+ save_all=True,
+ append_images=chunk_frames_pil[1:],
+ duration=1000//FPS,
+ loop=0
+ )
+ create_mp4_from_frames(chunk_frames, chunk_mp4_path, FPS)
+
+ # Create input chunk GIFs and MP4s if available
+ if input_gif_frames:
+ print("Creating input frames chunk GIFs and MP4s...")
+ input_chunk_gifs_folder = os.path.join(
+ viz_folder, "input_chunk_gifs")
+ os.makedirs(input_chunk_gifs_folder, exist_ok=True)
+
+ for chunk_idx in range(0, len(input_gif_frames), chunk_size):
+ end_idx = min(chunk_idx + chunk_size, len(input_gif_frames))
+ chunk_input_frames = input_gif_frames[chunk_idx:end_idx]
+
+ if chunk_input_frames:
+ chunk_input_gif_path = os.path.join(
+ input_chunk_gifs_folder, f"input_chunk_{chunk_idx//chunk_size:02d}_frames_{chunk_idx:04d}-{end_idx-1:04d}.gif")
+ chunk_input_mp4_path = os.path.join(
+ input_chunk_gifs_folder, f"input_chunk_{chunk_idx//chunk_size:02d}_frames_{chunk_idx:04d}-{end_idx-1:04d}.mp4")
+ chunk_input_frames_pil = [Image.fromarray(
+ frame) for frame in chunk_input_frames]
+ chunk_input_frames_pil[0].save(
+ chunk_input_gif_path,
+ save_all=True,
+ append_images=chunk_input_frames_pil[1:],
+ duration=1000//FPS,
+ loop=0
+ )
+ create_mp4_from_frames(
+ chunk_input_frames, chunk_input_mp4_path, FPS)
+ print(f" - Input chunk GIF saved: {chunk_input_gif_path}")
+ print(f" - Input chunk MP4 saved: {chunk_input_mp4_path}")
+
+ print(f"Visualization completed!")
diff --git a/unish/utils/renderer.py b/unish/utils/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..967e52c3ae6c2ade8a329dc4afbb7e2fc9924d7e
--- /dev/null
+++ b/unish/utils/renderer.py
@@ -0,0 +1,617 @@
+import torch
+import numpy as np
+
+from pytorch3d.renderer import (
+ PerspectiveCameras,
+ TexturesVertex,
+ PointLights,
+ Materials,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+)
+from pytorch3d.structures import Meshes
+from pytorch3d.structures.meshes import join_meshes_as_scene
+from pytorch3d.renderer.cameras import look_at_rotation
+from pytorch3d.transforms import axis_angle_to_matrix
+
+
+colors_str_map = {
+ "gray": [0.8, 0.8, 0.8],
+ "green": [39, 194, 128],
+}
+
+
+def overlay_image_onto_background(image, mask, bbox, background):
+ if isinstance(image, torch.Tensor):
+ image = image.detach().cpu().numpy()
+ if isinstance(mask, torch.Tensor):
+ mask = mask.detach().cpu().numpy()
+
+ out_image = background.copy()
+ bbox = bbox[0].int().cpu().numpy().copy()
+
+ # 确保边界框坐标在有效范围内
+ h_bg, w_bg = background.shape[:2]
+ left, top, right, bottom = bbox[0], bbox[1], bbox[2], bbox[3]
+
+ # 限制边界框坐标在背景图像范围内
+ left = max(0, min(left, w_bg))
+ top = max(0, min(top, h_bg))
+ right = max(left + 1, min(right, w_bg)) # 确保right > left
+ bottom = max(top + 1, min(bottom, h_bg)) # 确保bottom > top
+
+ roi_image = out_image[top:bottom, left:right]
+
+ # 检查roi_image是否为空
+ if roi_image.size == 0:
+ print(f"Warning: ROI image is empty. bbox: [{left}, {top}, {right}, {bottom}], bg_shape: {background.shape}")
+ return out_image
+
+ # 检查image和mask的维度是否匹配roi_image
+ expected_height, expected_width = bottom - top, right - left
+
+ if image.shape[:2] != (expected_height, expected_width):
+ print(f"Warning: Image shape mismatch. Expected: ({expected_height}, {expected_width}), Got: {image.shape[:2]}")
+ # 如果尺寸不匹配,使用背景色
+ return out_image
+
+ if mask.shape[:2] != (expected_height, expected_width):
+ print(f"Warning: Mask shape mismatch. Expected: ({expected_height}, {expected_width}), Got: {mask.shape[:2]}")
+ # 如果尺寸不匹配,使用背景色
+ return out_image
+
+ # 安全地应用掩码
+ try:
+ roi_image[mask] = image[mask]
+ out_image[top:bottom, left:right] = roi_image
+ except Exception as e:
+ print(f"Error in overlay operation: {e}")
+ print(f"roi_image shape: {roi_image.shape}, mask shape: {mask.shape}, image shape: {image.shape}")
+ # 如果出错,返回原始背景
+ return out_image
+
+ return out_image
+
+def overlay_depth_onto_background(depth_map, mask, bbox, background_shape, background_value=0.0):
+ """
+ Overlay depth map onto full-size background with differentiable torch operations
+
+ Args:
+ depth_map: torch.Tensor, depth map from rasterizer
+ mask: torch.Tensor, object mask (boolean or float)
+ bbox: torch.Tensor, bounding box coordinates [left, top, right, bottom]
+ background_shape: tuple, (height, width) of output
+ background_value: float, background depth value
+
+ Returns:
+ torch.Tensor, full-size depth map
+ torch.Tensor, full-size mask
+ """
+ device = depth_map.device
+ dtype = depth_map.dtype
+
+ # Ensure inputs are torch tensors
+ if not isinstance(depth_map, torch.Tensor):
+ depth_map = torch.tensor(depth_map, device=device, dtype=dtype)
+ if not isinstance(mask, torch.Tensor):
+ mask = torch.tensor(mask, device=device, dtype=torch.float32)
+
+ h_bg, w_bg = background_shape
+
+ # Create full-size background tensors
+ out_depth = torch.full((h_bg, w_bg), background_value, device=device, dtype=dtype)
+ out_mask = torch.zeros((h_bg, w_bg), device=device, dtype=mask.dtype)
+
+ # Extract bbox coordinates and clamp to valid range
+ if bbox.dim() > 1:
+ bbox = bbox[0] # Take first bbox if batched
+
+ bbox = bbox.to(torch.int32)
+ left = torch.clamp(bbox[0], 0, w_bg)
+ top = torch.clamp(bbox[1], 0, h_bg)
+ right = torch.clamp(bbox[2], left + 1, w_bg)
+ bottom = torch.clamp(bbox[3], top + 1, h_bg)
+
+ # Calculate expected dimensions
+ expected_height = bottom - top
+ expected_width = right - left
+
+ # Check if dimensions match
+ if (depth_map.shape[0] != expected_height or depth_map.shape[1] != expected_width):
+ print(f"Warning: Depth map shape mismatch. Expected: ({expected_height}, {expected_width}), Got: {depth_map.shape[:2]}")
+ return out_depth, out_mask
+
+ if (mask.shape[0] != expected_height or mask.shape[1] != expected_width):
+ print(f"Warning: Mask shape mismatch. Expected: ({expected_height}, {expected_width}), Got: {mask.shape[:2]}")
+ return out_depth, out_mask
+
+ # Overlay depth map and mask onto background using differentiable operations
+ try:
+ # Convert coordinates to integers for indexing (non-differentiable but necessary)
+ left_int = left.item()
+ top_int = top.item()
+ right_int = right.item()
+ bottom_int = bottom.item()
+
+ # Use tensor operations for differentiable overlay
+ out_depth[top_int:bottom_int, left_int:right_int] = depth_map
+ out_mask[top_int:bottom_int, left_int:right_int] = mask
+
+ except Exception as e:
+ print(f"Error in depth overlay operation: {e}")
+ print(f"Target region shape: ({bottom_int - top_int}, {right_int - left_int}), depth_map shape: {depth_map.shape}")
+
+ return out_depth, out_mask
+
+
+def update_intrinsics_from_bbox(K_org, bbox):
+ device, dtype = K_org.device, K_org.dtype
+
+ K = torch.zeros((K_org.shape[0], 4, 4)).to(device=device, dtype=dtype)
+ K[:, :3, :3] = K_org.clone()
+ K[:, 2, 2] = 0
+ K[:, 2, -1] = 1
+ K[:, -1, 2] = 1
+
+ image_sizes = []
+ for idx, bbox in enumerate(bbox):
+ left, upper, right, lower = bbox
+ cx, cy = K[idx, 0, 2], K[idx, 1, 2]
+
+ new_cx = cx - left
+ new_cy = cy - upper
+ new_height = max(lower - upper, 1)
+ new_width = max(right - left, 1)
+ new_cx = new_width - new_cx
+ new_cy = new_height - new_cy
+
+ K[idx, 0, 2] = new_cx
+ K[idx, 1, 2] = new_cy
+ image_sizes.append((int(new_height), int(new_width)))
+
+ return K, image_sizes
+
+
+def perspective_projection(x3d, K, R=None, T=None):
+ if R != None:
+ x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
+ if T != None:
+ x3d = x3d + T.transpose(1, 2)
+
+ z_coords = x3d[..., 2:]
+
+ valid_z_mask = (torch.abs(z_coords) >= 1e-5) & torch.isfinite(z_coords)
+
+ z_coords_safe = torch.where(valid_z_mask, z_coords, torch.sign(z_coords) * 1e-5 + 1e-5)
+
+ x3d_safe = x3d.clone()
+ x3d_safe[..., 2:] = z_coords_safe
+ x2d = torch.div(x3d_safe, x3d_safe[..., 2:])
+ x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
+
+ final_valid_mask = valid_z_mask.squeeze(-1) & torch.isfinite(x2d).all(dim=-1, keepdim=True).squeeze(-1)
+
+ x2d_masked = torch.where(final_valid_mask.unsqueeze(-1), x2d, torch.tensor(-999.0, device=x2d.device, dtype=x2d.dtype))
+
+ return x2d_masked
+
+
+def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
+ if X.numel() == 0:
+ print("Warning: Empty points for bbox computation, using full image bbox")
+ bbox = torch.tensor([[0, 0, img_w, img_h]]).float()
+ return bbox
+
+ if len(X.shape) == 3:
+ X_flat = X.reshape(-1, X.shape[-1]) # (batch_size * num_points, 2)
+ elif len(X.shape) == 2:
+ X_flat = X
+ else:
+ print(f"Warning: Unexpected X shape {X.shape}, using full image bbox")
+ bbox = torch.tensor([[0, 0, img_w, img_h]]).float()
+ return bbox
+
+ valid_mask = torch.isfinite(X_flat).all(dim=-1)
+ if not valid_mask.any():
+ print("Warning: No valid points for bbox computation, using full image bbox")
+ bbox = torch.tensor([[0, 0, img_w, img_h]]).float()
+ return bbox
+
+ X_valid = X_flat[valid_mask]
+ if X_valid.numel() == 0:
+ print("Warning: No valid points after filtering, using full image bbox")
+ bbox = torch.tensor([[0, 0, img_w, img_h]]).float()
+ return bbox
+
+ img_w_tensor = torch.tensor(img_w, dtype=X_valid.dtype, device=X_valid.device)
+ img_h_tensor = torch.tensor(img_h, dtype=X_valid.dtype, device=X_valid.device)
+
+ left = torch.clamp(X_valid[:, 0].min(), min=0, max=img_w_tensor)
+ right = torch.clamp(X_valid[:, 0].max(), min=0, max=img_w_tensor)
+ top = torch.clamp(X_valid[:, 1].min(), min=0, max=img_h_tensor)
+ bottom = torch.clamp(X_valid[:, 1].max(), min=0, max=img_h_tensor)
+
+ if left >= right:
+ left = torch.clamp(left - 10, min=0, max=img_w_tensor)
+ right = torch.clamp(left + 20, min=1, max=img_w_tensor)
+ if top >= bottom:
+ top = torch.clamp(top - 10, min=0, max=img_h_tensor)
+ bottom = torch.clamp(top + 20, min=1, max=img_h_tensor)
+
+ cx = (left + right) / 2
+ cy = (top + bottom) / 2
+ width = right - left
+ height = bottom - top
+
+ img_w_tensor = torch.tensor(img_w, dtype=cx.dtype, device=cx.device)
+ img_h_tensor = torch.tensor(img_h, dtype=cy.dtype, device=cy.device)
+ scaleFactor_tensor = torch.tensor(scaleFactor, dtype=cx.dtype, device=cx.device)
+
+ new_left = torch.clamp(cx - width / 2 * scaleFactor_tensor, min=0, max=img_w_tensor - 1)
+ new_right = torch.clamp(cx + width / 2 * scaleFactor_tensor, min=1, max=img_w_tensor)
+ new_top = torch.clamp(cy - height / 2 * scaleFactor_tensor, min=0, max=img_h_tensor - 1)
+ new_bottom = torch.clamp(cy + height / 2 * scaleFactor_tensor, min=1, max=img_h_tensor)
+
+ if new_left >= new_right:
+ new_left = torch.tensor(0, dtype=new_left.dtype, device=new_left.device)
+ new_right = torch.tensor(max(1, min(img_w, 100)), dtype=new_right.dtype, device=new_right.device)
+ if new_top >= new_bottom:
+ new_top = torch.tensor(0, dtype=new_top.dtype, device=new_top.device)
+ new_bottom = torch.tensor(max(1, min(img_h, 100)), dtype=new_bottom.dtype, device=new_bottom.device)
+
+ bbox = torch.stack((new_left.detach(), new_top.detach(), new_right.detach(), new_bottom.detach())).int().float()
+
+ if bbox.dim() == 1:
+ bbox = bbox.unsqueeze(0)
+
+ return bbox
+
+
+class Renderer:
+ def __init__(self, width, height, focal_length=None, device="cuda", faces=None, K=None, bin_size=0):
+ """set bin_size to 0 for no binning"""
+ self.width = width
+ self.height = height
+ self.bin_size = bin_size
+ assert (focal_length is not None) ^ (K is not None), "focal_length and K are mutually exclusive"
+
+ self.device = device
+ if faces is not None:
+ if isinstance(faces, np.ndarray):
+ faces = torch.from_numpy((faces).astype("int"))
+ self.faces = faces.unsqueeze(0).to(self.device)
+
+ self.initialize_camera_params(focal_length, K)
+ self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
+ self.create_renderer()
+
+ def create_renderer(self):
+ self.renderer = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ raster_settings=RasterizationSettings(
+ image_size=self.image_sizes[0], blur_radius=1e-5, bin_size=self.bin_size
+ ),
+ ),
+ shader=SoftPhongShader(
+ device=self.device,
+ lights=self.lights,
+ ),
+ )
+
+ def create_camera(self, R=None, T=None):
+ if R is not None:
+ self.R = R.clone().view(1, 3, 3).to(self.device)
+ if T is not None:
+ self.T = T.clone().view(1, 3).to(self.device)
+
+ return PerspectiveCameras(
+ device=self.device, R=self.R.mT, T=self.T, K=self.K_full, image_size=self.image_sizes, in_ndc=False
+ )
+
+ def initialize_camera_params(self, focal_length, K):
+ # Extrinsics
+ self.R = torch.diag(torch.tensor([1, 1, 1])).float().to(self.device).unsqueeze(0)
+
+ self.T = torch.tensor([0, 0, 0]).unsqueeze(0).float().to(self.device)
+
+ # Intrinsics
+ if K is not None:
+ self.K = K.float().reshape(1, 3, 3).to(self.device)
+ else:
+ assert focal_length is not None, "focal_length or K should be provided"
+ self.K = (
+ torch.tensor([[focal_length, 0, self.width / 2], [0, focal_length, self.height / 2], [0, 0, 1]])
+ .float()
+ .reshape(1, 3, 3)
+ .to(self.device)
+ )
+ self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
+ self.cameras = self.create_camera()
+
+ def set_intrinsic(self, K):
+ self.K = K.reshape(1, 3, 3)
+
+ def update_bbox(self, x3d, scale=2.0, mask=None):
+ """Update bbox of cameras from the given 3d points
+
+ x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
+ """
+
+ if x3d.size(-1) != 3:
+ x2d = x3d.unsqueeze(0)
+ else:
+ x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
+
+ if mask is not None:
+ x2d = x2d[:, ~mask]
+
+ bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
+ self.bboxes = bbox
+
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
+ self.cameras = self.create_camera()
+ self.create_renderer()
+
+ def reset_bbox(
+ self,
+ ):
+ bbox = torch.zeros((1, 4)).float().to(self.device)
+ bbox[0, 2] = self.width
+ bbox[0, 3] = self.height
+ self.bboxes = bbox
+
+ self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
+ self.cameras = self.create_camera()
+ self.create_renderer()
+
+ def render_mesh(self, vertices, background=None, colors=None, VI=50):
+ if colors is None:
+ colors = [0.8, 0.8, 0.8]
+ self.update_bbox(vertices[::VI], scale=1.2)
+ vertices = vertices.unsqueeze(0)
+
+ if isinstance(colors, torch.Tensor):
+ # per-vertex color
+ verts_features = colors.to(device=vertices.device, dtype=vertices.dtype)
+ colors = [0.8, 0.8, 0.8]
+ else:
+ if colors[0] > 1:
+ colors = [c / 255.0 for c in colors]
+ verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
+ verts_features = verts_features.repeat(1, vertices.shape[1], 1)
+ textures = TexturesVertex(verts_features=verts_features)
+
+ mesh = Meshes(
+ verts=vertices,
+ faces=self.faces,
+ textures=textures,
+ )
+
+ materials = Materials(device=self.device, specular_color=(colors,), shininess=0)
+
+ results = torch.flip(self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), [1, 2])
+ image = results[0, ..., :3] * 255
+ mask = results[0, ..., -1] > 1e-3
+
+ if background is None:
+ background = np.ones((self.height, self.width, 3)).astype(np.uint8) * 255
+
+ image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
+ self.reset_bbox()
+ return image
+
+ def render_with_ground(self, verts, colors, cameras, lights, faces=None):
+ """
+ :param verts (N, V, 3), potential multiple people
+ :param colors (N, 3) or (N, V, 3)
+ :param faces (N, F, 3), optional, otherwise self.faces is used will be used
+ """
+ # Sanity check of input verts, colors and faces: (B, V, 3), (B, F, 3), (B, V, 3)
+ N, V, _ = verts.shape
+ if faces is None:
+ faces = self.faces.clone().expand(N, -1, -1)
+ else:
+ assert len(faces.shape) == 3, "faces should have shape of (N, F, 3)"
+
+ assert len(colors.shape) in [2, 3]
+ if len(colors.shape) == 2:
+ assert len(colors) == N, "colors of shape 2 should be (N, 3)"
+ colors = colors[:, None]
+ colors = colors.expand(N, V, -1)[..., :3]
+
+ # (V, 3), (F, 3), (V, 3)
+ gv, gf, gc = self.ground_geometry
+ verts = list(torch.unbind(verts, dim=0)) + [gv]
+ faces = list(torch.unbind(faces, dim=0)) + [gf]
+ colors = list(torch.unbind(colors, dim=0)) + [gc[..., :3]]
+ mesh = create_meshes(verts, faces, colors)
+
+ materials = Materials(device=self.device, shininess=0)
+
+ results = self.renderer(mesh, cameras=cameras, lights=lights, materials=materials)
+ image = (results[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
+
+ return image
+
+ def render_depth_only(self, vertices, VI=50, return_visible_vertices=False):
+ """
+ Render only the depth map without RGB computation.
+
+ Args:
+ vertices: Mesh vertices
+ VI: Vertex interval for bbox computation
+ return_visible_vertices: If True, also return visible vertex coordinates
+
+ Returns:
+ tuple: (depth_map, mask) or (depth_map, mask, visible_vertices)
+ - depth_map: Depth map as numpy array, shape (height, width)
+ - mask: Object mask as numpy array, shape (height, width)
+ - visible_vertices: PyTorch tensor of visible vertex coordinates [N, 3]
+ """
+ self.update_bbox(vertices[::VI], scale=1.2)
+ vertices = vertices.unsqueeze(0)
+
+ # Create a simple mesh for depth rendering
+ verts_features = torch.ones(1, vertices.shape[1], 3, device=vertices.device, dtype=vertices.dtype)
+ textures = TexturesVertex(verts_features=verts_features)
+
+ mesh = Meshes(
+ verts=vertices,
+ faces=self.faces,
+ textures=textures,
+ )
+
+ # Get rasterizer fragments for depth information
+ fragments = self.renderer.rasterizer(mesh, cameras=self.cameras)
+
+ # Extract depth map from fragments
+ depth_map = torch.flip(fragments.zbuf[0, ..., 0], [0, 1]) # Flip to match image orientation
+
+ # Create mask from valid depth values
+ mask = torch.flip(fragments.pix_to_face[0, ..., 0] >= 0, [0, 1]) # Valid faces have non-negative indices
+
+ # Handle invalid depth values
+ depth_map = torch.where(depth_map < 0, torch.tensor(0.0, device=depth_map.device), depth_map)
+
+ visible_vertices = None
+ if return_visible_vertices:
+ # Extract visible vertices from fragments
+ visible_vertices = self._get_visible_vertices(fragments, mesh, return_coords=True)
+
+ # Apply overlay processing to ensure consistent size with other render methods
+ depth_map_full, mask_full = overlay_depth_onto_background(
+ depth_map, mask, self.bboxes, (self.height, self.width), background_value=0.0
+ )
+
+ self.reset_bbox()
+
+ if return_visible_vertices:
+ return depth_map_full, mask_full, visible_vertices
+ else:
+ return depth_map_full, mask_full
+
+ def _get_visible_vertices(self, fragments, mesh, return_coords=False):
+ """
+ Get visible vertex indices from rasterizer fragments.
+
+ Args:
+ fragments: Rasterizer fragments containing pix_to_face
+ mesh: The mesh object used for rendering
+ return_coords: If True, return vertex coordinates instead of indices
+
+ Returns:
+ torch.Tensor: Tensor of unique visible vertex indices or coordinates
+ """
+ # Get the face indices for each pixel
+ pix_to_face = fragments.pix_to_face[0] # Remove batch dimension
+
+ # Find pixels that have valid faces (face_id >= 0)
+ valid_mask = pix_to_face >= 0
+
+ # Get faces tensor
+ faces = mesh.faces_list()[0] # Get faces tensor (shape: [num_faces, 3])
+
+ # Only consider the first (closest) face for each pixel
+ closest_faces = pix_to_face[..., 0] # [height, width]
+ valid_closest = valid_mask[..., 0] # [height, width]
+
+ # Get unique visible face IDs
+ visible_face_ids = closest_faces[valid_closest].unique()
+
+ # Get all visible vertex IDs from visible faces
+ # Use indexing to maintain differentiability
+ visible_face_vertices = faces[visible_face_ids] # [num_visible_faces, 3]
+ visible_vertex_indices = visible_face_vertices.flatten().unique() # Flatten and get unique vertices
+
+ if return_coords:
+ # Return actual vertex coordinates
+ vertices = mesh.verts_list()[0] # [num_vertices, 3]
+ visible_vertex_coords = vertices[visible_vertex_indices] # [num_visible_vertices, 3]
+ return visible_vertex_coords
+ else:
+ # Return vertex indices
+ return visible_vertex_indices
+
+
+def create_meshes(verts, faces, colors):
+ """
+ :param verts (B, V, 3)
+ :param faces (B, F, 3)
+ :param colors (B, V, 3)
+ """
+ textures = TexturesVertex(verts_features=colors)
+ meshes = Meshes(verts=verts, faces=faces, textures=textures)
+ return join_meshes_as_scene(meshes)
+
+
+def get_global_cameras(verts, device="cuda", distance=5, position=(-5.0, 5.0, 0.0)):
+ """This always put object at the center of view"""
+ positions = torch.tensor([position]).repeat(len(verts), 1)
+ targets = verts.mean(1)
+
+ directions = targets - positions
+ directions = directions / torch.norm(directions, dim=-1).unsqueeze(-1) * distance
+ positions = targets - directions
+
+ rotation = look_at_rotation(positions, targets).mT
+ translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
+
+ lights = PointLights(device=device, location=[position])
+ return rotation, translation, lights
+
+
+def get_global_cameras_static(
+ verts, beta=4.0, cam_height_degree=30, target_center_height=1.0, use_long_axis=False, vec_rot=45, device="cuda"
+):
+ L, V, _ = verts.shape
+
+ # Compute target trajectory, denote as center + scale
+ targets = verts.mean(1) # (L, 3)
+ targets[:, 1] = 0 # project to xz-plane
+ target_center = targets.mean(0) # (3,)
+ target_scale, target_idx = torch.norm(targets - target_center, dim=-1).max(0)
+
+ # a 45 degree vec from longest axis
+ if use_long_axis:
+ long_vec = targets[target_idx] - target_center # (x, 0, z)
+ long_vec = long_vec / torch.norm(long_vec)
+ R = axis_angle_to_matrix(torch.tensor([0, np.pi / 4, 0])).to(long_vec)
+ vec = R @ long_vec
+ else:
+ vec_rad = vec_rot / 180 * np.pi
+ vec = torch.tensor([np.sin(vec_rad), 0, np.cos(vec_rad)]).float()
+ vec = vec / torch.norm(vec)
+
+ # Compute camera position (center + scale * vec * beta) + y=4
+ target_scale = max(target_scale, 1.0) * beta
+ position = target_center + vec * target_scale
+ position[1] = target_scale * np.tan(np.pi * cam_height_degree / 180) + target_center_height
+
+ # Compute camera rotation and translation
+ positions = position.unsqueeze(0).repeat(L, 1)
+ target_centers = target_center.unsqueeze(0).repeat(L, 1)
+ target_centers[:, 1] = target_center_height
+ rotation = look_at_rotation(positions, target_centers).mT
+ translation = -(rotation @ positions.unsqueeze(-1)).squeeze(-1)
+
+ lights = PointLights(device=device, location=[position.tolist()])
+ return rotation, translation, lights
+
+
+def get_ground_params_from_points(root_points, vert_points):
+ """xz-plane is the ground plane
+ Args:
+ root_points: (L, 3), to decide center
+ vert_points: (L, V, 3), to decide scale
+ """
+ root_max = root_points.max(0)[0] # (3,)
+ root_min = root_points.min(0)[0] # (3,)
+ cx, _, cz = (root_max + root_min) / 2.0
+
+ vert_max = vert_points.reshape(-1, 3).max(0)[0] # (L, 3)
+ vert_min = vert_points.reshape(-1, 3).min(0)[0] # (L, 3)
+ scale = (vert_max - vert_min)[[0, 2]].max()
+ return float(scale), float(cx), float(cz)
diff --git a/unish/utils/smpl_utils.py b/unish/utils/smpl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7575419a390a71c208e6cbc33d14afae3f84d580
--- /dev/null
+++ b/unish/utils/smpl_utils.py
@@ -0,0 +1,393 @@
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import trimesh
+import os
+
+from unish.utils.renderer import Renderer
+from unish.utils.data_utils import closed_form_inverse_se3, rotmat_to_aa, aa_to_rotmat
+
+
+class SMPLWrapper:
+ def __init__(self, model_folder='body_models/', model_type='smplx', device='cpu', dtype=torch.float32):
+ """
+ Initialize SMPL visualizer with SMPL or SMPL-X models
+
+ Args:
+ model_folder (str): Path to model folder (should contain smpl/ or smplx/ subfolders)
+ model_type (str): Model type, either 'smpl' or 'smplx'
+ device (str or torch.device): Device to run models on ('cpu', 'cuda', or torch.device object)
+ dtype (torch.dtype): Data type for the models (default: torch.float32)
+ """
+ import smplx
+
+ self.model_folder = model_folder
+ self.model_type = model_type.lower()
+ # 🚀 确保device被正确转换为torch.device对象
+ if isinstance(device, str):
+ self.device = torch.device(device)
+ else:
+ self.device = device
+ self.dtype = dtype
+ self.models = {}
+
+ # Initialize models for different genders on specified device
+ if self.model_type == 'smplx':
+ model_path = os.path.join(model_folder, 'smplx/models') if 'smplx' not in model_folder else model_folder
+ self.models['male'] = smplx.create(model_path, model_type='smplx',
+ gender='neutral',
+ ext='npz',
+ flat_hand_mean=True,
+ num_betas=11,
+ use_pca=False).to(self.device, dtype=self.dtype)
+
+ self.models['female'] = smplx.create(model_path, model_type='smplx',
+ gender='female',
+ ext='npz',
+ num_betas=11,
+ flat_hand_mean=True,
+ use_pca=False).to(self.device, dtype=self.dtype)
+
+ self.models['neutral'] = smplx.create(model_path, model_type='smplx',
+ gender='neutral',
+ ext='npz',
+ flat_hand_mean=True,
+ num_betas=11,
+ use_pca=False).to(self.device, dtype=self.dtype)
+
+ elif self.model_type == 'smpl':
+ model_path = os.path.join(model_folder, 'smpl') if 'smpl' not in model_folder else model_folder
+ self.models['male'] = smplx.create(model_path, model_type='smpl',
+ gender='male',
+ ext='pkl',
+ num_betas=10).to(self.device, dtype=self.dtype)
+
+ self.models['female'] = smplx.create(model_path, model_type='smpl',
+ gender='female',
+ ext='pkl',
+ num_betas=10).to(self.device, dtype=self.dtype)
+
+ self.models['neutral'] = smplx.create(model_path, model_type='smpl',
+ gender='neutral',
+ ext='pkl',
+ num_betas=10).to(self.device, dtype=self.dtype)
+ else:
+ raise ValueError(f"Unsupported model type: {model_type}. Please use 'smpl' or 'smplx'.")
+
+ def get_vertices(self, poses, betas, trans, gender):
+ """
+ Get model vertices and joints for given parameters
+
+ Args:
+ poses (torch.Tensor): Pose parameters (72 for SMPL, 165 for SMPL-X)
+ betas (torch.Tensor): Shape parameters (10 for SMPL, 11 for SMPL-X)
+ trans (torch.Tensor): Translation parameters
+ gender (str): Gender of the model ('male', 'female', or 'neutral')
+
+ Returns:
+ tuple: (vertices, joints)
+ """
+ if gender not in self.models:
+ raise ValueError('Please provide gender as male, female, or neutral')
+
+ if isinstance(poses, np.ndarray):
+ poses = torch.from_numpy(poses)
+ if isinstance(betas, np.ndarray):
+ betas = torch.from_numpy(betas)
+ if isinstance(trans, np.ndarray):
+ trans = torch.from_numpy(trans)
+
+ if len(poses.shape) == 1:
+ poses = poses.unsqueeze(0)
+ if len(betas.shape) == 1:
+ betas = betas.unsqueeze(0)
+ if len(trans.shape) == 1:
+ trans = trans.unsqueeze(0)
+
+ # Move data to the same device as the model
+ poses = poses.to(self.device, dtype=self.dtype)
+ betas = betas.to(self.device, dtype=self.dtype)
+ trans = trans.to(self.device, dtype=self.dtype)
+
+ if self.model_type == 'smplx':
+ # SMPL-X parameters
+ model_out = self.models[gender](
+ betas=betas,
+ global_orient=poses[:, :3],
+ body_pose=poses[:, 3:66],
+ left_hand_pose=poses[:, 75:120],
+ right_hand_pose=poses[:, 120:165],
+ jaw_pose=poses[:, 66:69],
+ leye_pose=poses[:, 69:72],
+ reye_pose=poses[:, 72:75],
+ transl=trans
+ )
+ elif self.model_type == 'smpl':
+ # SMPL parameters (72 dimensions: 3 for global_orient + 69 for body_pose)
+ model_out = self.models[gender](
+ betas=betas,
+ global_orient=poses[:, :3],
+ body_pose=poses[:, 3:72],
+ transl=trans
+ )
+
+ return model_out.vertices[0], model_out.joints[0]
+
+ def get_batch_vertices(self, poses, betas, trans, gender):
+ """
+ Get model vertices and joints for batched parameters (更高效的批处理版本)
+
+ Args:
+ poses (torch.Tensor): Pose parameters [B, 72/165] (72 for SMPL, 165 for SMPL-X)
+ betas (torch.Tensor): Shape parameters [B, 10/11] (10 for SMPL, 11 for SMPL-X)
+ trans (torch.Tensor): Translation parameters [B, 3]
+ gender (str): Gender of the model ('male', 'female', or 'neutral')
+
+ Returns:
+ tuple: (vertices [B, V, 3], joints [B, J, 3])
+ """
+
+ assert len(poses.shape) == 2 and len(betas.shape) == 2 and len(trans.shape) == 2, "poses, betas, trans should be 2D"
+
+ if gender not in self.models:
+ raise ValueError('Please provide gender as male, female, or neutral')
+
+ # Move data to the same device and dtype as the model
+ poses = poses.to(self.device, dtype=self.dtype)
+ betas = betas.to(self.device, dtype=self.dtype)
+ trans = trans.to(self.device, dtype=self.dtype)
+
+ if self.model_type == 'smplx':
+ # SMPL-X parameters
+ model_out = self.models[gender](
+ betas=betas,
+ global_orient=poses[:, :3],
+ body_pose=poses[:, 3:66],
+ left_hand_pose=poses[:, 75:120],
+ right_hand_pose=poses[:, 120:165],
+ jaw_pose=poses[:, 66:69],
+ leye_pose=poses[:, 69:72],
+ reye_pose=poses[:, 72:75],
+ transl=trans
+ )
+ elif self.model_type == 'smpl':
+ # SMPL parameters (72 dimensions: 3 for global_orient + 69 for body_pose)
+ model_out = self.models[gender](
+ betas=betas,
+ global_orient=poses[:, :3],
+ body_pose=poses[:, 3:72],
+ transl=trans
+ )
+
+ return model_out.vertices, model_out.joints
+
+ def render(self, poses, betas, trans, gender, K, background=None, w2c=None):
+ """
+ Render SMPL model with given parameters
+
+ Args:
+ poses (torch.Tensor): Pose parameters (72 for SMPL, 165 for SMPL-X)
+ betas (torch.Tensor): Shape parameters (10 for SMPL, 11 for SMPL-X)
+ trans (torch.Tensor): Translation parameters
+ gender (str): Gender of the model
+ K (torch.Tensor): Camera intrinsic matrix
+ background (numpy.ndarray, optional): Background image
+ w2c (torch.Tensor, optional): Transformation matrix from world to camera
+
+ Returns:
+ tuple: (rendered_image, vertices)
+ """
+
+ extr = torch.eye(4) if w2c is None else w2c
+
+ vertices, joints = self.get_vertices(poses, betas, trans, gender)
+
+ if background is None:
+ width, height = K[0, 2] * 2, K[1, 2] * 2
+ background = np.zeros((int(height), int(width), 3))
+ else:
+ height, width = background.shape[:2]
+
+ renderer = Renderer(width, height, device="cuda", faces=self.models[gender].faces, K=K)
+ renderer.create_camera(R=extr[:3, :3], T=extr[:3, 3])
+
+ vertices_float32 = vertices.float().to(self.device)
+ render_img = renderer.render_mesh(vertices_float32, background, [0.8, 0.8, 0.8])
+ return render_img, vertices
+
+ def get_smpl_depth(self, vertices, K, extr=None, width=None, height=None, return_visible_vertices=False):
+ """
+ Get depth map and mask from SMPL vertices
+
+ Args:
+ vertices (torch.Tensor): SMPL vertices [V, 3]
+ K (torch.Tensor): Camera intrinsic matrix [3, 3]
+ extr (torch.Tensor, optional): Camera extrinsic matrix [4, 4]. If None, uses identity
+ width (int, optional): Image width. If None, uses 2*K[0,2]
+ height (int, optional): Image height. If None, uses 2*K[1,2]
+
+ Returns:
+ tuple: (depth_map, mask)
+ - depth_map (np.ndarray): Depth map as numpy array
+ - mask (np.ndarray): Object mask as numpy array
+ """
+ # Set default extrinsic matrix if not provided
+ if extr is None:
+ extr = torch.eye(4, device=vertices.device, dtype=vertices.dtype)
+
+ if isinstance(K, np.ndarray):
+ K = torch.from_numpy(K)
+
+ # Get image dimensions from camera intrinsics if not provided
+ if width is None:
+ width = int(K[0, 2] * 2)
+ if height is None:
+ height = int(K[1, 2] * 2)
+
+ # Ensure vertices are on the correct device
+ vertices = vertices.to(self.device, dtype=self.dtype)
+ K = K.to(self.device, dtype=self.dtype)
+ extr = extr.to(self.device, dtype=self.dtype)
+
+ # Create renderer instance
+ renderer = Renderer(width, height, device=self.device, faces=self.models['neutral'].faces, K=K)
+
+ # Set camera pose from extrinsic matrix
+ R = extr[:3, :3] # Rotation matrix
+ T = extr[:3, 3] # Translation vector
+ renderer.create_camera(R=R, T=T)
+
+ # Render depth only
+ if return_visible_vertices:
+ depth_map, mask, visible_vertices = renderer.render_depth_only(vertices, return_visible_vertices=return_visible_vertices)
+ return depth_map, mask, visible_vertices
+ else:
+ depth_map, mask = renderer.render_depth_only(vertices, return_visible_vertices=return_visible_vertices)
+ return depth_map, mask
+
+ def get_smplx_vertices(self, poses, betas, trans, gender):
+ """Deprecated: Use get_vertices() instead. This method is kept for backward compatibility."""
+ return self.get_vertices(poses, betas, trans, gender)
+
+ def get_smplx_batch_vertices(self, poses, betas, trans, gender):
+ """Deprecated: Use get_batch_vertices() instead. This method is kept for backward compatibility."""
+ return self.get_batch_vertices(poses, betas, trans, gender)
+
+def transform_smpl(smpl_dict, extrinsics, copy_dict=True):
+ """
+ Transform SMPL parameters from camera coordinate system to world coordinate system.
+
+ Args:
+ smpl_dict (dict): Dictionary containing SMPL parameters in camera coordinates
+ - 'pose_cam': Pose parameters as rotation matrices (B, S, N, 3, 3) or axis-angle (B, S, N*3)
+ - 'trans_cam': Translation parameters (B, S, 3)
+ - 'betas': Shape parameters (B, S, 10 or 11) - unchanged by coordinate transform
+ extrinsics (torch.Tensor): Camera extrinsic matrix (B, S, 4, 4) or (4, 4)
+ Transformation matrix from world to camera coordinates
+ copy_dict (bool): Whether to create a copy of the input dict (default: True)
+
+ Returns:
+ dict: Transformed SMPL dictionary with parameters in world coordinates
+ - 'pose_world': Transformed pose parameters in world coordinates
+ - 'trans_world': Transformed translation parameters in world coordinates
+ - 'betas': Shape parameters (unchanged)
+ """
+
+ # Create a copy to avoid modifying the original dictionary
+ if copy_dict:
+ transformed_dict = {}
+ for key, value in smpl_dict.items():
+ if torch.is_tensor(value):
+ transformed_dict[key] = value.clone()
+ else:
+ transformed_dict[key] = value
+ else:
+ transformed_dict = smpl_dict
+
+ # Get batch and sequence dimensions from camera coordinate parameters
+ pose_cam = smpl_dict['pose_cam']
+ trans_cam = smpl_dict['trans_cam']
+
+ batch_size = pose_cam.shape[0]
+ seq_len = pose_cam.shape[1]
+
+ # Handle extrinsics shape - ensure it has batch and sequence dimensions
+ if len(extrinsics.shape) == 2: # (4, 4)
+ extrinsics = extrinsics.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, -1, -1)
+ elif len(extrinsics.shape) == 3: # (B, 4, 4)
+ extrinsics = extrinsics.unsqueeze(1).expand(-1, seq_len, -1, -1)
+ elif len(extrinsics.shape) == 4: # (B, S, 4, 4)
+ pass # Already correct shape
+ else:
+ raise ValueError(f"Unsupported extrinsics shape: {extrinsics.shape}")
+
+ # Use closed-form inverse for SE3 matrices instead of torch.inverse
+ extrinsics_flat = extrinsics.view(-1, 4, 4)
+ extrinsics_inv_flat = closed_form_inverse_se3(extrinsics_flat)
+ cam_to_world_extrinsics = extrinsics_inv_flat.view(batch_size, seq_len, 4, 4)
+
+ # Extract rotation and translation from camera-to-world extrinsics
+ cam_to_world_R = cam_to_world_extrinsics[:, :, :3, :3] # (B, S, 3, 3)
+ cam_to_world_t = cam_to_world_extrinsics[:, :, :3, 3] # (B, S, 3)
+
+ # Transform translation from camera space to world space
+ # world_trans = R * cam_trans + t
+ trans_cam_flat = trans_cam.view(batch_size * seq_len, 3, 1) # (B*S, 3, 1)
+ cam_to_world_R_flat = cam_to_world_R.view(batch_size * seq_len, 3, 3) # (B*S, 3, 3)
+ cam_to_world_t_flat = cam_to_world_t.view(batch_size * seq_len, 3, 1) # (B*S, 3, 1)
+
+ # Apply rotation and translation
+ transformed_trans = torch.bmm(cam_to_world_R_flat, trans_cam_flat) + cam_to_world_t_flat
+ transformed_dict['trans_world'] = transformed_trans.view(batch_size, seq_len, 3)
+
+ # Transform pose parameters
+ # Only transform the root joint (global_orient), keep body_pose unchanged
+ if len(pose_cam.shape) == 5 and pose_cam.shape[-2:] == (3, 3):
+ # Rotation matrix format (B, S, N, 3, 3)
+ pose_world = pose_cam.clone()
+
+ # Only transform the first joint (root/global_orient) using similarity transformation
+ root_joint_rot = pose_cam[:, :, 0] # (B, S, 3, 3)
+ root_joint_flat = root_joint_rot.view(batch_size * seq_len, 3, 3) # (B*S, 3, 3)
+ cam_to_world_R_T = cam_to_world_R_flat.transpose(-2, -1) # (B*S, 3, 3)
+
+ # Apply coordinate transformation: R_world = R_cam_to_world @ R_cam
+ # For rotation matrices, coordinate transformation is direct multiplication
+ transformed_root_rot = torch.bmm(cam_to_world_R_flat, root_joint_flat) # R_cam_to_world @ R_cam
+
+ # Replace only the root joint, keep all other joints unchanged
+ pose_world[:, :, 0] = transformed_root_rot.view(batch_size, seq_len, 3, 3)
+ transformed_dict['pose_world'] = pose_world
+
+ elif len(pose_cam.shape) == 3:
+ # Axis-angle format (B, S, N*3) - typically 72 for SMPL or 165 for SMPL-X
+ pose_world = pose_cam.clone()
+
+ # Only transform the first 3 parameters (root joint / global_orient)
+ root_joint_aa = pose_cam[:, :, :3] # (B, S, 3)
+ root_joint_flat = root_joint_aa.view(batch_size * seq_len, 3) # (B*S, 3)
+
+ # Convert root joint from axis-angle to rotation matrix
+ root_joint_rotmat = aa_to_rotmat(root_joint_flat) # (B*S, 3, 3)
+ cam_to_world_R_T = cam_to_world_R_flat.transpose(-2, -1) # (B*S, 3, 3)
+
+ # Apply coordinate transformation: R_world = R_cam_to_world @ R_cam
+ # For rotation matrices, coordinate transformation is direct multiplication
+ transformed_root_rotmat = torch.bmm(cam_to_world_R_flat, root_joint_rotmat) # R_cam_to_world @ R_cam
+
+ # Convert back to axis-angle
+ transformed_root_aa = rotmat_to_aa(transformed_root_rotmat) # (B*S, 3)
+
+ # Replace only the first 3 parameters (root joint), keep all others unchanged
+ pose_world[:, :, :3] = transformed_root_aa.view(batch_size, seq_len, 3)
+ transformed_dict['pose_world'] = pose_world
+
+ else:
+ raise ValueError(f"Unsupported pose format with shape: {pose_cam.shape}")
+
+ # Shape parameters (betas) remain unchanged as they are not affected by coordinate transformations
+ # No need to transform betas
+
+ return transformed_dict