Spaces:
Build error
Build error
murphylmf
commited on
Commit
·
ae166e6
1
Parent(s):
eaea719
Init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +61 -7
- app.py +457 -0
- environment.yml +14 -0
- inference.py +186 -0
- install.sh +74 -0
- packages.txt +7 -0
- requirements.txt +22 -0
- static/teaser.svg +0 -0
- unish/__pycache__/pipeline.cpython-310.pyc +0 -0
- unish/heads/__pycache__/align_net.cpython-310.pyc +0 -0
- unish/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
- unish/heads/__pycache__/head_act.cpython-310.pyc +0 -0
- unish/heads/__pycache__/human_head_cliff.cpython-310.pyc +0 -0
- unish/heads/__pycache__/pose_transformer.cpython-310.pyc +0 -0
- unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc +0 -0
- unish/heads/__pycache__/utils.cpython-310.pyc +0 -0
- unish/heads/__pycache__/vit.cpython-310.pyc +0 -0
- unish/heads/align_net.py +571 -0
- unish/heads/dpt_head.py +500 -0
- unish/heads/head_act.py +125 -0
- unish/heads/human_head_cliff.py +97 -0
- unish/heads/pose_transformer.py +364 -0
- unish/heads/t_cond_mlp.py +199 -0
- unish/heads/utils.py +108 -0
- unish/heads/vit.py +346 -0
- unish/pi3/models/__pycache__/pi3.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/__init__.py +6 -0
- unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/hub/__init__.py +4 -0
- unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/hub/backbones.py +156 -0
- unish/pi3/models/dinov2/hub/utils.py +39 -0
- unish/pi3/models/dinov2/layers/__init__.py +11 -0
- unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
- unish/pi3/models/dinov2/layers/attention.py +89 -0
- unish/pi3/models/dinov2/layers/block.py +259 -0
- unish/pi3/models/dinov2/layers/dino_head.py +58 -0
- unish/pi3/models/dinov2/layers/drop_path.py +34 -0
- unish/pi3/models/dinov2/layers/layer_scale.py +27 -0
- unish/pi3/models/dinov2/layers/mlp.py +40 -0
README.md
CHANGED
|
@@ -1,13 +1,67 @@
|
|
| 1 |
---
|
| 2 |
-
title: UniSH
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license:
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: UniSH (Unified Scene & Human Reconstruction)
|
| 3 |
+
emoji: 🏃♂️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.0.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: cc-by-nc-4.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass
|
| 14 |
+
|
| 15 |
+
<div align="center">
|
| 16 |
+
|
| 17 |
+
Mengfei Li<sup>1</sup>, Peng Li<sup>1</sup>, Zheng Zhang<sup>2</sup>, Jiahao Lu<sup>1</sup>, Chengfeng Zhao<sup>1</sup>, Wei Xue<sup>1</sup>, <br>
|
| 18 |
+
Qifeng Liu<sup>1</sup>, Sida Peng<sup>3</sup>, Wenxiao Zhang<sup>1</sup>, Wenhan Luo<sup>1</sup>, Yuan Liu<sup>1†</sup>, Yike Guo<sup>1†</sup>
|
| 19 |
+
|
| 20 |
+
<sup>1</sup>The Hong Kong University of Science and Technology, <sup>2</sup>Beijing University of Posts and Telecommunications, <sup>3</sup>Zhejiang University
|
| 21 |
+
|
| 22 |
+
<a href="https://murphylmf.github.io/UniSH/"><img src="https://img.shields.io/badge/Project-Page-8A2BE2" alt="Project Page"></a>
|
| 23 |
+
<a href="https://arxiv.org/abs/2601.01222"><img src="https://img.shields.io/badge/arXiv-2601.01222-b31b1b.svg" alt="arXiv"></a>
|
| 24 |
+
<a href="https://github.com/murphylmf/UniSH"><img src="https://img.shields.io/badge/GitHub-Code-black.svg" alt="Code"></a>
|
| 25 |
+
|
| 26 |
+
</div>
|
| 27 |
+
|
| 28 |
+
## Abstract
|
| 29 |
+
|
| 30 |
+
We present UniSH, a unified, feed-forward framework for joint metric-scale 3D scene and human reconstruction. A key challenge in this domain is the scarcity of large-scale, annotated real-world data, forcing a reliance on synthetic datasets. This reliance introduces a significant sim-to-real domain gap, leading to poor generalization, low-fidelity human geometry, and poor alignment on in-the-wild videos.
|
| 31 |
+
|
| 32 |
+
To address this, we propose an innovative training paradigm that effectively leverages unlabeled in-the-wild data. Our framework bridges strong, disparate priors from scene reconstruction and HMR, and is trained with two core components: (1) a robust distillation strategy to refine human surface details by distilling high-frequency details from an expert depth model, and (2) a two-stage supervision scheme, which first learns coarse localization on synthetic data, then fine-tunes on real data by directly optimizing the geometric correspondence between the SMPL mesh and the human point cloud. This approach enables our feed-forward model to jointly recover high-fidelity scene geometry, human point clouds, camera parameters, and coherent, metric-scale SMPL bodies, all in a single forward pass. Extensive experiments demonstrate that our model achieves state-of-the-art performance on human-centric scene reconstruction and delivers highly competitive results on global human motion estimation, comparing favorably against both optimization-based frameworks and HMR-only methods.
|
| 33 |
+
|
| 34 |
+
## Method
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
**The network architecture of UniSH.**
|
| 39 |
+
UniSH takes a monocular video as input. The video frames are processed by the **Reconstruction Branch** to predict per-frame camera extrinsics *E*, confidence maps *C*, and pointmaps *P*. Camera intrinsics *K* are derived from the pointmaps. Human crops from the video are fed into the **Human Body Branch** along with *K* to estimate global SMPL shape parameters *β* and per-frame pose parameters *θ<sub>i</sub>*. Features from both branches are processed by **AlignNet** to predict the global scene scale *s* and per-frame SMPL translations *t<sub>i</sub>* for coherent scene and human alignment.
|
| 40 |
+
|
| 41 |
+
## Usage
|
| 42 |
+
|
| 43 |
+
This Space provides an interactive demo for UniSH.
|
| 44 |
+
|
| 45 |
+
1. **Upload a Video**: Upload a monocular video containing a human.
|
| 46 |
+
2. **Set Duration**: Choose the duration to process (default: 3 seconds).
|
| 47 |
+
3. **Run Inference**: Click "Run Inference" to generate the 3D reconstruction.
|
| 48 |
+
4. **Visualize**: The result will be displayed in an interactive 3D viewer where you can rotate, pan, and zoom.
|
| 49 |
+
|
| 50 |
+
## BibTeX
|
| 51 |
+
|
| 52 |
+
```bibtex
|
| 53 |
+
@misc{li2026unishunifyingscenehuman,
|
| 54 |
+
title={UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass},
|
| 55 |
+
author={Mengfei Li and Peng Li and Zheng Zhang and Jiahao Lu and Chengfeng Zhao and Wei Xue and Qifeng Liu and Sida Peng and Wenxiao Zhang and Wenhan Luo and Yuan Liu and Yike Guo},
|
| 56 |
+
year={2026},
|
| 57 |
+
eprint={2601.01222},
|
| 58 |
+
archivePrefix={arXiv},
|
| 59 |
+
primaryClass={cs.CV},
|
| 60 |
+
url={https://arxiv.org/abs/2601.01222},
|
| 61 |
+
}
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Acknowledgements
|
| 65 |
+
|
| 66 |
+
This website is licensed under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
|
| 67 |
+
Template borrowed from [Nerfies](https://github.com/nerfies/nerfies.github.io).
|
app.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import spaces
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
import subprocess
|
| 10 |
+
import numpy as np
|
| 11 |
+
import trimesh
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
|
| 14 |
+
# Add current directory to path
|
| 15 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 16 |
+
|
| 17 |
+
from unish.utils.inference_utils import (
|
| 18 |
+
load_model,
|
| 19 |
+
process_video,
|
| 20 |
+
run_inference,
|
| 21 |
+
generate_mixed_geometries_in_memory,
|
| 22 |
+
save_smpl_meshes_per_frame,
|
| 23 |
+
save_scene_only_point_clouds,
|
| 24 |
+
save_human_point_clouds,
|
| 25 |
+
save_camera_parameters_per_frame
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
MODEL = None
|
| 29 |
+
BODY_MODELS_PATH = "body_models/"
|
| 30 |
+
|
| 31 |
+
def download_smpl_assets(body_models_path):
|
| 32 |
+
"""
|
| 33 |
+
Download SMPL models from private repository if they don't exist.
|
| 34 |
+
The path logic mimics SMPLWrapper's expectation:
|
| 35 |
+
1. SMPLWrapper appends 'smpl' if not present in body_models_path.
|
| 36 |
+
2. smplx library expects another 'smpl' folder inside that (or appends it).
|
| 37 |
+
Based on existing structure 'body_models/smpl/smpl/SMPL_*.pkl', the target dir is constructed below.
|
| 38 |
+
"""
|
| 39 |
+
if 'smpl' not in body_models_path:
|
| 40 |
+
model_path = os.path.join(body_models_path, 'smpl')
|
| 41 |
+
else:
|
| 42 |
+
model_path = body_models_path
|
| 43 |
+
|
| 44 |
+
# smplx looks for a 'smpl' folder inside the given model_path
|
| 45 |
+
target_dir = os.path.join(model_path, 'smpl')
|
| 46 |
+
|
| 47 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"]
|
| 50 |
+
token = os.environ.get("SMPL_DOWNLOAD_TOKEN")
|
| 51 |
+
|
| 52 |
+
for filename in files:
|
| 53 |
+
file_path = os.path.join(target_dir, filename)
|
| 54 |
+
if not os.path.exists(file_path):
|
| 55 |
+
if not token:
|
| 56 |
+
print(f"Warning: SMPL_DOWNLOAD_TOKEN not set. Cannot download {filename}.")
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
print(f"Downloading {filename} to {target_dir}...")
|
| 60 |
+
try:
|
| 61 |
+
hf_hub_download(
|
| 62 |
+
repo_id="Murphyyyy/UniSH-Private-Assets",
|
| 63 |
+
filename=filename,
|
| 64 |
+
local_dir=target_dir,
|
| 65 |
+
token=token
|
| 66 |
+
)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Failed to download {filename}: {e}")
|
| 69 |
+
|
| 70 |
+
def pack_sequence_to_glb(base_dir, output_path, start_frame=0, end_frame=60, scene_rate=0.5):
|
| 71 |
+
scene = trimesh.Scene()
|
| 72 |
+
|
| 73 |
+
print(f">>> Packing frames {start_frame} to {end_frame}...")
|
| 74 |
+
|
| 75 |
+
valid_count = 0
|
| 76 |
+
|
| 77 |
+
for i in range(start_frame, end_frame):
|
| 78 |
+
frame_node_name = f"frame_{valid_count}"
|
| 79 |
+
|
| 80 |
+
s_path = os.path.join(base_dir, "scene_only_point_clouds", f"scene_only_frame_{i:04d}.ply")
|
| 81 |
+
h_path = os.path.join(base_dir, "human_only_point_clouds", f"human_frame_{i:04d}.ply")
|
| 82 |
+
smpl_path = os.path.join(base_dir, "smpl_meshes_per_frame", f"smpl_mesh_frame_{i:04d}.ply")
|
| 83 |
+
|
| 84 |
+
if not (os.path.exists(h_path) or os.path.exists(smpl_path)):
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
scene.graph.update(frame_node_name, parent="world")
|
| 88 |
+
|
| 89 |
+
if os.path.exists(smpl_path):
|
| 90 |
+
try:
|
| 91 |
+
smpl = trimesh.load(smpl_path)
|
| 92 |
+
flesh_color = [255, 160, 122, 255]
|
| 93 |
+
smpl.visual.vertex_colors = np.tile(flesh_color, (len(smpl.vertices), 1))
|
| 94 |
+
|
| 95 |
+
scene.add_geometry(smpl, node_name=f"{frame_node_name}_smpl", parent_node_name=frame_node_name)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
if os.path.exists(h_path):
|
| 100 |
+
try:
|
| 101 |
+
human = trimesh.load(h_path)
|
| 102 |
+
if isinstance(human, trimesh.PointCloud):
|
| 103 |
+
scene.add_geometry(human, node_name=f"{frame_node_name}_human", parent_node_name=frame_node_name)
|
| 104 |
+
except: pass
|
| 105 |
+
|
| 106 |
+
if os.path.exists(s_path):
|
| 107 |
+
try:
|
| 108 |
+
s_obj = trimesh.load(s_path)
|
| 109 |
+
if isinstance(s_obj, trimesh.PointCloud):
|
| 110 |
+
total_pts = len(s_obj.vertices)
|
| 111 |
+
if total_pts > 0:
|
| 112 |
+
if scene_rate < 0.99:
|
| 113 |
+
count = int(total_pts * scene_rate)
|
| 114 |
+
if count > 100:
|
| 115 |
+
idx = np.random.choice(total_pts, count, replace=False)
|
| 116 |
+
s_obj = trimesh.PointCloud(s_obj.vertices[idx], colors=s_obj.colors[idx])
|
| 117 |
+
scene.add_geometry(s_obj, node_name=f"{frame_node_name}_scene", parent_node_name=frame_node_name)
|
| 118 |
+
except: pass
|
| 119 |
+
|
| 120 |
+
valid_count += 1
|
| 121 |
+
|
| 122 |
+
if valid_count == 0:
|
| 123 |
+
print("Error: No valid frames found.")
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
rot = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0])
|
| 128 |
+
scene.apply_transform(rot)
|
| 129 |
+
except: pass
|
| 130 |
+
|
| 131 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 132 |
+
print(f">>> Exporting to {output_path}...")
|
| 133 |
+
scene.export(output_path)
|
| 134 |
+
print(f">>> Done! Saved {valid_count} frames.")
|
| 135 |
+
|
| 136 |
+
def get_player_html(glb_abs_path):
|
| 137 |
+
html_content = f"""
|
| 138 |
+
<!DOCTYPE html>
|
| 139 |
+
<html>
|
| 140 |
+
<head>
|
| 141 |
+
<meta charset="utf-8">
|
| 142 |
+
<title>UniSH Viewer</title>
|
| 143 |
+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.4/css/bulma.min.css">
|
| 144 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
| 145 |
+
<style>
|
| 146 |
+
#canvas-container {{
|
| 147 |
+
width: 100%;
|
| 148 |
+
height: 600px;
|
| 149 |
+
background: #f5f5f5;
|
| 150 |
+
border-radius: 8px;
|
| 151 |
+
position: relative;
|
| 152 |
+
overflow: hidden;
|
| 153 |
+
box-shadow: inset 0 0 20px rgba(0,0,0,0.05);
|
| 154 |
+
}}
|
| 155 |
+
.slider {{
|
| 156 |
+
width: 100%;
|
| 157 |
+
}}
|
| 158 |
+
</style>
|
| 159 |
+
<script type="importmap">
|
| 160 |
+
{{
|
| 161 |
+
"imports": {{
|
| 162 |
+
"three": "https://unpkg.com/three@0.158.0/build/three.module.js",
|
| 163 |
+
"three/addons/": "https://unpkg.com/three@0.158.0/examples/jsm/"
|
| 164 |
+
}}
|
| 165 |
+
}}
|
| 166 |
+
</script>
|
| 167 |
+
</head>
|
| 168 |
+
<body>
|
| 169 |
+
<div class="box" style="padding: 10px; background: #f5f5f5;">
|
| 170 |
+
<div id="canvas-container">
|
| 171 |
+
<div id="loading-overlay" style="position: absolute; top:0; left:0; width:100%; height:100%; background: rgba(0,0,0,0.7); color: white; display: flex; flex-direction: column; justify-content: center; align-items: center; z-index: 10;">
|
| 172 |
+
<span class="icon is-large"><i class="fas fa-spinner fa-pulse"></i></span>
|
| 173 |
+
<p style="margin-top: 10px;">Loading 3D Sequence...</p>
|
| 174 |
+
</div>
|
| 175 |
+
</div>
|
| 176 |
+
|
| 177 |
+
<div class="columns is-vcentered is-mobile" style="margin-top: 10px; padding: 0 10px;">
|
| 178 |
+
<div class="column is-narrow">
|
| 179 |
+
<button id="play-btn" class="button is-dark is-rounded is-small">
|
| 180 |
+
<span class="icon is-small"><i class="fas fa-play"></i></span>
|
| 181 |
+
</button>
|
| 182 |
+
</div>
|
| 183 |
+
<div class="column">
|
| 184 |
+
<input id="frame-slider" class="slider is-fullwidth is-circle is-dark" step="1" min="0" max="0" value="0" type="range">
|
| 185 |
+
</div>
|
| 186 |
+
<div class="column is-narrow">
|
| 187 |
+
<span id="frame-count" class="tag is-light" style="width: 80px;">Frame: 0</span>
|
| 188 |
+
</div>
|
| 189 |
+
</div>
|
| 190 |
+
</div>
|
| 191 |
+
|
| 192 |
+
<script type="module">
|
| 193 |
+
import * as THREE from 'three';
|
| 194 |
+
import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';
|
| 195 |
+
import {{ GLTFLoader }} from 'three/addons/loaders/GLTFLoader.js';
|
| 196 |
+
|
| 197 |
+
// Inject the model path using f-string from Python
|
| 198 |
+
const MODEL_PATH = "/file={glb_abs_path}";
|
| 199 |
+
const FPS = 10;
|
| 200 |
+
|
| 201 |
+
let scene, camera, renderer, controls;
|
| 202 |
+
let frames = [];
|
| 203 |
+
let currentFrame = 0;
|
| 204 |
+
let isPlaying = false;
|
| 205 |
+
let intervalId = null;
|
| 206 |
+
|
| 207 |
+
const container = document.getElementById('canvas-container');
|
| 208 |
+
const slider = document.getElementById('frame-slider');
|
| 209 |
+
const playBtn = document.getElementById('play-btn');
|
| 210 |
+
const frameLabel = document.getElementById('frame-count');
|
| 211 |
+
const loadingOverlay = document.getElementById('loading-overlay');
|
| 212 |
+
|
| 213 |
+
init();
|
| 214 |
+
|
| 215 |
+
function init() {{
|
| 216 |
+
scene = new THREE.Scene();
|
| 217 |
+
scene.background = new THREE.Color(0xf5f5f5);
|
| 218 |
+
|
| 219 |
+
camera = new THREE.PerspectiveCamera(50, container.clientWidth / container.clientHeight, 0.1, 1000);
|
| 220 |
+
camera.position.set(-0.000, -4.272, 0.000);
|
| 221 |
+
|
| 222 |
+
renderer = new THREE.WebGLRenderer({{ antialias: true, alpha: true }});
|
| 223 |
+
renderer.setSize(container.clientWidth, container.clientHeight);
|
| 224 |
+
renderer.setPixelRatio(window.devicePixelRatio);
|
| 225 |
+
|
| 226 |
+
renderer.shadowMap.enabled = false;
|
| 227 |
+
renderer.useLegacyLights = false;
|
| 228 |
+
|
| 229 |
+
container.appendChild(renderer.domElement);
|
| 230 |
+
|
| 231 |
+
const hemiLight = new THREE.HemisphereLight(0xffffff, 0x444444, 3.0);
|
| 232 |
+
scene.add(hemiLight);
|
| 233 |
+
|
| 234 |
+
const dirLight = new THREE.DirectionalLight(0xffffff, 3.0);
|
| 235 |
+
dirLight.position.set(5, 10, 7);
|
| 236 |
+
scene.add(dirLight);
|
| 237 |
+
|
| 238 |
+
const frontLight = new THREE.DirectionalLight(0xffffff, 2.0);
|
| 239 |
+
frontLight.position.set(0, 0, 5);
|
| 240 |
+
scene.add(frontLight);
|
| 241 |
+
|
| 242 |
+
controls = new OrbitControls(camera, renderer.domElement);
|
| 243 |
+
controls.enableDamping = true;
|
| 244 |
+
controls.dampingFactor = 0.05;
|
| 245 |
+
|
| 246 |
+
controls.target.set(0.000, 0.000, 0.000);
|
| 247 |
+
|
| 248 |
+
const loader = new GLTFLoader();
|
| 249 |
+
console.log("Loading:", MODEL_PATH);
|
| 250 |
+
|
| 251 |
+
loader.load(MODEL_PATH, function (gltf) {{
|
| 252 |
+
const root = gltf.scene;
|
| 253 |
+
scene.add(root);
|
| 254 |
+
|
| 255 |
+
frames = [];
|
| 256 |
+
root.traverse((node) => {{
|
| 257 |
+
|
| 258 |
+
if (node.isMesh) {{
|
| 259 |
+
node.geometry.computeVertexNormals();
|
| 260 |
+
if (node.geometry.attributes.color) {{
|
| 261 |
+
node.geometry.deleteAttribute('color');
|
| 262 |
+
}}
|
| 263 |
+
node.material = new THREE.MeshStandardMaterial({{
|
| 264 |
+
color: 0xff9966,
|
| 265 |
+
roughness: 0.4,
|
| 266 |
+
metalness: 0.0,
|
| 267 |
+
side: THREE.DoubleSide
|
| 268 |
+
}});
|
| 269 |
+
node.material.vertexColors = false;
|
| 270 |
+
}}
|
| 271 |
+
|
| 272 |
+
if (node.isPoints) {{
|
| 273 |
+
if (node.name.toLowerCase().includes('scene')) {{
|
| 274 |
+
node.material.size = 0.05;
|
| 275 |
+
node.material.sizeAttenuation = true;
|
| 276 |
+
}}
|
| 277 |
+
if (node.name.toLowerCase().includes('human')) {{
|
| 278 |
+
node.material.size = 0.005;
|
| 279 |
+
}}
|
| 280 |
+
}}
|
| 281 |
+
|
| 282 |
+
if (node.name && node.name.startsWith('frame_')) {{
|
| 283 |
+
const parts = node.name.split('_');
|
| 284 |
+
if (parts.length === 2 && !isNaN(parseInt(parts[1]))) {{
|
| 285 |
+
const idx = parseInt(parts[1]);
|
| 286 |
+
frames[idx] = node;
|
| 287 |
+
node.visible = false;
|
| 288 |
+
}}
|
| 289 |
+
}}
|
| 290 |
+
}});
|
| 291 |
+
|
| 292 |
+
frames = frames.filter(n => n !== undefined);
|
| 293 |
+
console.log(`Loaded ${{frames.length}} frames.`);
|
| 294 |
+
|
| 295 |
+
if (frames.length > 0) {{
|
| 296 |
+
slider.max = frames.length - 1;
|
| 297 |
+
loadingOverlay.style.display = 'none';
|
| 298 |
+
showFrame(0);
|
| 299 |
+
}} else {{
|
| 300 |
+
loadingOverlay.innerHTML = "<p>No frames found.</p>";
|
| 301 |
+
}}
|
| 302 |
+
|
| 303 |
+
}}, undefined, function (error) {{
|
| 304 |
+
console.error(error);
|
| 305 |
+
loadingOverlay.innerHTML = "<p>Error loading model.</p>";
|
| 306 |
+
}});
|
| 307 |
+
|
| 308 |
+
window.addEventListener('resize', onWindowResize);
|
| 309 |
+
animate();
|
| 310 |
+
}}
|
| 311 |
+
|
| 312 |
+
function showFrame(idx) {{
|
| 313 |
+
if (!frames[idx]) return;
|
| 314 |
+
if (frames[currentFrame]) frames[currentFrame].visible = false;
|
| 315 |
+
frames[idx].visible = true;
|
| 316 |
+
currentFrame = idx;
|
| 317 |
+
slider.value = idx;
|
| 318 |
+
frameLabel.innerText = `Frame: ${{idx}}`;
|
| 319 |
+
}}
|
| 320 |
+
|
| 321 |
+
function togglePlay() {{
|
| 322 |
+
if (frames.length === 0) return;
|
| 323 |
+
isPlaying = !isPlaying;
|
| 324 |
+
|
| 325 |
+
const icon = playBtn.querySelector('.fa-play, .fa-pause');
|
| 326 |
+
|
| 327 |
+
if (isPlaying) {{
|
| 328 |
+
if(icon) {{ icon.classList.remove('fa-play'); icon.classList.add('fa-pause'); }}
|
| 329 |
+
intervalId = setInterval(() => {{
|
| 330 |
+
let next = currentFrame + 1;
|
| 331 |
+
if (next >= frames.length) next = 0;
|
| 332 |
+
showFrame(next);
|
| 333 |
+
}}, 1000 / FPS);
|
| 334 |
+
}} else {{
|
| 335 |
+
if(icon) {{ icon.classList.remove('fa-pause'); icon.classList.add('fa-play'); }}
|
| 336 |
+
clearInterval(intervalId);
|
| 337 |
+
}}
|
| 338 |
+
}}
|
| 339 |
+
|
| 340 |
+
slider.addEventListener('input', (e) => {{
|
| 341 |
+
if (isPlaying) togglePlay();
|
| 342 |
+
showFrame(parseInt(e.target.value));
|
| 343 |
+
}});
|
| 344 |
+
playBtn.addEventListener('click', togglePlay);
|
| 345 |
+
|
| 346 |
+
function onWindowResize() {{
|
| 347 |
+
camera.aspect = container.clientWidth / container.clientHeight;
|
| 348 |
+
camera.updateProjectionMatrix();
|
| 349 |
+
renderer.setSize(container.clientWidth, container.clientHeight);
|
| 350 |
+
}}
|
| 351 |
+
|
| 352 |
+
function animate() {{
|
| 353 |
+
requestAnimationFrame(animate);
|
| 354 |
+
controls.update();
|
| 355 |
+
renderer.render(scene, camera);
|
| 356 |
+
}}
|
| 357 |
+
</script>
|
| 358 |
+
</body>
|
| 359 |
+
</html>
|
| 360 |
+
"""
|
| 361 |
+
return html_content
|
| 362 |
+
|
| 363 |
+
@spaces.GPU(duration=120)
|
| 364 |
+
def predict(video_path, duration_seconds=3.0):
|
| 365 |
+
global MODEL
|
| 366 |
+
|
| 367 |
+
# 0. Setup directories
|
| 368 |
+
output_dir = tempfile.mkdtemp()
|
| 369 |
+
|
| 370 |
+
# 1. Trim video
|
| 371 |
+
duration = min(float(duration_seconds), 10.0)
|
| 372 |
+
trimmed_video_path = os.path.join(output_dir, "input_trimmed.mp4")
|
| 373 |
+
|
| 374 |
+
cmd = [
|
| 375 |
+
"ffmpeg", "-i", video_path,
|
| 376 |
+
"-t", str(duration),
|
| 377 |
+
"-c:v", "libx264", "-c:a", "aac",
|
| 378 |
+
trimmed_video_path, "-y"
|
| 379 |
+
]
|
| 380 |
+
subprocess.run(cmd, check=True)
|
| 381 |
+
|
| 382 |
+
# 2. Load Model
|
| 383 |
+
if MODEL is None:
|
| 384 |
+
MODEL = load_model()
|
| 385 |
+
|
| 386 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 387 |
+
MODEL.to(device)
|
| 388 |
+
MODEL.eval()
|
| 389 |
+
|
| 390 |
+
# 3. Process Video
|
| 391 |
+
fps = 6.0
|
| 392 |
+
target_size = 518
|
| 393 |
+
human_idx = 0
|
| 394 |
+
bbox_scale = 1.0
|
| 395 |
+
|
| 396 |
+
# Check and download SMPL assets
|
| 397 |
+
download_smpl_assets(BODY_MODELS_PATH)
|
| 398 |
+
|
| 399 |
+
data_dict = process_video(
|
| 400 |
+
trimmed_video_path, fps, human_idx, target_size,
|
| 401 |
+
bbox_scale=bbox_scale
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# 4. Run Inference
|
| 405 |
+
results = run_inference(MODEL, data_dict, device, chunk_size=30)
|
| 406 |
+
|
| 407 |
+
# 5. Generate Geometries & Save
|
| 408 |
+
seq_name = results['seq_name']
|
| 409 |
+
|
| 410 |
+
viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
|
| 411 |
+
results, BODY_MODELS_PATH, fps=fps, conf_thres=0.1
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Save to disk
|
| 415 |
+
save_smpl_meshes_per_frame(results, output_dir, BODY_MODELS_PATH)
|
| 416 |
+
save_scene_only_point_clouds(viz_scene_only_point_clouds, output_dir, seq_name)
|
| 417 |
+
save_human_point_clouds(viz_scene_point_clouds, viz_scene_only_point_clouds, output_dir, seq_name, results)
|
| 418 |
+
|
| 419 |
+
# 6. Pack to GLB
|
| 420 |
+
base_dir = os.path.join(output_dir, seq_name)
|
| 421 |
+
output_glb_path = os.path.join(output_dir, "output.glb")
|
| 422 |
+
|
| 423 |
+
num_frames = len(viz_scene_point_clouds)
|
| 424 |
+
|
| 425 |
+
pack_sequence_to_glb(
|
| 426 |
+
base_dir,
|
| 427 |
+
output_glb_path,
|
| 428 |
+
start_frame=0,
|
| 429 |
+
end_frame=num_frames,
|
| 430 |
+
scene_rate=0.5
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return get_player_html(output_glb_path)
|
| 434 |
+
|
| 435 |
+
with gr.Blocks() as demo:
|
| 436 |
+
gr.Markdown("# UniSH Demo")
|
| 437 |
+
gr.Markdown("Upload a video to reconstruct scene and human in 3D.")
|
| 438 |
+
|
| 439 |
+
with gr.Row():
|
| 440 |
+
with gr.Column():
|
| 441 |
+
input_video = gr.Video(label="Input Video")
|
| 442 |
+
duration_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Duration to Process (seconds)")
|
| 443 |
+
submit_btn = gr.Button("Run Inference", variant="primary")
|
| 444 |
+
|
| 445 |
+
with gr.Column():
|
| 446 |
+
output_html = gr.HTML(label="3D Result", min_height=600)
|
| 447 |
+
|
| 448 |
+
submit_btn.click(
|
| 449 |
+
predict,
|
| 450 |
+
inputs=[input_video, duration_slider],
|
| 451 |
+
outputs=[output_html]
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
demo.queue()
|
| 455 |
+
demo.launch()
|
| 456 |
+
|
| 457 |
+
|
environment.yml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: unish
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.10
|
| 7 |
+
- pip
|
| 8 |
+
- git
|
| 9 |
+
- ninja
|
| 10 |
+
- mesalib
|
| 11 |
+
- libgl-devel
|
| 12 |
+
- libegl-devel
|
| 13 |
+
- gxx_linux-64=11.*
|
| 14 |
+
- ffmpeg
|
inference.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import logging
|
| 7 |
+
from unish.utils.inference_utils import *
|
| 8 |
+
|
| 9 |
+
def setup_seed(seed):
|
| 10 |
+
torch.manual_seed(seed)
|
| 11 |
+
torch.cuda.manual_seed_all(seed)
|
| 12 |
+
np.random.seed(seed)
|
| 13 |
+
random.seed(seed)
|
| 14 |
+
torch.backends.cudnn.deterministic = True
|
| 15 |
+
|
| 16 |
+
def setup_logging(output_dir):
|
| 17 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# Create logger
|
| 20 |
+
logger = logging.getLogger()
|
| 21 |
+
logger.setLevel(logging.INFO)
|
| 22 |
+
|
| 23 |
+
# Create handlers
|
| 24 |
+
c_handler = logging.StreamHandler()
|
| 25 |
+
f_handler = logging.FileHandler(os.path.join(output_dir, 'inference.log'), mode='w')
|
| 26 |
+
|
| 27 |
+
# Create formatters and add it to handlers
|
| 28 |
+
c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 29 |
+
f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 30 |
+
c_handler.setFormatter(c_format)
|
| 31 |
+
f_handler.setFormatter(f_format)
|
| 32 |
+
|
| 33 |
+
# Add handlers to the logger
|
| 34 |
+
logger.addHandler(c_handler)
|
| 35 |
+
logger.addHandler(f_handler)
|
| 36 |
+
|
| 37 |
+
return logger
|
| 38 |
+
|
| 39 |
+
def main():
|
| 40 |
+
parser = argparse.ArgumentParser(description="Video Inference Script")
|
| 41 |
+
parser.add_argument("--video_path", type=str, required=True,
|
| 42 |
+
help="Path to the input video file or directory containing images")
|
| 43 |
+
parser.add_argument("--fps", type=float, default=6.0,
|
| 44 |
+
help="Target FPS for frame extraction (default: 6.0)")
|
| 45 |
+
parser.add_argument("--original_fps", type=float, default=30.0,
|
| 46 |
+
help="Original FPS of the image sequence (default: 30.0, used only for directory input)")
|
| 47 |
+
parser.add_argument("--target_size", type=int, default=518,
|
| 48 |
+
help="Target size for frame processing (default: 518)")
|
| 49 |
+
parser.add_argument("--checkpoint", type=str, default="checkpoints/unish_release.safetensors",
|
| 50 |
+
help="Path to the model checkpoint")
|
| 51 |
+
parser.add_argument("--output_dir", type=str, default="inference_results_video",
|
| 52 |
+
help="Output directory for results")
|
| 53 |
+
parser.add_argument("--body_models_path", type=str, default="body_models/",
|
| 54 |
+
help="Path to SMPL body models")
|
| 55 |
+
parser.add_argument("--device", type=str, default="cuda",
|
| 56 |
+
help="Device to run inference on")
|
| 57 |
+
parser.add_argument("--save_results", action="store_true", default=True,
|
| 58 |
+
help="Save additional results including smpl_points_for_camera (default: True)")
|
| 59 |
+
parser.add_argument("--chunk_size", type=int, default=30,
|
| 60 |
+
help="Number of frames to process in each chunk during inference (default: 30)")
|
| 61 |
+
parser.add_argument("--gpu_id", type=int, default=0,
|
| 62 |
+
help="GPU ID to use for inference (default: 0)")
|
| 63 |
+
parser.add_argument("--camera_mode", type=str, default="fixed",
|
| 64 |
+
choices=["predicted", "fixed"],
|
| 65 |
+
help="Camera mode: 'predicted' uses model-predicted camera parameters, "
|
| 66 |
+
"'fixed' uses a fixed camera angle (default: predicted)")
|
| 67 |
+
parser.add_argument("--human_idx", type=int, default=0,
|
| 68 |
+
help="Human index to process (default: 0)")
|
| 69 |
+
parser.add_argument("--start_idx", type=int, default=None,
|
| 70 |
+
help="Start frame index for processing (default: None, process from beginning)")
|
| 71 |
+
parser.add_argument("--end_idx", type=int, default=None,
|
| 72 |
+
help="End frame index for processing (default: None, process to end)")
|
| 73 |
+
parser.add_argument("--bbox_scale", type=float, default=1.0,
|
| 74 |
+
help="Scale factor for bounding box size (default: 1.0)")
|
| 75 |
+
parser.add_argument("--conf_thres", type=float, default=0.1,
|
| 76 |
+
help="Confidence threshold for point cloud generation (default: 0.1)")
|
| 77 |
+
|
| 78 |
+
# New arguments
|
| 79 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
| 80 |
+
parser.add_argument("--yolo_ckpt", type=str, default="ckpts/yolo11n.pt", help="Path to YOLO checkpoint")
|
| 81 |
+
parser.add_argument("--sam2_model", type=str, default="facebook/sam2-hiera-large", help="SAM2 model name or path")
|
| 82 |
+
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
|
| 85 |
+
# Setup seed
|
| 86 |
+
setup_seed(args.seed)
|
| 87 |
+
|
| 88 |
+
# Setup logging
|
| 89 |
+
logger = setup_logging(args.output_dir)
|
| 90 |
+
|
| 91 |
+
# Setup device
|
| 92 |
+
if torch.cuda.is_available():
|
| 93 |
+
if args.device == "cuda":
|
| 94 |
+
# Use specified GPU ID
|
| 95 |
+
device = torch.device(f"cuda:{args.gpu_id}")
|
| 96 |
+
# Set the current CUDA device
|
| 97 |
+
torch.cuda.set_device(args.gpu_id)
|
| 98 |
+
logger.info(
|
| 99 |
+
f"Using GPU {args.gpu_id}: {torch.cuda.get_device_name(args.gpu_id)}")
|
| 100 |
+
else:
|
| 101 |
+
device = torch.device(args.device)
|
| 102 |
+
else:
|
| 103 |
+
device = torch.device("cpu")
|
| 104 |
+
logger.info("CUDA not available, using CPU")
|
| 105 |
+
|
| 106 |
+
logger.info(f"Using device: {device}")
|
| 107 |
+
|
| 108 |
+
# Load model
|
| 109 |
+
logger.info("Loading model...")
|
| 110 |
+
model = load_model(args.checkpoint)
|
| 111 |
+
model = model.to(device)
|
| 112 |
+
model.eval()
|
| 113 |
+
|
| 114 |
+
# Process video
|
| 115 |
+
logger.info(f"Processing video: {args.video_path}")
|
| 116 |
+
data_dict = process_video(
|
| 117 |
+
args.video_path, args.fps, args.human_idx, args.target_size,
|
| 118 |
+
bbox_scale=args.bbox_scale, start_idx=args.start_idx, end_idx=args.end_idx,
|
| 119 |
+
original_fps=args.original_fps,
|
| 120 |
+
yolo_ckpt=args.yolo_ckpt, sam2_model=args.sam2_model
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Run inference
|
| 124 |
+
results = run_inference(model, data_dict, device, args.chunk_size)
|
| 125 |
+
|
| 126 |
+
# Create output directory
|
| 127 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
|
| 130 |
+
results, args.body_models_path, fps=args.fps, conf_thres=args.conf_thres
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Determine camera mode based on arguments
|
| 134 |
+
use_predicted_camera = (args.camera_mode == "predicted")
|
| 135 |
+
logger.info(f"Using {args.camera_mode} camera mode")
|
| 136 |
+
|
| 137 |
+
original_rgb_images = results['rgb_images']
|
| 138 |
+
|
| 139 |
+
if original_rgb_images is not None:
|
| 140 |
+
if hasattr(original_rgb_images, 'permute'): # It's a torch tensor
|
| 141 |
+
original_rgb_images = original_rgb_images.permute(
|
| 142 |
+
0, 2, 3, 1).cpu().numpy() # [S, H, W, 3]
|
| 143 |
+
elif not isinstance(original_rgb_images, np.ndarray):
|
| 144 |
+
original_rgb_images = np.array(original_rgb_images)
|
| 145 |
+
|
| 146 |
+
# Ensure proper data type and range
|
| 147 |
+
if original_rgb_images.max() <= 1.0:
|
| 148 |
+
original_rgb_images = (
|
| 149 |
+
original_rgb_images * 255).astype(np.uint8)
|
| 150 |
+
|
| 151 |
+
original_human_boxes = data_dict['human_boxes']
|
| 152 |
+
|
| 153 |
+
run_visualization(viz_scene_point_clouds, viz_smpl_meshes, smpl_points_for_camera,
|
| 154 |
+
args.output_dir, results['seq_name'],
|
| 155 |
+
fps=args.fps, # Use original fps
|
| 156 |
+
rgb_images=original_rgb_images,
|
| 157 |
+
human_boxes=original_human_boxes,
|
| 158 |
+
chunk_size=args.chunk_size, # Use original chunk size
|
| 159 |
+
results=results,
|
| 160 |
+
use_predicted_camera=use_predicted_camera,
|
| 161 |
+
scene_only_point_clouds=viz_scene_only_point_clouds,
|
| 162 |
+
conf_thres=args.conf_thres)
|
| 163 |
+
|
| 164 |
+
if args.save_results:
|
| 165 |
+
|
| 166 |
+
logger.info("Creating SMPL meshes per frame...")
|
| 167 |
+
save_smpl_meshes_per_frame(
|
| 168 |
+
results, args.output_dir, args.body_models_path)
|
| 169 |
+
|
| 170 |
+
logger.info("Saving scene point clouds (without human)...")
|
| 171 |
+
save_scene_only_point_clouds(
|
| 172 |
+
viz_scene_only_point_clouds, args.output_dir, results['seq_name'])
|
| 173 |
+
|
| 174 |
+
logger.info("Saving human point clouds...")
|
| 175 |
+
save_human_point_clouds(viz_scene_point_clouds,
|
| 176 |
+
viz_scene_only_point_clouds, args.output_dir, results['seq_name'], results)
|
| 177 |
+
|
| 178 |
+
logger.info("Saving camera parameters per frame...")
|
| 179 |
+
save_camera_parameters_per_frame(
|
| 180 |
+
results, args.output_dir, results['seq_name'])
|
| 181 |
+
|
| 182 |
+
logger.info(f"Inference completed! Results saved to {args.output_dir}")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
install.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
# ==========================================
|
| 5 |
+
# UniSH Auto-Install Script
|
| 6 |
+
# ==========================================
|
| 7 |
+
|
| 8 |
+
get_cuda_version() {
|
| 9 |
+
if [ ! -z "$1" ]; then echo "$1"; return; fi
|
| 10 |
+
if command -v nvidia-smi &> /dev/null; then
|
| 11 |
+
DRIVER_CUDA_MAJOR=$(nvidia-smi | grep "CUDA Version" | awk -F'CUDA Version:' '{print $2}' | awk -F'.' '{print $1}' | tr -d '[:space:]')
|
| 12 |
+
if [ "$DRIVER_CUDA_MAJOR" == "12" ]; then echo "12.1"; elif [ "$DRIVER_CUDA_MAJOR" == "11" ]; then echo "11.8"; else echo "12.1"; fi
|
| 13 |
+
else echo "12.1"; fi
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
if [[ -z "$CONDA_PREFIX" ]]; then
|
| 17 |
+
echo "❌ Error: Please activate the conda environment first!"
|
| 18 |
+
exit 1
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
TARGET_CUDA=$(get_cuda_version "$1")
|
| 22 |
+
echo "========================================"
|
| 23 |
+
echo " Detected/Selected CUDA: $TARGET_CUDA"
|
| 24 |
+
echo "========================================"
|
| 25 |
+
|
| 26 |
+
if [[ "$TARGET_CUDA" == "12.1" ]]; then TORCH_INDEX_URL="https://download.pytorch.org/whl/cu121";
|
| 27 |
+
elif [[ "$TARGET_CUDA" == "11.8" ]]; then TORCH_INDEX_URL="https://download.pytorch.org/whl/cu118";
|
| 28 |
+
else TORCH_INDEX_URL=""; fi
|
| 29 |
+
|
| 30 |
+
echo "[1/6] Installing PyTorch 2.4.1 (CUDA $TARGET_CUDA)..."
|
| 31 |
+
pip install torch==2.4.1 torchvision==0.19.1 --index-url $TORCH_INDEX_URL
|
| 32 |
+
|
| 33 |
+
echo "[2/6] Installing Safe Requirements..."
|
| 34 |
+
pip install -r requirements.txt
|
| 35 |
+
|
| 36 |
+
echo "[3/6] Installing Custom Utils3D..."
|
| 37 |
+
pip install "git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183"
|
| 38 |
+
|
| 39 |
+
echo "[4/6] Installing Heavy Dependencies..."
|
| 40 |
+
pip install open3d==0.19.0 --no-deps
|
| 41 |
+
pip install ultralytics==8.3.227 --no-deps
|
| 42 |
+
pip install timm==1.0.24 --no-deps
|
| 43 |
+
|
| 44 |
+
echo "[5/6] Installing MMCV & PyTorch3D..."
|
| 45 |
+
pip install mmcv==2.2.0 --no-deps --no-binary mmcv
|
| 46 |
+
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" --no-build-isolation
|
| 47 |
+
|
| 48 |
+
echo "[6/6] Installing SAM 2 (With Setuptools Fix)..."
|
| 49 |
+
|
| 50 |
+
pip install setuptools==69.5.1 wheel
|
| 51 |
+
rm -rf _tmp_install_sam2
|
| 52 |
+
|
| 53 |
+
mkdir -p _tmp_install_sam2
|
| 54 |
+
cd _tmp_install_sam2
|
| 55 |
+
|
| 56 |
+
echo " -> Cloning SAM 2..."
|
| 57 |
+
git clone https://github.com/facebookresearch/segment-anything-2.git --depth 1
|
| 58 |
+
cd segment-anything-2
|
| 59 |
+
|
| 60 |
+
echo " -> Patching setup.py..."
|
| 61 |
+
python -c "
|
| 62 |
+
path = 'setup.py'
|
| 63 |
+
with open(path, 'r') as f: c = f.read()
|
| 64 |
+
c = c.replace('torch>=2.5.1', 'torch>=2.4.1')
|
| 65 |
+
with open(path, 'w') as f: f.write(c)
|
| 66 |
+
"
|
| 67 |
+
pip install . --no-deps --no-build-isolation
|
| 68 |
+
cd ../..
|
| 69 |
+
rm -rf _tmp_install_sam2
|
| 70 |
+
|
| 71 |
+
echo "========================================"
|
| 72 |
+
echo "Installation Complete!"
|
| 73 |
+
python -c "import torch; print(f'PyTorch: {torch.__version__} | CUDA: {torch.version.cuda}')"
|
| 74 |
+
echo "========================================"
|
packages.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libgl1-mesa-glx
|
| 3 |
+
libglib2.0-0
|
| 4 |
+
libegl1-mesa
|
| 5 |
+
xvfb
|
| 6 |
+
|
| 7 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.1
|
| 2 |
+
torchvision==0.19.1
|
| 3 |
+
numpy
|
| 4 |
+
scipy
|
| 5 |
+
trimesh
|
| 6 |
+
tqdm
|
| 7 |
+
opencv-python-headless
|
| 8 |
+
pillow
|
| 9 |
+
gradio
|
| 10 |
+
spaces
|
| 11 |
+
ninja
|
| 12 |
+
einops
|
| 13 |
+
safetensors
|
| 14 |
+
huggingface_hub
|
| 15 |
+
open3d==0.19.0
|
| 16 |
+
ultralytics==8.3.227
|
| 17 |
+
timm==1.0.24
|
| 18 |
+
git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183
|
| 19 |
+
mmcv==2.2.0 --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html
|
| 20 |
+
pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
|
| 21 |
+
git+https://github.com/facebookresearch/segment-anything-2.git
|
| 22 |
+
smplx
|
static/teaser.svg
ADDED
|
|
unish/__pycache__/pipeline.cpython-310.pyc
ADDED
|
Binary file (6.53 kB). View file
|
|
|
unish/heads/__pycache__/align_net.cpython-310.pyc
ADDED
|
Binary file (13.6 kB). View file
|
|
|
unish/heads/__pycache__/dpt_head.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
unish/heads/__pycache__/head_act.cpython-310.pyc
ADDED
|
Binary file (3.11 kB). View file
|
|
|
unish/heads/__pycache__/human_head_cliff.cpython-310.pyc
ADDED
|
Binary file (2.92 kB). View file
|
|
|
unish/heads/__pycache__/pose_transformer.cpython-310.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc
ADDED
|
Binary file (6.08 kB). View file
|
|
|
unish/heads/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.14 kB). View file
|
|
|
unish/heads/__pycache__/vit.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
unish/heads/align_net.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from unish.utils.data_utils import rot6d_to_rotmat
|
| 8 |
+
from unish.utils.constants import SMPL_MEAN_PARAMS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TimeStepRoPE1D(nn.Module):
|
| 12 |
+
"""1D RoPE for timestep embedding, similar to pi3's RoPE2D but for 1D time sequence"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, freq=100.0):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.base = freq
|
| 17 |
+
self.cache = {}
|
| 18 |
+
self.max_train_len = 120
|
| 19 |
+
|
| 20 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
| 21 |
+
if (D, seq_len, device, dtype) in self.cache:
|
| 22 |
+
return self.cache[D, seq_len, device, dtype]
|
| 23 |
+
|
| 24 |
+
if seq_len <= self.max_train_len:
|
| 25 |
+
assert D % 2 == 0
|
| 26 |
+
|
| 27 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
| 28 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 29 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
| 30 |
+
|
| 31 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 32 |
+
cos = freqs.cos() # (seq_len, D)
|
| 33 |
+
sin = freqs.sin() # (seq_len, D)
|
| 34 |
+
self.cache[D, seq_len, device, dtype] = (cos, sin)
|
| 35 |
+
return cos, sin
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
cos_train, sin_train = self.get_cos_sin(D, self.max_train_len, device, dtype)
|
| 39 |
+
cos_train_res = cos_train.transpose(0, 1).unsqueeze(0)
|
| 40 |
+
sin_train_res = sin_train.transpose(0, 1).unsqueeze(0)
|
| 41 |
+
|
| 42 |
+
# [1, D, max_train_len] -> [1, D, seq_len]
|
| 43 |
+
cos_interp = F.interpolate(cos_train_res, size=seq_len, mode='linear', align_corners=True)
|
| 44 |
+
sin_interp = F.interpolate(sin_train_res, size=seq_len, mode='linear', align_corners=True)
|
| 45 |
+
|
| 46 |
+
# [1, D, seq_len] -> [seq_len, D]
|
| 47 |
+
cos_final = cos_interp.squeeze(0).transpose(0, 1)
|
| 48 |
+
sin_final = sin_interp.squeeze(0).transpose(0, 1)
|
| 49 |
+
|
| 50 |
+
self.cache[D, seq_len, device, dtype] = (cos_final, sin_final)
|
| 51 |
+
return cos_final, sin_final
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def rotate_half(x):
|
| 55 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 56 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 57 |
+
|
| 58 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
| 59 |
+
"""Apply 1D RoPE to tokens based on 1D positions"""
|
| 60 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] # [batch, 1, seq_len, D]
|
| 61 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] # [batch, 1, seq_len, D]
|
| 62 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 63 |
+
|
| 64 |
+
def forward(self, tokens, positions):
|
| 65 |
+
"""
|
| 66 |
+
Apply 1D RoPE to tokens based on timestep positions.
|
| 67 |
+
Args:
|
| 68 |
+
tokens: [batch, num_heads, seq_len, head_dim]
|
| 69 |
+
positions: [batch, seq_len] - timestep positions (0, 1, 2, ...)
|
| 70 |
+
Returns:
|
| 71 |
+
tokens with RoPE applied: [batch, num_heads, seq_len, head_dim]
|
| 72 |
+
"""
|
| 73 |
+
head_dim = tokens.size(3)
|
| 74 |
+
assert head_dim % 2 == 0, "head_dim should be a multiple of two"
|
| 75 |
+
assert positions.ndim == 2 # [batch, seq_len]
|
| 76 |
+
|
| 77 |
+
cos, sin = self.get_cos_sin(head_dim, int(positions.max()) + 1, tokens.device, tokens.dtype)
|
| 78 |
+
|
| 79 |
+
return self.apply_rope1d(tokens, positions.long(), cos, sin)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TransformerDecoderLayer(nn.Module):
|
| 83 |
+
"""单层Transformer Decoder with RoPE support"""
|
| 84 |
+
|
| 85 |
+
def __init__(self, hidden_dim=512, num_heads=8, ff_dim=1024, dropout=0.1, use_rope=True):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
self.use_rope = use_rope
|
| 89 |
+
self.hidden_dim = hidden_dim
|
| 90 |
+
self.num_heads = num_heads
|
| 91 |
+
self.head_dim = hidden_dim // num_heads
|
| 92 |
+
|
| 93 |
+
if use_rope:
|
| 94 |
+
self.self_attention = None
|
| 95 |
+
self.cross_attention = None
|
| 96 |
+
|
| 97 |
+
self.self_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 98 |
+
self.self_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 99 |
+
self.self_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 100 |
+
self.self_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 101 |
+
|
| 102 |
+
self.cross_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 103 |
+
self.cross_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 104 |
+
self.cross_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 105 |
+
self.cross_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 106 |
+
|
| 107 |
+
# RoPE for timestep embedding
|
| 108 |
+
self.timestep_rope = TimeStepRoPE1D(freq=100.0)
|
| 109 |
+
else:
|
| 110 |
+
self.self_attention = nn.MultiheadAttention(
|
| 111 |
+
embed_dim=hidden_dim,
|
| 112 |
+
num_heads=num_heads,
|
| 113 |
+
dropout=dropout,
|
| 114 |
+
batch_first=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.cross_attention = nn.MultiheadAttention(
|
| 118 |
+
embed_dim=hidden_dim,
|
| 119 |
+
num_heads=num_heads,
|
| 120 |
+
dropout=dropout,
|
| 121 |
+
batch_first=True
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.feed_forward = nn.Sequential(
|
| 125 |
+
nn.Linear(hidden_dim, ff_dim),
|
| 126 |
+
nn.ReLU(),
|
| 127 |
+
nn.Dropout(dropout),
|
| 128 |
+
nn.Linear(ff_dim, hidden_dim),
|
| 129 |
+
nn.Dropout(dropout)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.norm1 = nn.LayerNorm(hidden_dim) # for self attention
|
| 133 |
+
self.norm2 = nn.LayerNorm(hidden_dim) # for cross attention
|
| 134 |
+
self.norm3 = nn.LayerNorm(hidden_dim) # for feed forward
|
| 135 |
+
|
| 136 |
+
# Dropout
|
| 137 |
+
self.dropout = nn.Dropout(dropout)
|
| 138 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 139 |
+
|
| 140 |
+
# Scale factor for attention
|
| 141 |
+
self.scale = self.head_dim ** -0.5
|
| 142 |
+
|
| 143 |
+
# Gradient checkpointing flag
|
| 144 |
+
self.use_gradient_checkpoint = False
|
| 145 |
+
|
| 146 |
+
def gradient_checkpointing_enable(self):
|
| 147 |
+
"""Enable gradient checkpointing for memory optimization."""
|
| 148 |
+
self.use_gradient_checkpoint = True
|
| 149 |
+
|
| 150 |
+
def _rope_attention(self, q_proj, k_proj, v_proj, out_proj, query, key, value, timestep_pos=None):
|
| 151 |
+
"""Apply RoPE-based attention using torch.nn.functional.scaled_dot_product_attention"""
|
| 152 |
+
batch_size, seq_len, _ = query.shape
|
| 153 |
+
|
| 154 |
+
# Project Q, K, V
|
| 155 |
+
q = q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 156 |
+
k = k_proj(key).view(batch_size, key.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
|
| 157 |
+
v = v_proj(value).view(batch_size, value.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
|
| 158 |
+
|
| 159 |
+
# Apply RoPE to Q and K if timestep positions are provided
|
| 160 |
+
if timestep_pos is not None and self.use_rope:
|
| 161 |
+
# For self-attention, both q and k use the same timestep positions
|
| 162 |
+
if query.shape == key.shape: # self-attention case
|
| 163 |
+
q = self.timestep_rope(q, timestep_pos)
|
| 164 |
+
k = self.timestep_rope(k, timestep_pos)
|
| 165 |
+
else: # cross-attention case
|
| 166 |
+
# Only apply RoPE to query (cam_token), key/value are spatial features
|
| 167 |
+
q = self.timestep_rope(q, timestep_pos)
|
| 168 |
+
|
| 169 |
+
attn_output = F.scaled_dot_product_attention(
|
| 170 |
+
q, k, v,
|
| 171 |
+
dropout_p=self.attn_dropout.p if self.training else 0.0,
|
| 172 |
+
scale=self.scale
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Reshape output
|
| 176 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
|
| 177 |
+
|
| 178 |
+
# Output projection
|
| 179 |
+
return out_proj(attn_output)
|
| 180 |
+
|
| 181 |
+
def forward(self, query, key, value, self_attn_mask=None, cross_attn_mask=None, timestep_pos=None):
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
query: [batch, num_views, hidden_dim]
|
| 185 |
+
key: [batch, num_views, hidden_dim]
|
| 186 |
+
value: [batch, num_views, hidden_dim]
|
| 187 |
+
timestep_pos: [batch, num_views] - timestep positions for RoPE
|
| 188 |
+
"""
|
| 189 |
+
if self.use_gradient_checkpoint and self.training:
|
| 190 |
+
from torch.utils.checkpoint import checkpoint
|
| 191 |
+
|
| 192 |
+
if self.use_rope:
|
| 193 |
+
# 1. Self Attention + Residual with RoPE (with gradient checkpointing)
|
| 194 |
+
self_attn_output = checkpoint(
|
| 195 |
+
self._rope_attention,
|
| 196 |
+
self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
|
| 197 |
+
query, query, query, timestep_pos,
|
| 198 |
+
use_reentrant=False
|
| 199 |
+
)
|
| 200 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 201 |
+
|
| 202 |
+
# 2. Cross Attention + Residual with RoPE (with gradient checkpointing)
|
| 203 |
+
cross_attn_output = checkpoint(
|
| 204 |
+
self._rope_attention,
|
| 205 |
+
self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
|
| 206 |
+
query, key, value, timestep_pos,
|
| 207 |
+
use_reentrant=False
|
| 208 |
+
)
|
| 209 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 210 |
+
else:
|
| 211 |
+
# 1. Self Attention + Residual (with gradient checkpointing)
|
| 212 |
+
def self_attn_fn(q, k, v):
|
| 213 |
+
out, _ = self.self_attention(q, k, v, attn_mask=self_attn_mask)
|
| 214 |
+
return out
|
| 215 |
+
self_attn_output = checkpoint(self_attn_fn, query, query, query, use_reentrant=False)
|
| 216 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 217 |
+
|
| 218 |
+
# 2. Cross Attention + Residual (with gradient checkpointing)
|
| 219 |
+
def cross_attn_fn(q, k, v):
|
| 220 |
+
out, _ = self.cross_attention(q, k, v, attn_mask=cross_attn_mask)
|
| 221 |
+
return out
|
| 222 |
+
cross_attn_output = checkpoint(cross_attn_fn, query, key, value, use_reentrant=False)
|
| 223 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 224 |
+
|
| 225 |
+
# 3. Feed Forward + Residual (with gradient checkpointing)
|
| 226 |
+
ff_output = checkpoint(self.feed_forward, query, use_reentrant=False)
|
| 227 |
+
query = self.norm3(query + ff_output)
|
| 228 |
+
else:
|
| 229 |
+
# Original implementation without gradient checkpointing
|
| 230 |
+
if self.use_rope:
|
| 231 |
+
# 1. Self Attention + Residual with RoPE
|
| 232 |
+
self_attn_output = self._rope_attention(
|
| 233 |
+
self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
|
| 234 |
+
query, query, query, timestep_pos
|
| 235 |
+
)
|
| 236 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 237 |
+
|
| 238 |
+
# 2. Cross Attention + Residual with RoPE
|
| 239 |
+
cross_attn_output = self._rope_attention(
|
| 240 |
+
self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
|
| 241 |
+
query, key, value, timestep_pos
|
| 242 |
+
)
|
| 243 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 244 |
+
else:
|
| 245 |
+
# 1. Self Attention + Residual (original implementation)
|
| 246 |
+
self_attn_output, _ = self.self_attention(query, query, query, attn_mask=self_attn_mask)
|
| 247 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 248 |
+
|
| 249 |
+
# 2. Cross Attention + Residual (original implementation)
|
| 250 |
+
cross_attn_output, _ = self.cross_attention(query, key, value, attn_mask=cross_attn_mask)
|
| 251 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 252 |
+
|
| 253 |
+
# 3. Feed Forward + Residual
|
| 254 |
+
ff_output = self.feed_forward(query)
|
| 255 |
+
query = self.norm3(query + ff_output)
|
| 256 |
+
|
| 257 |
+
return query
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class CrossViewTransformerDecoderLayer(nn.Module):
|
| 261 |
+
"""Cross-view Transformer Decoder Layer for V4 - handles concatenated tokens from multiple views"""
|
| 262 |
+
|
| 263 |
+
def __init__(self, hidden_dim=512, num_heads=8, ff_dim=1024, dropout=0.1, use_rope=True):
|
| 264 |
+
super().__init__()
|
| 265 |
+
|
| 266 |
+
self.use_rope = use_rope
|
| 267 |
+
self.hidden_dim = hidden_dim
|
| 268 |
+
self.num_heads = num_heads
|
| 269 |
+
self.head_dim = hidden_dim // num_heads
|
| 270 |
+
|
| 271 |
+
if use_rope:
|
| 272 |
+
self.self_attention = None
|
| 273 |
+
self.cross_attention = None
|
| 274 |
+
|
| 275 |
+
# Self-attention components
|
| 276 |
+
self.self_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 277 |
+
self.self_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 278 |
+
self.self_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 279 |
+
self.self_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 280 |
+
|
| 281 |
+
# Cross-attention components
|
| 282 |
+
self.cross_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 283 |
+
self.cross_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 284 |
+
self.cross_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 285 |
+
self.cross_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 286 |
+
|
| 287 |
+
# RoPE for timestep embedding
|
| 288 |
+
self.timestep_rope = TimeStepRoPE1D(freq=100.0)
|
| 289 |
+
else:
|
| 290 |
+
# Self Attention层
|
| 291 |
+
self.self_attention = nn.MultiheadAttention(
|
| 292 |
+
embed_dim=hidden_dim,
|
| 293 |
+
num_heads=num_heads,
|
| 294 |
+
dropout=dropout,
|
| 295 |
+
batch_first=True
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Cross Attention层
|
| 299 |
+
self.cross_attention = nn.MultiheadAttention(
|
| 300 |
+
embed_dim=hidden_dim,
|
| 301 |
+
num_heads=num_heads,
|
| 302 |
+
dropout=dropout,
|
| 303 |
+
batch_first=True
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
self.feed_forward = nn.Sequential(
|
| 307 |
+
nn.Linear(hidden_dim, ff_dim),
|
| 308 |
+
nn.ReLU(),
|
| 309 |
+
nn.Dropout(dropout),
|
| 310 |
+
nn.Linear(ff_dim, hidden_dim),
|
| 311 |
+
nn.Dropout(dropout)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
self.norm1 = nn.LayerNorm(hidden_dim) # for self attention
|
| 315 |
+
self.norm2 = nn.LayerNorm(hidden_dim) # for cross attention
|
| 316 |
+
self.norm3 = nn.LayerNorm(hidden_dim) # for feed forward
|
| 317 |
+
|
| 318 |
+
self.dropout = nn.Dropout(dropout)
|
| 319 |
+
self.attn_dropout = nn.Dropout(dropout)
|
| 320 |
+
|
| 321 |
+
self.scale = self.head_dim ** -0.5
|
| 322 |
+
|
| 323 |
+
self.use_gradient_checkpoint = False
|
| 324 |
+
|
| 325 |
+
def gradient_checkpointing_enable(self):
|
| 326 |
+
"""Enable gradient checkpointing for memory optimization."""
|
| 327 |
+
self.use_gradient_checkpoint = True
|
| 328 |
+
|
| 329 |
+
def _rope_attention(self, q_proj, k_proj, v_proj, out_proj, query, key, value, query_timestep_pos=None, key_timestep_pos=None):
|
| 330 |
+
"""Apply RoPE-based attention for cross-view scenarios using torch.nn.functional.scaled_dot_product_attention"""
|
| 331 |
+
batch_size, query_seq_len, _ = query.shape
|
| 332 |
+
_, key_seq_len, _ = key.shape
|
| 333 |
+
|
| 334 |
+
# Project Q, K, V
|
| 335 |
+
q = q_proj(query).view(batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 336 |
+
k = k_proj(key).view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 337 |
+
v = v_proj(value).view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 338 |
+
|
| 339 |
+
# Apply RoPE to Q and K if timestep positions are provided
|
| 340 |
+
if self.use_rope:
|
| 341 |
+
if query_timestep_pos is not None:
|
| 342 |
+
q_scale = q[:, :, 0:1, :] # [batch, num_heads, 1, head_dim] - scale token
|
| 343 |
+
q_cam = q[:, :, 1:, :] # [batch, num_heads, num_views, head_dim] - cam tokens
|
| 344 |
+
|
| 345 |
+
cam_timestep_pos = query_timestep_pos[:, 1:]
|
| 346 |
+
q_cam_rope = self.timestep_rope(q_cam, cam_timestep_pos)
|
| 347 |
+
|
| 348 |
+
q = torch.cat([q_scale, q_cam_rope], dim=2)
|
| 349 |
+
if key_timestep_pos is not None:
|
| 350 |
+
k = self.timestep_rope(k, key_timestep_pos)
|
| 351 |
+
|
| 352 |
+
attn_output = F.scaled_dot_product_attention(
|
| 353 |
+
q, k, v,
|
| 354 |
+
dropout_p=self.attn_dropout.p if self.training else 0.0,
|
| 355 |
+
scale=self.scale
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Reshape output
|
| 359 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_seq_len, self.hidden_dim)
|
| 360 |
+
|
| 361 |
+
# Output projection
|
| 362 |
+
return out_proj(attn_output)
|
| 363 |
+
|
| 364 |
+
def forward(self, query, key, value, query_timestep_pos=None, key_timestep_pos=None):
|
| 365 |
+
"""
|
| 366 |
+
Args:
|
| 367 |
+
query: [batch, num_queries, hidden_dim] - cam tokens + scale token
|
| 368 |
+
key: [batch, num_views * num_tokens, hidden_dim] - concatenated feature tokens from all views
|
| 369 |
+
value: [batch, num_views * num_tokens, hidden_dim] - concatenated feature tokens from all views
|
| 370 |
+
query_timestep_pos: [batch, num_queries] - timestep positions for query tokens
|
| 371 |
+
key_timestep_pos: [batch, num_views * num_tokens] - timestep positions for key/value tokens
|
| 372 |
+
"""
|
| 373 |
+
if self.use_gradient_checkpoint and self.training:
|
| 374 |
+
from torch.utils.checkpoint import checkpoint
|
| 375 |
+
|
| 376 |
+
if self.use_rope:
|
| 377 |
+
# 1. Self Attention + Residual with RoPE (with gradient checkpointing)
|
| 378 |
+
self_attn_output = checkpoint(
|
| 379 |
+
self._rope_attention,
|
| 380 |
+
self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
|
| 381 |
+
query, query, query, query_timestep_pos, query_timestep_pos,
|
| 382 |
+
use_reentrant=False
|
| 383 |
+
)
|
| 384 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 385 |
+
|
| 386 |
+
# 2. Cross Attention + Residual with RoPE (with gradient checkpointing)
|
| 387 |
+
cross_attn_output = checkpoint(
|
| 388 |
+
self._rope_attention,
|
| 389 |
+
self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
|
| 390 |
+
query, key, value, query_timestep_pos, key_timestep_pos,
|
| 391 |
+
use_reentrant=False
|
| 392 |
+
)
|
| 393 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 394 |
+
else:
|
| 395 |
+
# 1. Self Attention + Residual (with gradient checkpointing)
|
| 396 |
+
def self_attn_fn(q, k, v):
|
| 397 |
+
out, _ = self.self_attention(q, k, v)
|
| 398 |
+
return out
|
| 399 |
+
self_attn_output = checkpoint(self_attn_fn, query, query, query, use_reentrant=False)
|
| 400 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 401 |
+
|
| 402 |
+
# 2. Cross Attention + Residual (with gradient checkpointing)
|
| 403 |
+
def cross_attn_fn(q, k, v):
|
| 404 |
+
out, _ = self.cross_attention(q, k, v)
|
| 405 |
+
return out
|
| 406 |
+
cross_attn_output = checkpoint(cross_attn_fn, query, key, value, use_reentrant=False)
|
| 407 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 408 |
+
|
| 409 |
+
# 3. Feed Forward + Residual (with gradient checkpointing)
|
| 410 |
+
ff_output = checkpoint(self.feed_forward, query, use_reentrant=False)
|
| 411 |
+
query = self.norm3(query + ff_output)
|
| 412 |
+
else:
|
| 413 |
+
# Original implementation without gradient checkpointing
|
| 414 |
+
if self.use_rope:
|
| 415 |
+
# 1. Self Attention + Residual with RoPE
|
| 416 |
+
self_attn_output = self._rope_attention(
|
| 417 |
+
self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
|
| 418 |
+
query, query, query, query_timestep_pos, query_timestep_pos
|
| 419 |
+
)
|
| 420 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 421 |
+
|
| 422 |
+
# 2. Cross Attention + Residual with RoPE
|
| 423 |
+
cross_attn_output = self._rope_attention(
|
| 424 |
+
self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
|
| 425 |
+
query, key, value, query_timestep_pos, key_timestep_pos
|
| 426 |
+
)
|
| 427 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 428 |
+
else:
|
| 429 |
+
# 1. Self Attention + Residual (original implementation)
|
| 430 |
+
self_attn_output, _ = self.self_attention(query, query, query)
|
| 431 |
+
query = self.norm1(query + self.dropout(self_attn_output))
|
| 432 |
+
|
| 433 |
+
# 2. Cross Attention + Residual (original implementation)
|
| 434 |
+
cross_attn_output, _ = self.cross_attention(query, key, value)
|
| 435 |
+
query = self.norm2(query + self.dropout(cross_attn_output))
|
| 436 |
+
|
| 437 |
+
# 3. Feed Forward + Residual
|
| 438 |
+
ff_output = self.feed_forward(query)
|
| 439 |
+
query = self.norm3(query + ff_output)
|
| 440 |
+
|
| 441 |
+
return query
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class AlignNet(nn.Module):
|
| 445 |
+
def __init__(self, aggregated_dim=2048, cam_dim=1024, hidden_dim=512, num_heads=8, ff_dim=512, dropout=0.1, use_rope=True, num_decoder_layers=2):
|
| 446 |
+
super().__init__()
|
| 447 |
+
|
| 448 |
+
self.use_rope = use_rope
|
| 449 |
+
self.hidden_dim = hidden_dim
|
| 450 |
+
self.num_decoder_layers = num_decoder_layers
|
| 451 |
+
|
| 452 |
+
self.scale_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
| 453 |
+
|
| 454 |
+
self.cam_feature_adapter = nn.Sequential(
|
| 455 |
+
nn.LayerNorm(cam_dim),
|
| 456 |
+
nn.Linear(cam_dim, hidden_dim),
|
| 457 |
+
nn.ReLU(),
|
| 458 |
+
nn.Dropout(dropout)
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
self.patch_feature_adapter = nn.Sequential(
|
| 462 |
+
nn.LayerNorm(aggregated_dim),
|
| 463 |
+
nn.Linear(aggregated_dim, hidden_dim),
|
| 464 |
+
nn.ReLU(),
|
| 465 |
+
nn.Dropout(dropout)
|
| 466 |
+
)
|
| 467 |
+
self.register_feature_adapter = nn.Sequential(
|
| 468 |
+
nn.LayerNorm(aggregated_dim),
|
| 469 |
+
nn.Linear(aggregated_dim, hidden_dim),
|
| 470 |
+
nn.ReLU(),
|
| 471 |
+
nn.Dropout(dropout)
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
self.decoder_layers = nn.ModuleList([
|
| 475 |
+
CrossViewTransformerDecoderLayer(hidden_dim, num_heads, ff_dim, dropout, use_rope=use_rope)
|
| 476 |
+
for _ in range(num_decoder_layers)
|
| 477 |
+
])
|
| 478 |
+
|
| 479 |
+
mean_params = SMPL_MEAN_PARAMS
|
| 480 |
+
init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
|
| 481 |
+
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
|
| 482 |
+
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
|
| 483 |
+
self.register_buffer('init_body_pose', init_body_pose)
|
| 484 |
+
self.register_buffer('init_betas', init_betas)
|
| 485 |
+
self.register_buffer('init_cam', init_cam)
|
| 486 |
+
|
| 487 |
+
self.trans_head = nn.Linear(hidden_dim, 3)
|
| 488 |
+
|
| 489 |
+
self.scale_head = nn.Linear(hidden_dim, 1)
|
| 490 |
+
|
| 491 |
+
self.joint_conversion_fn = rot6d_to_rotmat
|
| 492 |
+
|
| 493 |
+
def gradient_checkpointing_enable(self):
|
| 494 |
+
"""Enable gradient checkpointing for memory optimization."""
|
| 495 |
+
for layer in self.decoder_layers:
|
| 496 |
+
if hasattr(layer, 'gradient_checkpointing_enable'):
|
| 497 |
+
layer.gradient_checkpointing_enable()
|
| 498 |
+
|
| 499 |
+
def forward(self, hidden_tokens, cam_token, fps=6.0):
|
| 500 |
+
batch_size, num_views, num_tokens, _ = hidden_tokens.shape
|
| 501 |
+
|
| 502 |
+
register_tokens = hidden_tokens[:, :, :5, :]
|
| 503 |
+
patch_tokens = hidden_tokens[:, :, 5:, :]
|
| 504 |
+
|
| 505 |
+
if cam_token.dim() == 4:
|
| 506 |
+
cam_token = cam_token.squeeze(2) # [batch, num_views, 1, 1024] -> [batch, num_views, 1024]
|
| 507 |
+
|
| 508 |
+
cam_adapted = self.cam_feature_adapter(cam_token) # [batch, num_views, hidden_dim]
|
| 509 |
+
|
| 510 |
+
patch_tokens_reshaped = patch_tokens.view(batch_size * num_views, patch_tokens.shape[2], -1) # [batch*num_views, 777, 2048]
|
| 511 |
+
patch_adapted_tokens = self.patch_feature_adapter(patch_tokens_reshaped) # [batch*num_views, 777, hidden_dim]
|
| 512 |
+
patch_adapted_tokens = patch_adapted_tokens.view(batch_size, num_views, patch_tokens.shape[2], -1) # [batch, num_views, 777, hidden_dim]
|
| 513 |
+
|
| 514 |
+
register_tokens_reshaped = register_tokens.view(batch_size * num_views, 5, -1) # [batch*num_views, 5, 2048]
|
| 515 |
+
register_adapted_tokens = self.register_feature_adapter(register_tokens_reshaped) # [batch*num_views, 5, hidden_dim]
|
| 516 |
+
register_adapted_tokens = register_adapted_tokens.view(batch_size, num_views, 5, -1) # [batch, num_views, 5, hidden_dim]
|
| 517 |
+
|
| 518 |
+
fused_features_per_view = torch.cat([register_adapted_tokens, patch_adapted_tokens], dim=2) # [batch, num_views, 782, hidden_dim]
|
| 519 |
+
|
| 520 |
+
concatenated_features = fused_features_per_view.view(batch_size, num_views * num_tokens, -1)
|
| 521 |
+
|
| 522 |
+
scale_token_expanded = self.scale_token.expand(batch_size, -1, -1)
|
| 523 |
+
|
| 524 |
+
query_tokens = torch.cat([scale_token_expanded, cam_adapted], dim=1)
|
| 525 |
+
|
| 526 |
+
if self.use_rope:
|
| 527 |
+
base_fps = 6.0
|
| 528 |
+
|
| 529 |
+
time_scale = base_fps / fps
|
| 530 |
+
|
| 531 |
+
scale_timestep = torch.zeros((batch_size, 1), device=cam_adapted.device, dtype=torch.long)
|
| 532 |
+
|
| 533 |
+
cam_timestep_float = torch.arange(num_views, device=cam_adapted.device, dtype=torch.float32) * time_scale
|
| 534 |
+
cam_timestep = cam_timestep_float.round().long().unsqueeze(0).expand(batch_size, -1)
|
| 535 |
+
query_timestep_pos = torch.cat([scale_timestep, cam_timestep], dim=1) # [batch, 1 + num_views]
|
| 536 |
+
|
| 537 |
+
key_timestep_base_float = torch.arange(num_views, device=cam_adapted.device, dtype=torch.float32) * time_scale
|
| 538 |
+
key_timestep_base = key_timestep_base_float.round().long()
|
| 539 |
+
key_timestep_pos = key_timestep_base.unsqueeze(1).expand(-1, num_tokens).flatten()
|
| 540 |
+
key_timestep_pos = key_timestep_pos.unsqueeze(0).expand(batch_size, -1) # [batch, num_views * num_tokens]
|
| 541 |
+
else:
|
| 542 |
+
query_timestep_pos = None
|
| 543 |
+
key_timestep_pos = None
|
| 544 |
+
|
| 545 |
+
decoder_output = query_tokens
|
| 546 |
+
for i, layer in enumerate(self.decoder_layers):
|
| 547 |
+
residual = decoder_output
|
| 548 |
+
|
| 549 |
+
decoder_output = layer(
|
| 550 |
+
decoder_output, concatenated_features, concatenated_features,
|
| 551 |
+
query_timestep_pos=query_timestep_pos, key_timestep_pos=key_timestep_pos
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
decoder_output = decoder_output + residual
|
| 555 |
+
|
| 556 |
+
scale_output = decoder_output[:, 0, :]
|
| 557 |
+
cam_outputs = decoder_output[:, 1:, :]
|
| 558 |
+
|
| 559 |
+
scale_logits = self.scale_head(scale_output) # [batch, 1]
|
| 560 |
+
scale = F.softplus(scale_logits)
|
| 561 |
+
|
| 562 |
+
trans_raw = self.trans_head(cam_outputs) # [batch, num_views, 3]
|
| 563 |
+
xy, z = trans_raw.split([2, 1], dim=-1) # xy: [batch, num_views, 2], z: [batch, num_views, 1]
|
| 564 |
+
z = torch.exp(z)
|
| 565 |
+
trans = torch.cat([xy * z, z], dim=-1) # [batch, num_views, 3]
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
return {
|
| 569 |
+
"scale": scale, # [batch, 1]
|
| 570 |
+
"trans_cam": trans, # [batch, num_views, 3]
|
| 571 |
+
}
|
unish/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import List, Dict, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from .head_act import activate_head
|
| 18 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DPTHead(nn.Module):
|
| 22 |
+
"""
|
| 23 |
+
DPT Head for dense prediction tasks.
|
| 24 |
+
|
| 25 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 26 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 27 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dim_in (int): Input dimension (channels).
|
| 31 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 32 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 33 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 34 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 35 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 36 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 37 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 38 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 39 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 40 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim_in: int,
|
| 46 |
+
patch_size: int = 14,
|
| 47 |
+
output_dim: int = 4,
|
| 48 |
+
activation: str = "inv_log",
|
| 49 |
+
conf_activation: str = "expp1",
|
| 50 |
+
features: int = 256,
|
| 51 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 52 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 53 |
+
pos_embed: bool = True,
|
| 54 |
+
feature_only: bool = False,
|
| 55 |
+
down_ratio: int = 1,
|
| 56 |
+
) -> None:
|
| 57 |
+
super(DPTHead, self).__init__()
|
| 58 |
+
self.patch_size = patch_size
|
| 59 |
+
self.activation = activation
|
| 60 |
+
self.conf_activation = conf_activation
|
| 61 |
+
self.pos_embed = pos_embed
|
| 62 |
+
self.feature_only = feature_only
|
| 63 |
+
self.down_ratio = down_ratio
|
| 64 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 65 |
+
|
| 66 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 67 |
+
|
| 68 |
+
# Projection layers for each output channel from tokens.
|
| 69 |
+
self.projects = nn.ModuleList(
|
| 70 |
+
[
|
| 71 |
+
nn.Conv2d(
|
| 72 |
+
in_channels=dim_in,
|
| 73 |
+
out_channels=oc,
|
| 74 |
+
kernel_size=1,
|
| 75 |
+
stride=1,
|
| 76 |
+
padding=0,
|
| 77 |
+
)
|
| 78 |
+
for oc in out_channels
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Resize layers for upsampling feature maps.
|
| 83 |
+
self.resize_layers = nn.ModuleList(
|
| 84 |
+
[
|
| 85 |
+
nn.ConvTranspose2d(
|
| 86 |
+
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
| 87 |
+
),
|
| 88 |
+
nn.ConvTranspose2d(
|
| 89 |
+
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
| 90 |
+
),
|
| 91 |
+
nn.Identity(),
|
| 92 |
+
nn.Conv2d(
|
| 93 |
+
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
| 94 |
+
),
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
self.scratch = _make_scratch(
|
| 99 |
+
out_channels,
|
| 100 |
+
features,
|
| 101 |
+
expand=False,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Attach additional modules to scratch.
|
| 105 |
+
self.scratch.stem_transpose = None
|
| 106 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 107 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 108 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 109 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 110 |
+
|
| 111 |
+
head_features_1 = features
|
| 112 |
+
head_features_2 = 32
|
| 113 |
+
|
| 114 |
+
if feature_only:
|
| 115 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
| 116 |
+
else:
|
| 117 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 118 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 119 |
+
)
|
| 120 |
+
conv2_in_channels = head_features_1 // 2
|
| 121 |
+
|
| 122 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 123 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 124 |
+
nn.ReLU(inplace=True),
|
| 125 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 131 |
+
images: torch.Tensor,
|
| 132 |
+
patch_start_idx: int,
|
| 133 |
+
frames_chunk_size: int = 8,
|
| 134 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 135 |
+
"""
|
| 136 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 137 |
+
Args:
|
| 138 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 139 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 140 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 141 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 142 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 143 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 147 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 148 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 149 |
+
"""
|
| 150 |
+
B, S, _, H, W = images.shape
|
| 151 |
+
|
| 152 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 153 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 154 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 155 |
+
|
| 156 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 157 |
+
assert frames_chunk_size > 0
|
| 158 |
+
|
| 159 |
+
# Process frames in batches
|
| 160 |
+
all_preds = []
|
| 161 |
+
all_conf = []
|
| 162 |
+
|
| 163 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 164 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 165 |
+
|
| 166 |
+
# Process batch of frames
|
| 167 |
+
if self.feature_only:
|
| 168 |
+
chunk_output = self._forward_impl(
|
| 169 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 170 |
+
)
|
| 171 |
+
all_preds.append(chunk_output)
|
| 172 |
+
else:
|
| 173 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 174 |
+
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
| 175 |
+
)
|
| 176 |
+
all_preds.append(chunk_preds)
|
| 177 |
+
all_conf.append(chunk_conf)
|
| 178 |
+
|
| 179 |
+
# Concatenate results along the sequence dimension
|
| 180 |
+
if self.feature_only:
|
| 181 |
+
return torch.cat(all_preds, dim=1)
|
| 182 |
+
else:
|
| 183 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 184 |
+
|
| 185 |
+
def _forward_impl(
|
| 186 |
+
self,
|
| 187 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 188 |
+
images: torch.Tensor,
|
| 189 |
+
patch_start_idx: int,
|
| 190 |
+
frames_start_idx: int = None,
|
| 191 |
+
frames_end_idx: int = None,
|
| 192 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 193 |
+
"""
|
| 194 |
+
Implementation of the forward pass through the DPT head.
|
| 195 |
+
|
| 196 |
+
This method processes a specific chunk of frames from the sequence.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 200 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 201 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 202 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 203 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 207 |
+
"""
|
| 208 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 209 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 210 |
+
|
| 211 |
+
B, S, _, H, W = images.shape
|
| 212 |
+
|
| 213 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 214 |
+
|
| 215 |
+
out = []
|
| 216 |
+
dpt_idx = 0
|
| 217 |
+
|
| 218 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 219 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 220 |
+
|
| 221 |
+
x = x.to(self.projects[0].weight.dtype)
|
| 222 |
+
|
| 223 |
+
# Select frames if processing a chunk
|
| 224 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 225 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 226 |
+
|
| 227 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
| 228 |
+
|
| 229 |
+
x = self.norm(x)
|
| 230 |
+
|
| 231 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 232 |
+
|
| 233 |
+
x = self.projects[dpt_idx](x)
|
| 234 |
+
if self.pos_embed:
|
| 235 |
+
x = self._apply_pos_embed(x, W, H).to(self.projects[0].weight.dtype)
|
| 236 |
+
|
| 237 |
+
x = self.resize_layers[dpt_idx](x)
|
| 238 |
+
|
| 239 |
+
out.append(x)
|
| 240 |
+
dpt_idx += 1
|
| 241 |
+
|
| 242 |
+
# Fuse features from multiple layers.
|
| 243 |
+
out = self.scratch_forward(out)
|
| 244 |
+
# Interpolate fused output to match target image resolution.
|
| 245 |
+
out = custom_interpolate(
|
| 246 |
+
out,
|
| 247 |
+
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
| 248 |
+
mode="bilinear",
|
| 249 |
+
align_corners=True,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if self.pos_embed:
|
| 253 |
+
out = self._apply_pos_embed(out, W, H).to(self.projects[0].weight.dtype)
|
| 254 |
+
|
| 255 |
+
if self.feature_only:
|
| 256 |
+
return out.view(B, S, *out.shape[1:])
|
| 257 |
+
|
| 258 |
+
out = self.scratch.output_conv2(out)
|
| 259 |
+
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
| 260 |
+
|
| 261 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 262 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 263 |
+
return preds, conf
|
| 264 |
+
|
| 265 |
+
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 266 |
+
"""
|
| 267 |
+
Apply positional embedding to tensor x.
|
| 268 |
+
"""
|
| 269 |
+
patch_w = x.shape[-1]
|
| 270 |
+
patch_h = x.shape[-2]
|
| 271 |
+
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 272 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 273 |
+
pos_embed = pos_embed * ratio
|
| 274 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 275 |
+
return x + pos_embed
|
| 276 |
+
|
| 277 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 278 |
+
"""
|
| 279 |
+
Forward pass through the fusion blocks.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Tensor: Fused feature map.
|
| 286 |
+
"""
|
| 287 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 288 |
+
|
| 289 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 290 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 291 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 292 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 293 |
+
|
| 294 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 295 |
+
del layer_4_rn, layer_4
|
| 296 |
+
|
| 297 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 298 |
+
del layer_3_rn, layer_3
|
| 299 |
+
|
| 300 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 301 |
+
del layer_2_rn, layer_2
|
| 302 |
+
|
| 303 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 304 |
+
del layer_1_rn, layer_1
|
| 305 |
+
|
| 306 |
+
out = self.scratch.output_conv1(out)
|
| 307 |
+
return out
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
################################################################################
|
| 311 |
+
# Modules
|
| 312 |
+
################################################################################
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
| 316 |
+
return FeatureFusionBlock(
|
| 317 |
+
features,
|
| 318 |
+
nn.ReLU(inplace=True),
|
| 319 |
+
deconv=False,
|
| 320 |
+
bn=False,
|
| 321 |
+
expand=False,
|
| 322 |
+
align_corners=True,
|
| 323 |
+
size=size,
|
| 324 |
+
has_residual=has_residual,
|
| 325 |
+
groups=groups,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
| 330 |
+
scratch = nn.Module()
|
| 331 |
+
out_shape1 = out_shape
|
| 332 |
+
out_shape2 = out_shape
|
| 333 |
+
out_shape3 = out_shape
|
| 334 |
+
if len(in_shape) >= 4:
|
| 335 |
+
out_shape4 = out_shape
|
| 336 |
+
|
| 337 |
+
if expand:
|
| 338 |
+
out_shape1 = out_shape
|
| 339 |
+
out_shape2 = out_shape * 2
|
| 340 |
+
out_shape3 = out_shape * 4
|
| 341 |
+
if len(in_shape) >= 4:
|
| 342 |
+
out_shape4 = out_shape * 8
|
| 343 |
+
|
| 344 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 345 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 346 |
+
)
|
| 347 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 348 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 349 |
+
)
|
| 350 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 351 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 352 |
+
)
|
| 353 |
+
if len(in_shape) >= 4:
|
| 354 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 355 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
| 356 |
+
)
|
| 357 |
+
return scratch
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class ResidualConvUnit(nn.Module):
|
| 361 |
+
"""Residual convolution module."""
|
| 362 |
+
|
| 363 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 364 |
+
"""Init.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
features (int): number of features
|
| 368 |
+
"""
|
| 369 |
+
super().__init__()
|
| 370 |
+
|
| 371 |
+
self.bn = bn
|
| 372 |
+
self.groups = groups
|
| 373 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 374 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
| 375 |
+
|
| 376 |
+
self.norm1 = None
|
| 377 |
+
self.norm2 = None
|
| 378 |
+
|
| 379 |
+
self.activation = activation
|
| 380 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 381 |
+
|
| 382 |
+
def forward(self, x):
|
| 383 |
+
"""Forward pass.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
x (tensor): input
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
tensor: output
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
out = self.activation(x)
|
| 393 |
+
out = self.conv1(out)
|
| 394 |
+
if self.norm1 is not None:
|
| 395 |
+
out = self.norm1(out)
|
| 396 |
+
|
| 397 |
+
out = self.activation(out)
|
| 398 |
+
out = self.conv2(out)
|
| 399 |
+
if self.norm2 is not None:
|
| 400 |
+
out = self.norm2(out)
|
| 401 |
+
|
| 402 |
+
return self.skip_add.add(out, x)
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class FeatureFusionBlock(nn.Module):
|
| 406 |
+
"""Feature fusion block."""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
features,
|
| 411 |
+
activation,
|
| 412 |
+
deconv=False,
|
| 413 |
+
bn=False,
|
| 414 |
+
expand=False,
|
| 415 |
+
align_corners=True,
|
| 416 |
+
size=None,
|
| 417 |
+
has_residual=True,
|
| 418 |
+
groups=1,
|
| 419 |
+
):
|
| 420 |
+
"""Init.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
features (int): number of features
|
| 424 |
+
"""
|
| 425 |
+
super(FeatureFusionBlock, self).__init__()
|
| 426 |
+
|
| 427 |
+
self.deconv = deconv
|
| 428 |
+
self.align_corners = align_corners
|
| 429 |
+
self.groups = groups
|
| 430 |
+
self.expand = expand
|
| 431 |
+
out_features = features
|
| 432 |
+
if self.expand == True:
|
| 433 |
+
out_features = features // 2
|
| 434 |
+
|
| 435 |
+
self.out_conv = nn.Conv2d(
|
| 436 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if has_residual:
|
| 440 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 441 |
+
|
| 442 |
+
self.has_residual = has_residual
|
| 443 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
| 444 |
+
|
| 445 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 446 |
+
self.size = size
|
| 447 |
+
|
| 448 |
+
def forward(self, *xs, size=None):
|
| 449 |
+
"""Forward pass.
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
tensor: output
|
| 453 |
+
"""
|
| 454 |
+
output = xs[0]
|
| 455 |
+
|
| 456 |
+
if self.has_residual:
|
| 457 |
+
res = self.resConfUnit1(xs[1])
|
| 458 |
+
output = self.skip_add.add(output, res)
|
| 459 |
+
|
| 460 |
+
output = self.resConfUnit2(output)
|
| 461 |
+
|
| 462 |
+
if (size is None) and (self.size is None):
|
| 463 |
+
modifier = {"scale_factor": 2}
|
| 464 |
+
elif size is None:
|
| 465 |
+
modifier = {"size": self.size}
|
| 466 |
+
else:
|
| 467 |
+
modifier = {"size": size}
|
| 468 |
+
|
| 469 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
| 470 |
+
output = self.out_conv(output)
|
| 471 |
+
|
| 472 |
+
return output
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def custom_interpolate(
|
| 476 |
+
x: torch.Tensor,
|
| 477 |
+
size: Tuple[int, int] = None,
|
| 478 |
+
scale_factor: float = None,
|
| 479 |
+
mode: str = "bilinear",
|
| 480 |
+
align_corners: bool = True,
|
| 481 |
+
) -> torch.Tensor:
|
| 482 |
+
"""
|
| 483 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 484 |
+
"""
|
| 485 |
+
if size is None:
|
| 486 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 487 |
+
|
| 488 |
+
INT_MAX = 1610612736
|
| 489 |
+
|
| 490 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 491 |
+
|
| 492 |
+
if input_elements > INT_MAX:
|
| 493 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 494 |
+
interpolated_chunks = [
|
| 495 |
+
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
| 496 |
+
]
|
| 497 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 498 |
+
return x.contiguous()
|
| 499 |
+
else:
|
| 500 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
unish/heads/head_act.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
| 13 |
+
"""
|
| 14 |
+
Activate pose parameters with specified activation functions.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 18 |
+
trans_act: Activation type for translation component
|
| 19 |
+
quat_act: Activation type for quaternion component
|
| 20 |
+
fl_act: Activation type for focal length component
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Activated pose parameters tensor
|
| 24 |
+
"""
|
| 25 |
+
T = pred_pose_enc[..., :3]
|
| 26 |
+
quat = pred_pose_enc[..., 3:7]
|
| 27 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
| 28 |
+
|
| 29 |
+
T = base_pose_act(T, trans_act)
|
| 30 |
+
quat = base_pose_act(quat, quat_act)
|
| 31 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
| 32 |
+
|
| 33 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 34 |
+
|
| 35 |
+
return pred_pose_enc
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
| 39 |
+
"""
|
| 40 |
+
Apply basic activation function to pose parameters.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
pose_enc: Tensor containing encoded pose parameters
|
| 44 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Activated pose parameters
|
| 48 |
+
"""
|
| 49 |
+
if act_type == "linear":
|
| 50 |
+
return pose_enc
|
| 51 |
+
elif act_type == "inv_log":
|
| 52 |
+
return inverse_log_transform(pose_enc)
|
| 53 |
+
elif act_type == "exp":
|
| 54 |
+
return torch.exp(pose_enc)
|
| 55 |
+
elif act_type == "relu":
|
| 56 |
+
return F.relu(pose_enc)
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 62 |
+
"""
|
| 63 |
+
Process network output to extract 3D points and confidence values.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
out: Network output tensor (B, C, H, W)
|
| 67 |
+
activation: Activation type for 3D points
|
| 68 |
+
conf_activation: Activation type for confidence values
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 72 |
+
"""
|
| 73 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 74 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 75 |
+
|
| 76 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 77 |
+
xyz = fmap[:, :, :, :-1]
|
| 78 |
+
conf = fmap[:, :, :, -1]
|
| 79 |
+
|
| 80 |
+
if activation == "norm_exp":
|
| 81 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 82 |
+
xyz_normed = xyz / d
|
| 83 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 84 |
+
elif activation == "norm":
|
| 85 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 86 |
+
elif activation == "exp":
|
| 87 |
+
pts3d = torch.exp(xyz)
|
| 88 |
+
elif activation == "relu":
|
| 89 |
+
pts3d = F.relu(xyz)
|
| 90 |
+
elif activation == "inv_log":
|
| 91 |
+
pts3d = inverse_log_transform(xyz)
|
| 92 |
+
elif activation == "xy_inv_log":
|
| 93 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
| 94 |
+
z = inverse_log_transform(z)
|
| 95 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 96 |
+
elif activation == "sigmoid":
|
| 97 |
+
pts3d = torch.sigmoid(xyz)
|
| 98 |
+
elif activation == "linear":
|
| 99 |
+
pts3d = xyz
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 102 |
+
|
| 103 |
+
if conf_activation == "expp1":
|
| 104 |
+
conf_out = 1 + conf.exp()
|
| 105 |
+
elif conf_activation == "expp0":
|
| 106 |
+
conf_out = conf.exp()
|
| 107 |
+
elif conf_activation == "sigmoid":
|
| 108 |
+
conf_out = torch.sigmoid(conf)
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 111 |
+
|
| 112 |
+
return pts3d, conf_out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def inverse_log_transform(y):
|
| 116 |
+
"""
|
| 117 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
y: Input tensor
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Transformed tensor
|
| 124 |
+
"""
|
| 125 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
unish/heads/human_head_cliff.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import einops
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from unish.utils.data_utils import rot6d_to_rotmat
|
| 9 |
+
from unish.utils.constants import SMPL_MEAN_PARAMS
|
| 10 |
+
from .pose_transformer import TransformerDecoder
|
| 11 |
+
|
| 12 |
+
TRANSFORMER_DECODER={'depth': 6,
|
| 13 |
+
'heads': 8,
|
| 14 |
+
'mlp_dim': 1024,
|
| 15 |
+
'dim_head': 64,
|
| 16 |
+
'dropout': 0.0,
|
| 17 |
+
'emb_dropout': 0.0,
|
| 18 |
+
'norm': 'layer',
|
| 19 |
+
'context_dim': 1280}
|
| 20 |
+
|
| 21 |
+
NUM_POSE_INPUT = 23
|
| 22 |
+
NUM_BETAS_INPUT = 10
|
| 23 |
+
NUM_BETAS = 10
|
| 24 |
+
NUM_POSE_PARAMS = 23
|
| 25 |
+
|
| 26 |
+
class HumanHeadCliff(nn.Module):
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.joint_rep_dim = 6
|
| 31 |
+
npose = self.joint_rep_dim * (NUM_POSE_INPUT + 1)
|
| 32 |
+
self.npose = npose
|
| 33 |
+
transformer_args = dict(
|
| 34 |
+
num_tokens=1,
|
| 35 |
+
token_dim=(3 + npose + NUM_BETAS_INPUT + 3),
|
| 36 |
+
dim=1024,
|
| 37 |
+
)
|
| 38 |
+
transformer_args = (transformer_args | dict(TRANSFORMER_DECODER))
|
| 39 |
+
self.transformer = TransformerDecoder(
|
| 40 |
+
**transformer_args
|
| 41 |
+
)
|
| 42 |
+
dim=transformer_args['dim']
|
| 43 |
+
self.decpose = nn.Linear(dim, self.joint_rep_dim * (NUM_POSE_PARAMS + 1))
|
| 44 |
+
self.decshape = nn.Linear(dim, NUM_BETAS)
|
| 45 |
+
# self.deccam = nn.Linear(dim, 3)
|
| 46 |
+
# self.deckp = nn.Linear(dim, 88)
|
| 47 |
+
|
| 48 |
+
mean_params = SMPL_MEAN_PARAMS
|
| 49 |
+
init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
|
| 50 |
+
init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
|
| 51 |
+
init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
|
| 52 |
+
self.register_buffer('init_body_pose', init_body_pose)
|
| 53 |
+
self.register_buffer('init_betas', init_betas)
|
| 54 |
+
self.register_buffer('init_cam', init_cam)
|
| 55 |
+
|
| 56 |
+
def gradient_checkpointing_enable(self):
|
| 57 |
+
"""Enable gradient checkpointing for memory optimization."""
|
| 58 |
+
if hasattr(self.transformer, 'gradient_checkpointing_enable'):
|
| 59 |
+
self.transformer.gradient_checkpointing_enable()
|
| 60 |
+
|
| 61 |
+
def forward(self, x, bbox_info, **kwargs):
|
| 62 |
+
"""
|
| 63 |
+
x: (B, N, C, H, W)
|
| 64 |
+
bbox_info: [cx / f, cy / f, box_size / f], (B, N, 3)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
batch_size, num_views = x.shape[:2]
|
| 68 |
+
x = einops.rearrange(x, 'b n c h w -> (b n) (h w) c')
|
| 69 |
+
|
| 70 |
+
init_body_pose = self.init_body_pose.expand(batch_size * num_views, -1)
|
| 71 |
+
init_betas = self.init_betas.expand(batch_size * num_views, -1)
|
| 72 |
+
init_cam = self.init_cam.expand(batch_size * num_views, -1)
|
| 73 |
+
bbox_info = bbox_info.view(-1, 3)
|
| 74 |
+
|
| 75 |
+
pred_body_pose = init_body_pose
|
| 76 |
+
pred_betas = init_betas
|
| 77 |
+
pred_cam = init_cam
|
| 78 |
+
token = torch.cat([bbox_info, pred_body_pose, pred_betas, pred_cam], dim=-1)[:, None, :]
|
| 79 |
+
|
| 80 |
+
# Pass through transformer
|
| 81 |
+
token_out = self.transformer(token, context=x)
|
| 82 |
+
token_out = token_out.squeeze(1) # (B, C)
|
| 83 |
+
|
| 84 |
+
pred_body_pose = self.decpose(token_out) + pred_body_pose
|
| 85 |
+
pred_betas = self.decshape(token_out) + pred_betas
|
| 86 |
+
|
| 87 |
+
joint_conversion_fn = rot6d_to_rotmat
|
| 88 |
+
|
| 89 |
+
pred_body_pose = pred_body_pose.view(-1, 6)
|
| 90 |
+
pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, num_views, -1)
|
| 91 |
+
pred_betas = pred_betas.view(batch_size, num_views, -1).mean(dim=1)
|
| 92 |
+
token_out = token_out.view(batch_size, num_views, -1)
|
| 93 |
+
|
| 94 |
+
pred_smpl_params = {'pose_cam': pred_body_pose,
|
| 95 |
+
'token_out': token_out,
|
| 96 |
+
'betas': pred_betas}
|
| 97 |
+
return pred_smpl_params
|
unish/heads/pose_transformer.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inspect import isfunction
|
| 2 |
+
from typing import Callable, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from einops.layers.torch import Rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from .t_cond_mlp import (
|
| 10 |
+
AdaptiveLayerNorm1D,
|
| 11 |
+
FrequencyEmbedder,
|
| 12 |
+
normalization_layer,
|
| 13 |
+
)
|
| 14 |
+
# from .vit import Attention, FeedForward
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def exists(val):
|
| 18 |
+
return val is not None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def default(val, d):
|
| 22 |
+
if exists(val):
|
| 23 |
+
return val
|
| 24 |
+
return d() if isfunction(d) else d
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class PreNorm(nn.Module):
|
| 28 |
+
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.norm = normalization_layer(norm, dim, norm_cond_dim)
|
| 31 |
+
self.fn = fn
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 34 |
+
if isinstance(self.norm, AdaptiveLayerNorm1D):
|
| 35 |
+
return self.fn(self.norm(x, *args), **kwargs)
|
| 36 |
+
else:
|
| 37 |
+
return self.fn(self.norm(x), **kwargs)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class FeedForward(nn.Module):
|
| 41 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.net = nn.Sequential(
|
| 44 |
+
nn.Linear(dim, hidden_dim),
|
| 45 |
+
nn.GELU(),
|
| 46 |
+
nn.Dropout(dropout),
|
| 47 |
+
nn.Linear(hidden_dim, dim),
|
| 48 |
+
nn.Dropout(dropout),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return self.net(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Attention(nn.Module):
|
| 56 |
+
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
| 57 |
+
super().__init__()
|
| 58 |
+
inner_dim = dim_head * heads
|
| 59 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 60 |
+
|
| 61 |
+
self.heads = heads
|
| 62 |
+
self.scale = dim_head**-0.5
|
| 63 |
+
|
| 64 |
+
self.attend = nn.Softmax(dim=-1)
|
| 65 |
+
self.dropout = nn.Dropout(dropout)
|
| 66 |
+
|
| 67 |
+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
| 68 |
+
|
| 69 |
+
self.to_out = (
|
| 70 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
| 71 |
+
if project_out
|
| 72 |
+
else nn.Identity()
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
| 77 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
|
| 78 |
+
|
| 79 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 80 |
+
|
| 81 |
+
attn = self.attend(dots)
|
| 82 |
+
attn = self.dropout(attn)
|
| 83 |
+
|
| 84 |
+
out = torch.matmul(attn, v)
|
| 85 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 86 |
+
return self.to_out(out)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class CrossAttention(nn.Module):
|
| 90 |
+
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
| 91 |
+
super().__init__()
|
| 92 |
+
inner_dim = dim_head * heads
|
| 93 |
+
project_out = not (heads == 1 and dim_head == dim)
|
| 94 |
+
|
| 95 |
+
self.heads = heads
|
| 96 |
+
self.scale = dim_head**-0.5
|
| 97 |
+
|
| 98 |
+
self.attend = nn.Softmax(dim=-1)
|
| 99 |
+
self.dropout = nn.Dropout(dropout)
|
| 100 |
+
|
| 101 |
+
context_dim = default(context_dim, dim)
|
| 102 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
|
| 103 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 104 |
+
|
| 105 |
+
self.to_out = (
|
| 106 |
+
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
|
| 107 |
+
if project_out
|
| 108 |
+
else nn.Identity()
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, x, context=None):
|
| 112 |
+
context = default(context, x)
|
| 113 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 114 |
+
q = self.to_q(x)
|
| 115 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
|
| 116 |
+
|
| 117 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 118 |
+
|
| 119 |
+
attn = self.attend(dots)
|
| 120 |
+
attn = self.dropout(attn)
|
| 121 |
+
|
| 122 |
+
out = torch.matmul(attn, v)
|
| 123 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 124 |
+
return self.to_out(out)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Transformer(nn.Module):
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
dim: int,
|
| 131 |
+
depth: int,
|
| 132 |
+
heads: int,
|
| 133 |
+
dim_head: int,
|
| 134 |
+
mlp_dim: int,
|
| 135 |
+
dropout: float = 0.0,
|
| 136 |
+
norm: str = "layer",
|
| 137 |
+
norm_cond_dim: int = -1,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.layers = nn.ModuleList([])
|
| 141 |
+
for _ in range(depth):
|
| 142 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
| 143 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
| 144 |
+
self.layers.append(
|
| 145 |
+
nn.ModuleList(
|
| 146 |
+
[
|
| 147 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 148 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 149 |
+
]
|
| 150 |
+
)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def forward(self, x: torch.Tensor, *args):
|
| 154 |
+
for attn, ff in self.layers:
|
| 155 |
+
x = attn(x, *args) + x
|
| 156 |
+
x = ff(x, *args) + x
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TransformerCrossAttn(nn.Module):
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
dim: int,
|
| 164 |
+
depth: int,
|
| 165 |
+
heads: int,
|
| 166 |
+
dim_head: int,
|
| 167 |
+
mlp_dim: int,
|
| 168 |
+
dropout: float = 0.0,
|
| 169 |
+
norm: str = "layer",
|
| 170 |
+
norm_cond_dim: int = -1,
|
| 171 |
+
context_dim: Optional[int] = None,
|
| 172 |
+
):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.layers = nn.ModuleList([])
|
| 175 |
+
for _ in range(depth):
|
| 176 |
+
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
| 177 |
+
ca = CrossAttention(
|
| 178 |
+
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
|
| 179 |
+
)
|
| 180 |
+
ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
| 181 |
+
self.layers.append(
|
| 182 |
+
nn.ModuleList(
|
| 183 |
+
[
|
| 184 |
+
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 185 |
+
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 186 |
+
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
|
| 187 |
+
]
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
|
| 192 |
+
if context_list is None:
|
| 193 |
+
context_list = [context] * len(self.layers)
|
| 194 |
+
if len(context_list) != len(self.layers):
|
| 195 |
+
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
|
| 196 |
+
|
| 197 |
+
b, n = x.shape[:2]
|
| 198 |
+
|
| 199 |
+
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
|
| 200 |
+
x = self_attn(x, *args) + x
|
| 201 |
+
# TODO
|
| 202 |
+
# x = x.view(b*n, 1, -1)
|
| 203 |
+
x = cross_attn(x, *args, context=context_list[i]) + x
|
| 204 |
+
# x = x.view(b, n, -1)
|
| 205 |
+
x = ff(x, *args) + x
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class DropTokenDropout(nn.Module):
|
| 210 |
+
def __init__(self, p: float = 0.1):
|
| 211 |
+
super().__init__()
|
| 212 |
+
if p < 0 or p > 1:
|
| 213 |
+
raise ValueError(
|
| 214 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
| 215 |
+
)
|
| 216 |
+
self.p = p
|
| 217 |
+
|
| 218 |
+
def forward(self, x: torch.Tensor):
|
| 219 |
+
# x: (batch_size, seq_len, dim)
|
| 220 |
+
if self.training and self.p > 0:
|
| 221 |
+
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
|
| 222 |
+
# TODO: permutation idx for each batch using torch.argsort
|
| 223 |
+
if zero_mask.any():
|
| 224 |
+
x = x[:, ~zero_mask, :]
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class ZeroTokenDropout(nn.Module):
|
| 229 |
+
def __init__(self, p: float = 0.1):
|
| 230 |
+
super().__init__()
|
| 231 |
+
if p < 0 or p > 1:
|
| 232 |
+
raise ValueError(
|
| 233 |
+
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
|
| 234 |
+
)
|
| 235 |
+
self.p = p
|
| 236 |
+
|
| 237 |
+
def forward(self, x: torch.Tensor):
|
| 238 |
+
# x: (batch_size, seq_len, dim)
|
| 239 |
+
if self.training and self.p > 0:
|
| 240 |
+
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
|
| 241 |
+
# Zero-out the masked tokens
|
| 242 |
+
x[zero_mask, :] = 0
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class TransformerEncoder(nn.Module):
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
num_tokens: int,
|
| 250 |
+
token_dim: int,
|
| 251 |
+
dim: int,
|
| 252 |
+
depth: int,
|
| 253 |
+
heads: int,
|
| 254 |
+
mlp_dim: int,
|
| 255 |
+
dim_head: int = 64,
|
| 256 |
+
dropout: float = 0.0,
|
| 257 |
+
emb_dropout: float = 0.0,
|
| 258 |
+
emb_dropout_type: str = "drop",
|
| 259 |
+
emb_dropout_loc: str = "token",
|
| 260 |
+
norm: str = "layer",
|
| 261 |
+
norm_cond_dim: int = -1,
|
| 262 |
+
token_pe_numfreq: int = -1,
|
| 263 |
+
):
|
| 264 |
+
super().__init__()
|
| 265 |
+
if token_pe_numfreq > 0:
|
| 266 |
+
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
|
| 267 |
+
self.to_token_embedding = nn.Sequential(
|
| 268 |
+
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
|
| 269 |
+
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
|
| 270 |
+
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
|
| 271 |
+
nn.Linear(token_dim_new, dim),
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
| 275 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
| 276 |
+
if emb_dropout_type == "drop":
|
| 277 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
| 278 |
+
elif emb_dropout_type == "zero":
|
| 279 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
|
| 282 |
+
self.emb_dropout_loc = emb_dropout_loc
|
| 283 |
+
|
| 284 |
+
self.transformer = Transformer(
|
| 285 |
+
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
def forward(self, inp: torch.Tensor, *args, **kwargs):
|
| 289 |
+
x = inp
|
| 290 |
+
|
| 291 |
+
if self.emb_dropout_loc == "input":
|
| 292 |
+
x = self.dropout(x)
|
| 293 |
+
x = self.to_token_embedding(x)
|
| 294 |
+
|
| 295 |
+
if self.emb_dropout_loc == "token":
|
| 296 |
+
x = self.dropout(x)
|
| 297 |
+
b, n, _ = x.shape
|
| 298 |
+
x += self.pos_embedding[:, :n]
|
| 299 |
+
|
| 300 |
+
if self.emb_dropout_loc == "token_afterpos":
|
| 301 |
+
x = self.dropout(x)
|
| 302 |
+
x = self.transformer(x, *args)
|
| 303 |
+
return x
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class TransformerDecoder(nn.Module):
|
| 307 |
+
def __init__(
|
| 308 |
+
self,
|
| 309 |
+
num_tokens: int,
|
| 310 |
+
token_dim: int,
|
| 311 |
+
dim: int,
|
| 312 |
+
depth: int,
|
| 313 |
+
heads: int,
|
| 314 |
+
mlp_dim: int,
|
| 315 |
+
dim_head: int = 64,
|
| 316 |
+
dropout: float = 0.0,
|
| 317 |
+
emb_dropout: float = 0.0,
|
| 318 |
+
emb_dropout_type: str = 'drop',
|
| 319 |
+
norm: str = "layer",
|
| 320 |
+
norm_cond_dim: int = -1,
|
| 321 |
+
context_dim: Optional[int] = None,
|
| 322 |
+
skip_token_embedding: bool = False,
|
| 323 |
+
):
|
| 324 |
+
super().__init__()
|
| 325 |
+
if not skip_token_embedding:
|
| 326 |
+
self.to_token_embedding = nn.Linear(token_dim, dim)
|
| 327 |
+
else:
|
| 328 |
+
self.to_token_embedding = nn.Identity()
|
| 329 |
+
if token_dim != dim:
|
| 330 |
+
raise ValueError(
|
| 331 |
+
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
|
| 335 |
+
if emb_dropout_type == "drop":
|
| 336 |
+
self.dropout = DropTokenDropout(emb_dropout)
|
| 337 |
+
elif emb_dropout_type == "zero":
|
| 338 |
+
self.dropout = ZeroTokenDropout(emb_dropout)
|
| 339 |
+
elif emb_dropout_type == "normal":
|
| 340 |
+
self.dropout = nn.Dropout(emb_dropout)
|
| 341 |
+
|
| 342 |
+
self.transformer = TransformerCrossAttn(
|
| 343 |
+
dim,
|
| 344 |
+
depth,
|
| 345 |
+
heads,
|
| 346 |
+
dim_head,
|
| 347 |
+
mlp_dim,
|
| 348 |
+
dropout,
|
| 349 |
+
norm=norm,
|
| 350 |
+
norm_cond_dim=norm_cond_dim,
|
| 351 |
+
context_dim=context_dim,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
|
| 355 |
+
|
| 356 |
+
x = self.to_token_embedding(inp)
|
| 357 |
+
b, n, _ = x.shape
|
| 358 |
+
|
| 359 |
+
x = self.dropout(x)
|
| 360 |
+
x += self.pos_embedding[:, :n]
|
| 361 |
+
|
| 362 |
+
x = self.transformer(x, *args, context=context, context_list=context_list)
|
| 363 |
+
return x
|
| 364 |
+
|
unish/heads/t_cond_mlp.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AdaptiveLayerNorm1D(torch.nn.Module):
|
| 8 |
+
def __init__(self, data_dim: int, norm_cond_dim: int):
|
| 9 |
+
super().__init__()
|
| 10 |
+
if data_dim <= 0:
|
| 11 |
+
raise ValueError(f"data_dim must be positive, but got {data_dim}")
|
| 12 |
+
if norm_cond_dim <= 0:
|
| 13 |
+
raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
|
| 14 |
+
self.norm = torch.nn.LayerNorm(
|
| 15 |
+
data_dim
|
| 16 |
+
) # TODO: Check if elementwise_affine=True is correct
|
| 17 |
+
self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
|
| 18 |
+
torch.nn.init.zeros_(self.linear.weight)
|
| 19 |
+
torch.nn.init.zeros_(self.linear.bias)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
# x: (batch, ..., data_dim)
|
| 23 |
+
# t: (batch, norm_cond_dim)
|
| 24 |
+
# return: (batch, data_dim)
|
| 25 |
+
x = self.norm(x)
|
| 26 |
+
alpha, beta = self.linear(t).chunk(2, dim=-1)
|
| 27 |
+
|
| 28 |
+
# Add singleton dimensions to alpha and beta
|
| 29 |
+
if x.dim() > 2:
|
| 30 |
+
alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
|
| 31 |
+
beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
|
| 32 |
+
|
| 33 |
+
return x * (1 + alpha) + beta
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SequentialCond(torch.nn.Sequential):
|
| 37 |
+
def forward(self, input, *args, **kwargs):
|
| 38 |
+
for module in self:
|
| 39 |
+
if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
|
| 40 |
+
# print(f'Passing on args to {module}', [a.shape for a in args])
|
| 41 |
+
input = module(input, *args, **kwargs)
|
| 42 |
+
else:
|
| 43 |
+
# print(f'Skipping passing args to {module}', [a.shape for a in args])
|
| 44 |
+
input = module(input)
|
| 45 |
+
return input
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
|
| 49 |
+
if norm == "batch":
|
| 50 |
+
return torch.nn.BatchNorm1d(dim)
|
| 51 |
+
elif norm == "layer":
|
| 52 |
+
return torch.nn.LayerNorm(dim)
|
| 53 |
+
elif norm == "ada":
|
| 54 |
+
assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
|
| 55 |
+
return AdaptiveLayerNorm1D(dim, norm_cond_dim)
|
| 56 |
+
elif norm is None:
|
| 57 |
+
return torch.nn.Identity()
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Unknown norm: {norm}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def linear_norm_activ_dropout(
|
| 63 |
+
input_dim: int,
|
| 64 |
+
output_dim: int,
|
| 65 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 66 |
+
bias: bool = True,
|
| 67 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 68 |
+
dropout: float = 0.0,
|
| 69 |
+
norm_cond_dim: int = -1,
|
| 70 |
+
) -> SequentialCond:
|
| 71 |
+
layers = []
|
| 72 |
+
layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
|
| 73 |
+
if norm is not None:
|
| 74 |
+
layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
|
| 75 |
+
layers.append(copy.deepcopy(activation))
|
| 76 |
+
if dropout > 0.0:
|
| 77 |
+
layers.append(torch.nn.Dropout(dropout))
|
| 78 |
+
return SequentialCond(*layers)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def create_simple_mlp(
|
| 82 |
+
input_dim: int,
|
| 83 |
+
hidden_dims: List[int],
|
| 84 |
+
output_dim: int,
|
| 85 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 86 |
+
bias: bool = True,
|
| 87 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 88 |
+
dropout: float = 0.0,
|
| 89 |
+
norm_cond_dim: int = -1,
|
| 90 |
+
) -> SequentialCond:
|
| 91 |
+
layers = []
|
| 92 |
+
prev_dim = input_dim
|
| 93 |
+
for hidden_dim in hidden_dims:
|
| 94 |
+
layers.extend(
|
| 95 |
+
linear_norm_activ_dropout(
|
| 96 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
| 97 |
+
)
|
| 98 |
+
)
|
| 99 |
+
prev_dim = hidden_dim
|
| 100 |
+
layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
|
| 101 |
+
return SequentialCond(*layers)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ResidualMLPBlock(torch.nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
input_dim: int,
|
| 108 |
+
hidden_dim: int,
|
| 109 |
+
num_hidden_layers: int,
|
| 110 |
+
output_dim: int,
|
| 111 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 112 |
+
bias: bool = True,
|
| 113 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 114 |
+
dropout: float = 0.0,
|
| 115 |
+
norm_cond_dim: int = -1,
|
| 116 |
+
):
|
| 117 |
+
super().__init__()
|
| 118 |
+
if not (input_dim == output_dim == hidden_dim):
|
| 119 |
+
raise NotImplementedError(
|
| 120 |
+
f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
layers = []
|
| 124 |
+
prev_dim = input_dim
|
| 125 |
+
for i in range(num_hidden_layers):
|
| 126 |
+
layers.append(
|
| 127 |
+
linear_norm_activ_dropout(
|
| 128 |
+
prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
prev_dim = hidden_dim
|
| 132 |
+
self.model = SequentialCond(*layers)
|
| 133 |
+
self.skip = torch.nn.Identity()
|
| 134 |
+
|
| 135 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 136 |
+
return x + self.model(x, *args, **kwargs)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ResidualMLP(torch.nn.Module):
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
input_dim: int,
|
| 143 |
+
hidden_dim: int,
|
| 144 |
+
num_hidden_layers: int,
|
| 145 |
+
output_dim: int,
|
| 146 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 147 |
+
bias: bool = True,
|
| 148 |
+
norm: Optional[str] = "layer", # Options: ada/batch/layer
|
| 149 |
+
dropout: float = 0.0,
|
| 150 |
+
num_blocks: int = 1,
|
| 151 |
+
norm_cond_dim: int = -1,
|
| 152 |
+
):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.input_dim = input_dim
|
| 155 |
+
self.model = SequentialCond(
|
| 156 |
+
linear_norm_activ_dropout(
|
| 157 |
+
input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
|
| 158 |
+
),
|
| 159 |
+
*[
|
| 160 |
+
ResidualMLPBlock(
|
| 161 |
+
hidden_dim,
|
| 162 |
+
hidden_dim,
|
| 163 |
+
num_hidden_layers,
|
| 164 |
+
hidden_dim,
|
| 165 |
+
activation,
|
| 166 |
+
bias,
|
| 167 |
+
norm,
|
| 168 |
+
dropout,
|
| 169 |
+
norm_cond_dim,
|
| 170 |
+
)
|
| 171 |
+
for _ in range(num_blocks)
|
| 172 |
+
],
|
| 173 |
+
torch.nn.Linear(hidden_dim, output_dim, bias=bias),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 177 |
+
return self.model(x, *args, **kwargs)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class FrequencyEmbedder(torch.nn.Module):
|
| 181 |
+
def __init__(self, num_frequencies, max_freq_log2):
|
| 182 |
+
super().__init__()
|
| 183 |
+
frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
|
| 184 |
+
self.register_buffer("frequencies", frequencies)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
# x should be of size (N,) or (N, D)
|
| 188 |
+
N = x.size(0)
|
| 189 |
+
if x.dim() == 1: # (N,)
|
| 190 |
+
x = x.unsqueeze(1) # (N, D) where D=1
|
| 191 |
+
x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
|
| 192 |
+
scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
|
| 193 |
+
s = torch.sin(scaled)
|
| 194 |
+
c = torch.cos(scaled)
|
| 195 |
+
embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
|
| 196 |
+
N, -1
|
| 197 |
+
) # (N, D * 2 * num_frequencies + D)
|
| 198 |
+
return embedded
|
| 199 |
+
|
unish/heads/utils.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
| 12 |
+
"""
|
| 13 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 17 |
+
embed_dim: Output channel dimension for embeddings
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 21 |
+
"""
|
| 22 |
+
H, W, grid_dim = pos_grid.shape
|
| 23 |
+
assert grid_dim == 2
|
| 24 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 25 |
+
|
| 26 |
+
# Process x and y coordinates separately
|
| 27 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 28 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 29 |
+
|
| 30 |
+
# Combine and reshape
|
| 31 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 32 |
+
|
| 33 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 37 |
+
"""
|
| 38 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
- embed_dim: The embedding dimension.
|
| 42 |
+
- pos: The position to generate the embedding from.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
- emb: The generated 1D positional embedding.
|
| 46 |
+
"""
|
| 47 |
+
assert embed_dim % 2 == 0
|
| 48 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
|
| 49 |
+
omega /= embed_dim / 2.0
|
| 50 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 51 |
+
|
| 52 |
+
pos = pos.reshape(-1) # (M,)
|
| 53 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 54 |
+
|
| 55 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 56 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 57 |
+
|
| 58 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 59 |
+
return emb.float()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Inspired by https://github.com/microsoft/moge
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_uv_grid(
|
| 66 |
+
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
| 67 |
+
) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 70 |
+
|
| 71 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 72 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 73 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
width (int): Number of points horizontally.
|
| 77 |
+
height (int): Number of points vertically.
|
| 78 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 79 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 80 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 84 |
+
"""
|
| 85 |
+
# Derive aspect ratio if not explicitly provided
|
| 86 |
+
if aspect_ratio is None:
|
| 87 |
+
aspect_ratio = float(width) / float(height)
|
| 88 |
+
|
| 89 |
+
# Compute normalized spans for X and Y
|
| 90 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 91 |
+
span_x = aspect_ratio / diag_factor
|
| 92 |
+
span_y = 1.0 / diag_factor
|
| 93 |
+
|
| 94 |
+
# Establish the linspace boundaries
|
| 95 |
+
left_x = -span_x * (width - 1) / width
|
| 96 |
+
right_x = span_x * (width - 1) / width
|
| 97 |
+
top_y = -span_y * (height - 1) / height
|
| 98 |
+
bottom_y = span_y * (height - 1) / height
|
| 99 |
+
|
| 100 |
+
# Generate 1D coordinates
|
| 101 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 102 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 103 |
+
|
| 104 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 105 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 106 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 107 |
+
|
| 108 |
+
return uv_grid
|
unish/heads/vit.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from functools import partial
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.utils.checkpoint as checkpoint
|
| 9 |
+
|
| 10 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
| 11 |
+
|
| 12 |
+
def vit():
|
| 13 |
+
return ViT(
|
| 14 |
+
img_size=(256, 192),
|
| 15 |
+
patch_size=16,
|
| 16 |
+
embed_dim=1280,
|
| 17 |
+
depth=32,
|
| 18 |
+
num_heads=16,
|
| 19 |
+
ratio=1,
|
| 20 |
+
use_checkpoint=False,
|
| 21 |
+
mlp_ratio=4,
|
| 22 |
+
qkv_bias=True,
|
| 23 |
+
drop_path_rate=0.55,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
|
| 27 |
+
"""
|
| 28 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
|
| 29 |
+
dimension for the original embeddings.
|
| 30 |
+
Args:
|
| 31 |
+
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
|
| 32 |
+
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
|
| 33 |
+
hw (Tuple): size of input image tokens.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Absolute positional embeddings after processing with shape (1, H, W, C)
|
| 37 |
+
"""
|
| 38 |
+
cls_token = None
|
| 39 |
+
B, L, C = abs_pos.shape
|
| 40 |
+
if has_cls_token:
|
| 41 |
+
cls_token = abs_pos[:, 0:1]
|
| 42 |
+
abs_pos = abs_pos[:, 1:]
|
| 43 |
+
|
| 44 |
+
if ori_h != h or ori_w != w:
|
| 45 |
+
new_abs_pos = F.interpolate(
|
| 46 |
+
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
|
| 47 |
+
size=(h, w),
|
| 48 |
+
mode="bicubic",
|
| 49 |
+
align_corners=False,
|
| 50 |
+
).permute(0, 2, 3, 1).reshape(B, -1, C)
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
new_abs_pos = abs_pos
|
| 54 |
+
|
| 55 |
+
if cls_token is not None:
|
| 56 |
+
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
|
| 57 |
+
return new_abs_pos
|
| 58 |
+
|
| 59 |
+
class DropPath(nn.Module):
|
| 60 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 61 |
+
"""
|
| 62 |
+
def __init__(self, drop_prob=None):
|
| 63 |
+
super(DropPath, self).__init__()
|
| 64 |
+
self.drop_prob = drop_prob
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 68 |
+
|
| 69 |
+
def extra_repr(self):
|
| 70 |
+
return 'p={}'.format(self.drop_prob)
|
| 71 |
+
|
| 72 |
+
class Mlp(nn.Module):
|
| 73 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 74 |
+
super().__init__()
|
| 75 |
+
out_features = out_features or in_features
|
| 76 |
+
hidden_features = hidden_features or in_features
|
| 77 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 78 |
+
self.act = act_layer()
|
| 79 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 80 |
+
self.drop = nn.Dropout(drop)
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
x = self.fc1(x)
|
| 84 |
+
x = self.act(x)
|
| 85 |
+
x = self.fc2(x)
|
| 86 |
+
x = self.drop(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
class Attention(nn.Module):
|
| 90 |
+
def __init__(
|
| 91 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
| 92 |
+
proj_drop=0., attn_head_dim=None,):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.num_heads = num_heads
|
| 95 |
+
head_dim = dim // num_heads
|
| 96 |
+
self.dim = dim
|
| 97 |
+
|
| 98 |
+
if attn_head_dim is not None:
|
| 99 |
+
head_dim = attn_head_dim
|
| 100 |
+
all_head_dim = head_dim * self.num_heads
|
| 101 |
+
|
| 102 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 103 |
+
|
| 104 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
|
| 105 |
+
|
| 106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 107 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
B, N, C = x.shape
|
| 112 |
+
qkv = self.qkv(x)
|
| 113 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 114 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 115 |
+
|
| 116 |
+
q = q * self.scale
|
| 117 |
+
attn = (q @ k.transpose(-2, -1))
|
| 118 |
+
|
| 119 |
+
attn = attn.softmax(dim=-1)
|
| 120 |
+
attn = self.attn_drop(attn)
|
| 121 |
+
|
| 122 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
| 123 |
+
x = self.proj(x)
|
| 124 |
+
x = self.proj_drop(x)
|
| 125 |
+
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
class Block(nn.Module):
|
| 129 |
+
|
| 130 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
|
| 131 |
+
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
|
| 132 |
+
norm_layer=nn.LayerNorm, attn_head_dim=None
|
| 133 |
+
):
|
| 134 |
+
super().__init__()
|
| 135 |
+
|
| 136 |
+
self.norm1 = norm_layer(dim)
|
| 137 |
+
self.attn = Attention(
|
| 138 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 139 |
+
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 143 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 144 |
+
self.norm2 = norm_layer(dim)
|
| 145 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 146 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 150 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class PatchEmbed(nn.Module):
|
| 155 |
+
""" Image to Patch Embedding
|
| 156 |
+
"""
|
| 157 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
|
| 158 |
+
super().__init__()
|
| 159 |
+
img_size = to_2tuple(img_size)
|
| 160 |
+
patch_size = to_2tuple(patch_size)
|
| 161 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
|
| 162 |
+
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
|
| 163 |
+
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
|
| 164 |
+
self.img_size = img_size
|
| 165 |
+
self.patch_size = patch_size
|
| 166 |
+
self.num_patches = num_patches
|
| 167 |
+
|
| 168 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
|
| 169 |
+
|
| 170 |
+
def forward(self, x, **kwargs):
|
| 171 |
+
B, C, H, W = x.shape
|
| 172 |
+
x = self.proj(x)
|
| 173 |
+
Hp, Wp = x.shape[2], x.shape[3]
|
| 174 |
+
|
| 175 |
+
x = x.flatten(2).transpose(1, 2)
|
| 176 |
+
return x, (Hp, Wp)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class HybridEmbed(nn.Module):
|
| 180 |
+
""" CNN Feature Map Embedding
|
| 181 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
| 184 |
+
super().__init__()
|
| 185 |
+
assert isinstance(backbone, nn.Module)
|
| 186 |
+
img_size = to_2tuple(img_size)
|
| 187 |
+
self.img_size = img_size
|
| 188 |
+
self.backbone = backbone
|
| 189 |
+
if feature_size is None:
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
training = backbone.training
|
| 192 |
+
if training:
|
| 193 |
+
backbone.eval()
|
| 194 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
| 195 |
+
feature_size = o.shape[-2:]
|
| 196 |
+
feature_dim = o.shape[1]
|
| 197 |
+
backbone.train(training)
|
| 198 |
+
else:
|
| 199 |
+
feature_size = to_2tuple(feature_size)
|
| 200 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
| 201 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
| 202 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
x = self.backbone(x)[-1]
|
| 206 |
+
x = x.flatten(2).transpose(1, 2)
|
| 207 |
+
x = self.proj(x)
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class ViT(nn.Module):
|
| 212 |
+
def __init__(self,
|
| 213 |
+
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
|
| 214 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
| 215 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
|
| 216 |
+
frozen_stages=-1, ratio=1, last_norm=True,
|
| 217 |
+
patch_padding='pad', freeze_attn=False, freeze_ffn=False,
|
| 218 |
+
):
|
| 219 |
+
super(ViT, self).__init__()
|
| 220 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 221 |
+
self.num_classes = num_classes
|
| 222 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 223 |
+
self.frozen_stages = frozen_stages
|
| 224 |
+
self.use_checkpoint = use_checkpoint
|
| 225 |
+
self.patch_padding = patch_padding
|
| 226 |
+
self.freeze_attn = freeze_attn
|
| 227 |
+
self.freeze_ffn = freeze_ffn
|
| 228 |
+
self.depth = depth
|
| 229 |
+
|
| 230 |
+
if hybrid_backbone is not None:
|
| 231 |
+
self.patch_embed = HybridEmbed(
|
| 232 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 233 |
+
else:
|
| 234 |
+
self.patch_embed = PatchEmbed(
|
| 235 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
|
| 236 |
+
num_patches = self.patch_embed.num_patches
|
| 237 |
+
|
| 238 |
+
# since the pretraining model has class token
|
| 239 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
| 240 |
+
|
| 241 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 242 |
+
|
| 243 |
+
self.blocks = nn.ModuleList([
|
| 244 |
+
Block(
|
| 245 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 246 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
| 247 |
+
)
|
| 248 |
+
for i in range(depth)])
|
| 249 |
+
|
| 250 |
+
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
|
| 251 |
+
|
| 252 |
+
if self.pos_embed is not None:
|
| 253 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 254 |
+
|
| 255 |
+
self._freeze_stages()
|
| 256 |
+
|
| 257 |
+
def _freeze_stages(self):
|
| 258 |
+
"""Freeze parameters."""
|
| 259 |
+
if self.frozen_stages >= 0:
|
| 260 |
+
self.patch_embed.eval()
|
| 261 |
+
for param in self.patch_embed.parameters():
|
| 262 |
+
param.requires_grad = False
|
| 263 |
+
|
| 264 |
+
for i in range(1, self.frozen_stages + 1):
|
| 265 |
+
m = self.blocks[i]
|
| 266 |
+
m.eval()
|
| 267 |
+
for param in m.parameters():
|
| 268 |
+
param.requires_grad = False
|
| 269 |
+
|
| 270 |
+
if self.freeze_attn:
|
| 271 |
+
for i in range(0, self.depth):
|
| 272 |
+
m = self.blocks[i]
|
| 273 |
+
m.attn.eval()
|
| 274 |
+
m.norm1.eval()
|
| 275 |
+
for param in m.attn.parameters():
|
| 276 |
+
param.requires_grad = False
|
| 277 |
+
for param in m.norm1.parameters():
|
| 278 |
+
param.requires_grad = False
|
| 279 |
+
|
| 280 |
+
if self.freeze_ffn:
|
| 281 |
+
self.pos_embed.requires_grad = False
|
| 282 |
+
self.patch_embed.eval()
|
| 283 |
+
for param in self.patch_embed.parameters():
|
| 284 |
+
param.requires_grad = False
|
| 285 |
+
for i in range(0, self.depth):
|
| 286 |
+
m = self.blocks[i]
|
| 287 |
+
m.mlp.eval()
|
| 288 |
+
m.norm2.eval()
|
| 289 |
+
for param in m.mlp.parameters():
|
| 290 |
+
param.requires_grad = False
|
| 291 |
+
for param in m.norm2.parameters():
|
| 292 |
+
param.requires_grad = False
|
| 293 |
+
|
| 294 |
+
def init_weights(self):
|
| 295 |
+
"""Initialize the weights in backbone.
|
| 296 |
+
Args:
|
| 297 |
+
pretrained (str, optional): Path to pre-trained weights.
|
| 298 |
+
Defaults to None.
|
| 299 |
+
"""
|
| 300 |
+
def _init_weights(m):
|
| 301 |
+
if isinstance(m, nn.Linear):
|
| 302 |
+
trunc_normal_(m.weight, std=.02)
|
| 303 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 304 |
+
nn.init.constant_(m.bias, 0)
|
| 305 |
+
elif isinstance(m, nn.LayerNorm):
|
| 306 |
+
nn.init.constant_(m.bias, 0)
|
| 307 |
+
nn.init.constant_(m.weight, 1.0)
|
| 308 |
+
|
| 309 |
+
self.apply(_init_weights)
|
| 310 |
+
|
| 311 |
+
def get_num_layers(self):
|
| 312 |
+
return len(self.blocks)
|
| 313 |
+
|
| 314 |
+
@torch.jit.ignore
|
| 315 |
+
def no_weight_decay(self):
|
| 316 |
+
return {'pos_embed', 'cls_token'}
|
| 317 |
+
|
| 318 |
+
def forward_features(self, x):
|
| 319 |
+
B, C, H, W = x.shape
|
| 320 |
+
x, (Hp, Wp) = self.patch_embed(x)
|
| 321 |
+
|
| 322 |
+
if self.pos_embed is not None:
|
| 323 |
+
# fit for multiple GPU training
|
| 324 |
+
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
|
| 325 |
+
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
|
| 326 |
+
|
| 327 |
+
for blk in self.blocks:
|
| 328 |
+
if self.use_checkpoint:
|
| 329 |
+
x = checkpoint.checkpoint(blk, x)
|
| 330 |
+
else:
|
| 331 |
+
x = blk(x)
|
| 332 |
+
|
| 333 |
+
x = self.last_norm(x)
|
| 334 |
+
|
| 335 |
+
xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
|
| 336 |
+
|
| 337 |
+
return xp
|
| 338 |
+
|
| 339 |
+
def forward(self, x):
|
| 340 |
+
x = self.forward_features(x)
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
def train(self, mode=True):
|
| 344 |
+
"""Convert the model into training mode."""
|
| 345 |
+
super().train(mode)
|
| 346 |
+
self._freeze_stages()
|
unish/pi3/models/__pycache__/pi3.cpython-310.pyc
ADDED
|
Binary file (7.01 kB). View file
|
|
|
unish/pi3/models/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (181 Bytes). View file
|
|
|
unish/pi3/models/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc
ADDED
|
Binary file (3.99 kB). View file
|
|
|
unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
unish/pi3/models/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 88 |
+
"""
|
| 89 |
+
return _make_dinov2_model(
|
| 90 |
+
arch_name="vit_giant2",
|
| 91 |
+
ffn_layer="swiglufused",
|
| 92 |
+
weights=weights,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 101 |
+
"""
|
| 102 |
+
return _make_dinov2_model(
|
| 103 |
+
arch_name="vit_small",
|
| 104 |
+
pretrained=pretrained,
|
| 105 |
+
weights=weights,
|
| 106 |
+
num_register_tokens=4,
|
| 107 |
+
interpolate_antialias=True,
|
| 108 |
+
interpolate_offset=0.0,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 116 |
+
"""
|
| 117 |
+
return _make_dinov2_model(
|
| 118 |
+
arch_name="vit_base",
|
| 119 |
+
pretrained=pretrained,
|
| 120 |
+
weights=weights,
|
| 121 |
+
num_register_tokens=4,
|
| 122 |
+
interpolate_antialias=True,
|
| 123 |
+
interpolate_offset=0.0,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 131 |
+
"""
|
| 132 |
+
return _make_dinov2_model(
|
| 133 |
+
arch_name="vit_large",
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
num_register_tokens=4,
|
| 137 |
+
interpolate_antialias=True,
|
| 138 |
+
interpolate_offset=0.0,
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 144 |
+
"""
|
| 145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 146 |
+
"""
|
| 147 |
+
return _make_dinov2_model(
|
| 148 |
+
arch_name="vit_giant2",
|
| 149 |
+
ffn_layer="swiglufused",
|
| 150 |
+
weights=weights,
|
| 151 |
+
pretrained=pretrained,
|
| 152 |
+
num_register_tokens=4,
|
| 153 |
+
interpolate_antialias=True,
|
| 154 |
+
interpolate_offset=0.0,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
unish/pi3/models/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
unish/pi3/models/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (451 Bytes). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (2.47 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc
ADDED
|
Binary file (8.04 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc
ADDED
|
Binary file (1.99 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc
ADDED
|
Binary file (1.01 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc
ADDED
|
Binary file (2.12 kB). View file
|
|
|
unish/pi3/models/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 22 |
+
try:
|
| 23 |
+
if XFORMERS_ENABLED:
|
| 24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 25 |
+
|
| 26 |
+
XFORMERS_AVAILABLE = True
|
| 27 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 28 |
+
else:
|
| 29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 30 |
+
raise ImportError
|
| 31 |
+
except ImportError:
|
| 32 |
+
XFORMERS_AVAILABLE = False
|
| 33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Attention(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int = 8,
|
| 41 |
+
qkv_bias: bool = False,
|
| 42 |
+
proj_bias: bool = True,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
proj_drop: float = 0.0,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.num_heads = num_heads
|
| 48 |
+
head_dim = dim // num_heads
|
| 49 |
+
self.scale = head_dim**-0.5
|
| 50 |
+
|
| 51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 57 |
+
B, N, C = x.shape
|
| 58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 59 |
+
|
| 60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 61 |
+
attn = q @ k.transpose(-2, -1)
|
| 62 |
+
|
| 63 |
+
attn = attn.softmax(dim=-1)
|
| 64 |
+
attn = self.attn_drop(attn)
|
| 65 |
+
|
| 66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
x = self.proj_drop(x)
|
| 69 |
+
return x
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class MemEffAttention(Attention):
|
| 73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 74 |
+
if not XFORMERS_AVAILABLE:
|
| 75 |
+
if attn_bias is not None:
|
| 76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 77 |
+
return super().forward(x)
|
| 78 |
+
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 81 |
+
|
| 82 |
+
q, k, v = unbind(qkv, 2)
|
| 83 |
+
|
| 84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 85 |
+
x = x.reshape([B, N, C])
|
| 86 |
+
|
| 87 |
+
x = self.proj(x)
|
| 88 |
+
x = self.proj_drop(x)
|
| 89 |
+
return x
|
unish/pi3/models/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
# warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Block(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
mlp_ratio: float = 4.0,
|
| 48 |
+
qkv_bias: bool = False,
|
| 49 |
+
proj_bias: bool = True,
|
| 50 |
+
ffn_bias: bool = True,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
attn_drop: float = 0.0,
|
| 53 |
+
init_values=None,
|
| 54 |
+
drop_path: float = 0.0,
|
| 55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 62 |
+
self.norm1 = norm_layer(dim)
|
| 63 |
+
self.attn = attn_class(
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=num_heads,
|
| 66 |
+
qkv_bias=qkv_bias,
|
| 67 |
+
proj_bias=proj_bias,
|
| 68 |
+
attn_drop=attn_drop,
|
| 69 |
+
proj_drop=drop,
|
| 70 |
+
)
|
| 71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 73 |
+
|
| 74 |
+
self.norm2 = norm_layer(dim)
|
| 75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 76 |
+
self.mlp = ffn_layer(
|
| 77 |
+
in_features=dim,
|
| 78 |
+
hidden_features=mlp_hidden_dim,
|
| 79 |
+
act_layer=act_layer,
|
| 80 |
+
drop=drop,
|
| 81 |
+
bias=ffn_bias,
|
| 82 |
+
)
|
| 83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 85 |
+
|
| 86 |
+
self.sample_drop_ratio = drop_path
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 91 |
+
|
| 92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 94 |
+
|
| 95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 97 |
+
x = drop_add_residual_stochastic_depth(
|
| 98 |
+
x,
|
| 99 |
+
residual_func=attn_residual_func,
|
| 100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 101 |
+
)
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=ffn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 110 |
+
else:
|
| 111 |
+
x = x + attn_residual_func(x)
|
| 112 |
+
x = x + ffn_residual_func(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def drop_add_residual_stochastic_depth(
|
| 117 |
+
x: Tensor,
|
| 118 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 119 |
+
sample_drop_ratio: float = 0.0,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
# 1) extract subset using permutation
|
| 122 |
+
b, n, d = x.shape
|
| 123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 125 |
+
x_subset = x[brange]
|
| 126 |
+
|
| 127 |
+
# 2) apply residual_func to get residual
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
unish/pi3/models/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
unish/pi3/models/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
unish/pi3/models/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
unish/pi3/models/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|