murphylmf commited on
Commit
ae166e6
·
1 Parent(s): eaea719
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +61 -7
  2. app.py +457 -0
  3. environment.yml +14 -0
  4. inference.py +186 -0
  5. install.sh +74 -0
  6. packages.txt +7 -0
  7. requirements.txt +22 -0
  8. static/teaser.svg +0 -0
  9. unish/__pycache__/pipeline.cpython-310.pyc +0 -0
  10. unish/heads/__pycache__/align_net.cpython-310.pyc +0 -0
  11. unish/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
  12. unish/heads/__pycache__/head_act.cpython-310.pyc +0 -0
  13. unish/heads/__pycache__/human_head_cliff.cpython-310.pyc +0 -0
  14. unish/heads/__pycache__/pose_transformer.cpython-310.pyc +0 -0
  15. unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc +0 -0
  16. unish/heads/__pycache__/utils.cpython-310.pyc +0 -0
  17. unish/heads/__pycache__/vit.cpython-310.pyc +0 -0
  18. unish/heads/align_net.py +571 -0
  19. unish/heads/dpt_head.py +500 -0
  20. unish/heads/head_act.py +125 -0
  21. unish/heads/human_head_cliff.py +97 -0
  22. unish/heads/pose_transformer.py +364 -0
  23. unish/heads/t_cond_mlp.py +199 -0
  24. unish/heads/utils.py +108 -0
  25. unish/heads/vit.py +346 -0
  26. unish/pi3/models/__pycache__/pi3.cpython-310.pyc +0 -0
  27. unish/pi3/models/dinov2/__init__.py +6 -0
  28. unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
  29. unish/pi3/models/dinov2/hub/__init__.py +4 -0
  30. unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc +0 -0
  31. unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc +0 -0
  32. unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc +0 -0
  33. unish/pi3/models/dinov2/hub/backbones.py +156 -0
  34. unish/pi3/models/dinov2/hub/utils.py +39 -0
  35. unish/pi3/models/dinov2/layers/__init__.py +11 -0
  36. unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc +0 -0
  37. unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc +0 -0
  38. unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc +0 -0
  39. unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc +0 -0
  40. unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc +0 -0
  41. unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
  42. unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc +0 -0
  43. unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
  44. unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
  45. unish/pi3/models/dinov2/layers/attention.py +89 -0
  46. unish/pi3/models/dinov2/layers/block.py +259 -0
  47. unish/pi3/models/dinov2/layers/dino_head.py +58 -0
  48. unish/pi3/models/dinov2/layers/drop_path.py +34 -0
  49. unish/pi3/models/dinov2/layers/layer_scale.py +27 -0
  50. unish/pi3/models/dinov2/layers/mlp.py +40 -0
README.md CHANGED
@@ -1,13 +1,67 @@
1
  ---
2
- title: UniSH
3
- emoji: 🏆
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: UniSH (Unified Scene & Human Reconstruction)
3
+ emoji: 🏃‍♂️
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.0.0
8
  app_file: app.py
9
  pinned: false
10
+ license: cc-by-nc-4.0
11
  ---
12
 
13
+ # UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass
14
+
15
+ <div align="center">
16
+
17
+ Mengfei Li<sup>1</sup>, Peng Li<sup>1</sup>, Zheng Zhang<sup>2</sup>, Jiahao Lu<sup>1</sup>, Chengfeng Zhao<sup>1</sup>, Wei Xue<sup>1</sup>, <br>
18
+ Qifeng Liu<sup>1</sup>, Sida Peng<sup>3</sup>, Wenxiao Zhang<sup>1</sup>, Wenhan Luo<sup>1</sup>, Yuan Liu<sup>1†</sup>, Yike Guo<sup>1†</sup>
19
+
20
+ <sup>1</sup>The Hong Kong University of Science and Technology, <sup>2</sup>Beijing University of Posts and Telecommunications, <sup>3</sup>Zhejiang University
21
+
22
+ <a href="https://murphylmf.github.io/UniSH/"><img src="https://img.shields.io/badge/Project-Page-8A2BE2" alt="Project Page"></a>
23
+ <a href="https://arxiv.org/abs/2601.01222"><img src="https://img.shields.io/badge/arXiv-2601.01222-b31b1b.svg" alt="arXiv"></a>
24
+ <a href="https://github.com/murphylmf/UniSH"><img src="https://img.shields.io/badge/GitHub-Code-black.svg" alt="Code"></a>
25
+
26
+ </div>
27
+
28
+ ## Abstract
29
+
30
+ We present UniSH, a unified, feed-forward framework for joint metric-scale 3D scene and human reconstruction. A key challenge in this domain is the scarcity of large-scale, annotated real-world data, forcing a reliance on synthetic datasets. This reliance introduces a significant sim-to-real domain gap, leading to poor generalization, low-fidelity human geometry, and poor alignment on in-the-wild videos.
31
+
32
+ To address this, we propose an innovative training paradigm that effectively leverages unlabeled in-the-wild data. Our framework bridges strong, disparate priors from scene reconstruction and HMR, and is trained with two core components: (1) a robust distillation strategy to refine human surface details by distilling high-frequency details from an expert depth model, and (2) a two-stage supervision scheme, which first learns coarse localization on synthetic data, then fine-tunes on real data by directly optimizing the geometric correspondence between the SMPL mesh and the human point cloud. This approach enables our feed-forward model to jointly recover high-fidelity scene geometry, human point clouds, camera parameters, and coherent, metric-scale SMPL bodies, all in a single forward pass. Extensive experiments demonstrate that our model achieves state-of-the-art performance on human-centric scene reconstruction and delivers highly competitive results on global human motion estimation, comparing favorably against both optimization-based frameworks and HMR-only methods.
33
+
34
+ ## Method
35
+
36
+ ![Teaser](static/teaser.svg)
37
+
38
+ **The network architecture of UniSH.**
39
+ UniSH takes a monocular video as input. The video frames are processed by the **Reconstruction Branch** to predict per-frame camera extrinsics *E*, confidence maps *C*, and pointmaps *P*. Camera intrinsics *K* are derived from the pointmaps. Human crops from the video are fed into the **Human Body Branch** along with *K* to estimate global SMPL shape parameters *β* and per-frame pose parameters *θ<sub>i</sub>*. Features from both branches are processed by **AlignNet** to predict the global scene scale *s* and per-frame SMPL translations *t<sub>i</sub>* for coherent scene and human alignment.
40
+
41
+ ## Usage
42
+
43
+ This Space provides an interactive demo for UniSH.
44
+
45
+ 1. **Upload a Video**: Upload a monocular video containing a human.
46
+ 2. **Set Duration**: Choose the duration to process (default: 3 seconds).
47
+ 3. **Run Inference**: Click "Run Inference" to generate the 3D reconstruction.
48
+ 4. **Visualize**: The result will be displayed in an interactive 3D viewer where you can rotate, pan, and zoom.
49
+
50
+ ## BibTeX
51
+
52
+ ```bibtex
53
+ @misc{li2026unishunifyingscenehuman,
54
+ title={UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass},
55
+ author={Mengfei Li and Peng Li and Zheng Zhang and Jiahao Lu and Chengfeng Zhao and Wei Xue and Qifeng Liu and Sida Peng and Wenxiao Zhang and Wenhan Luo and Yuan Liu and Yike Guo},
56
+ year={2026},
57
+ eprint={2601.01222},
58
+ archivePrefix={arXiv},
59
+ primaryClass={cs.CV},
60
+ url={https://arxiv.org/abs/2601.01222},
61
+ }
62
+ ```
63
+
64
+ ## Acknowledgements
65
+
66
+ This website is licensed under a [Creative Commons Attribution-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-sa/4.0/).
67
+ Template borrowed from [Nerfies](https://github.com/nerfies/nerfies.github.io).
app.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os
4
+ import sys
5
+ import shutil
6
+ import tempfile
7
+ import torch
8
+ import cv2
9
+ import subprocess
10
+ import numpy as np
11
+ import trimesh
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # Add current directory to path
15
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
16
+
17
+ from unish.utils.inference_utils import (
18
+ load_model,
19
+ process_video,
20
+ run_inference,
21
+ generate_mixed_geometries_in_memory,
22
+ save_smpl_meshes_per_frame,
23
+ save_scene_only_point_clouds,
24
+ save_human_point_clouds,
25
+ save_camera_parameters_per_frame
26
+ )
27
+
28
+ MODEL = None
29
+ BODY_MODELS_PATH = "body_models/"
30
+
31
+ def download_smpl_assets(body_models_path):
32
+ """
33
+ Download SMPL models from private repository if they don't exist.
34
+ The path logic mimics SMPLWrapper's expectation:
35
+ 1. SMPLWrapper appends 'smpl' if not present in body_models_path.
36
+ 2. smplx library expects another 'smpl' folder inside that (or appends it).
37
+ Based on existing structure 'body_models/smpl/smpl/SMPL_*.pkl', the target dir is constructed below.
38
+ """
39
+ if 'smpl' not in body_models_path:
40
+ model_path = os.path.join(body_models_path, 'smpl')
41
+ else:
42
+ model_path = body_models_path
43
+
44
+ # smplx looks for a 'smpl' folder inside the given model_path
45
+ target_dir = os.path.join(model_path, 'smpl')
46
+
47
+ os.makedirs(target_dir, exist_ok=True)
48
+
49
+ files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"]
50
+ token = os.environ.get("SMPL_DOWNLOAD_TOKEN")
51
+
52
+ for filename in files:
53
+ file_path = os.path.join(target_dir, filename)
54
+ if not os.path.exists(file_path):
55
+ if not token:
56
+ print(f"Warning: SMPL_DOWNLOAD_TOKEN not set. Cannot download {filename}.")
57
+ continue
58
+
59
+ print(f"Downloading {filename} to {target_dir}...")
60
+ try:
61
+ hf_hub_download(
62
+ repo_id="Murphyyyy/UniSH-Private-Assets",
63
+ filename=filename,
64
+ local_dir=target_dir,
65
+ token=token
66
+ )
67
+ except Exception as e:
68
+ print(f"Failed to download {filename}: {e}")
69
+
70
+ def pack_sequence_to_glb(base_dir, output_path, start_frame=0, end_frame=60, scene_rate=0.5):
71
+ scene = trimesh.Scene()
72
+
73
+ print(f">>> Packing frames {start_frame} to {end_frame}...")
74
+
75
+ valid_count = 0
76
+
77
+ for i in range(start_frame, end_frame):
78
+ frame_node_name = f"frame_{valid_count}"
79
+
80
+ s_path = os.path.join(base_dir, "scene_only_point_clouds", f"scene_only_frame_{i:04d}.ply")
81
+ h_path = os.path.join(base_dir, "human_only_point_clouds", f"human_frame_{i:04d}.ply")
82
+ smpl_path = os.path.join(base_dir, "smpl_meshes_per_frame", f"smpl_mesh_frame_{i:04d}.ply")
83
+
84
+ if not (os.path.exists(h_path) or os.path.exists(smpl_path)):
85
+ continue
86
+
87
+ scene.graph.update(frame_node_name, parent="world")
88
+
89
+ if os.path.exists(smpl_path):
90
+ try:
91
+ smpl = trimesh.load(smpl_path)
92
+ flesh_color = [255, 160, 122, 255]
93
+ smpl.visual.vertex_colors = np.tile(flesh_color, (len(smpl.vertices), 1))
94
+
95
+ scene.add_geometry(smpl, node_name=f"{frame_node_name}_smpl", parent_node_name=frame_node_name)
96
+ except Exception as e:
97
+ pass
98
+
99
+ if os.path.exists(h_path):
100
+ try:
101
+ human = trimesh.load(h_path)
102
+ if isinstance(human, trimesh.PointCloud):
103
+ scene.add_geometry(human, node_name=f"{frame_node_name}_human", parent_node_name=frame_node_name)
104
+ except: pass
105
+
106
+ if os.path.exists(s_path):
107
+ try:
108
+ s_obj = trimesh.load(s_path)
109
+ if isinstance(s_obj, trimesh.PointCloud):
110
+ total_pts = len(s_obj.vertices)
111
+ if total_pts > 0:
112
+ if scene_rate < 0.99:
113
+ count = int(total_pts * scene_rate)
114
+ if count > 100:
115
+ idx = np.random.choice(total_pts, count, replace=False)
116
+ s_obj = trimesh.PointCloud(s_obj.vertices[idx], colors=s_obj.colors[idx])
117
+ scene.add_geometry(s_obj, node_name=f"{frame_node_name}_scene", parent_node_name=frame_node_name)
118
+ except: pass
119
+
120
+ valid_count += 1
121
+
122
+ if valid_count == 0:
123
+ print("Error: No valid frames found.")
124
+ return
125
+
126
+ try:
127
+ rot = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0])
128
+ scene.apply_transform(rot)
129
+ except: pass
130
+
131
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
132
+ print(f">>> Exporting to {output_path}...")
133
+ scene.export(output_path)
134
+ print(f">>> Done! Saved {valid_count} frames.")
135
+
136
+ def get_player_html(glb_abs_path):
137
+ html_content = f"""
138
+ <!DOCTYPE html>
139
+ <html>
140
+ <head>
141
+ <meta charset="utf-8">
142
+ <title>UniSH Viewer</title>
143
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.4/css/bulma.min.css">
144
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
145
+ <style>
146
+ #canvas-container {{
147
+ width: 100%;
148
+ height: 600px;
149
+ background: #f5f5f5;
150
+ border-radius: 8px;
151
+ position: relative;
152
+ overflow: hidden;
153
+ box-shadow: inset 0 0 20px rgba(0,0,0,0.05);
154
+ }}
155
+ .slider {{
156
+ width: 100%;
157
+ }}
158
+ </style>
159
+ <script type="importmap">
160
+ {{
161
+ "imports": {{
162
+ "three": "https://unpkg.com/three@0.158.0/build/three.module.js",
163
+ "three/addons/": "https://unpkg.com/three@0.158.0/examples/jsm/"
164
+ }}
165
+ }}
166
+ </script>
167
+ </head>
168
+ <body>
169
+ <div class="box" style="padding: 10px; background: #f5f5f5;">
170
+ <div id="canvas-container">
171
+ <div id="loading-overlay" style="position: absolute; top:0; left:0; width:100%; height:100%; background: rgba(0,0,0,0.7); color: white; display: flex; flex-direction: column; justify-content: center; align-items: center; z-index: 10;">
172
+ <span class="icon is-large"><i class="fas fa-spinner fa-pulse"></i></span>
173
+ <p style="margin-top: 10px;">Loading 3D Sequence...</p>
174
+ </div>
175
+ </div>
176
+
177
+ <div class="columns is-vcentered is-mobile" style="margin-top: 10px; padding: 0 10px;">
178
+ <div class="column is-narrow">
179
+ <button id="play-btn" class="button is-dark is-rounded is-small">
180
+ <span class="icon is-small"><i class="fas fa-play"></i></span>
181
+ </button>
182
+ </div>
183
+ <div class="column">
184
+ <input id="frame-slider" class="slider is-fullwidth is-circle is-dark" step="1" min="0" max="0" value="0" type="range">
185
+ </div>
186
+ <div class="column is-narrow">
187
+ <span id="frame-count" class="tag is-light" style="width: 80px;">Frame: 0</span>
188
+ </div>
189
+ </div>
190
+ </div>
191
+
192
+ <script type="module">
193
+ import * as THREE from 'three';
194
+ import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';
195
+ import {{ GLTFLoader }} from 'three/addons/loaders/GLTFLoader.js';
196
+
197
+ // Inject the model path using f-string from Python
198
+ const MODEL_PATH = "/file={glb_abs_path}";
199
+ const FPS = 10;
200
+
201
+ let scene, camera, renderer, controls;
202
+ let frames = [];
203
+ let currentFrame = 0;
204
+ let isPlaying = false;
205
+ let intervalId = null;
206
+
207
+ const container = document.getElementById('canvas-container');
208
+ const slider = document.getElementById('frame-slider');
209
+ const playBtn = document.getElementById('play-btn');
210
+ const frameLabel = document.getElementById('frame-count');
211
+ const loadingOverlay = document.getElementById('loading-overlay');
212
+
213
+ init();
214
+
215
+ function init() {{
216
+ scene = new THREE.Scene();
217
+ scene.background = new THREE.Color(0xf5f5f5);
218
+
219
+ camera = new THREE.PerspectiveCamera(50, container.clientWidth / container.clientHeight, 0.1, 1000);
220
+ camera.position.set(-0.000, -4.272, 0.000);
221
+
222
+ renderer = new THREE.WebGLRenderer({{ antialias: true, alpha: true }});
223
+ renderer.setSize(container.clientWidth, container.clientHeight);
224
+ renderer.setPixelRatio(window.devicePixelRatio);
225
+
226
+ renderer.shadowMap.enabled = false;
227
+ renderer.useLegacyLights = false;
228
+
229
+ container.appendChild(renderer.domElement);
230
+
231
+ const hemiLight = new THREE.HemisphereLight(0xffffff, 0x444444, 3.0);
232
+ scene.add(hemiLight);
233
+
234
+ const dirLight = new THREE.DirectionalLight(0xffffff, 3.0);
235
+ dirLight.position.set(5, 10, 7);
236
+ scene.add(dirLight);
237
+
238
+ const frontLight = new THREE.DirectionalLight(0xffffff, 2.0);
239
+ frontLight.position.set(0, 0, 5);
240
+ scene.add(frontLight);
241
+
242
+ controls = new OrbitControls(camera, renderer.domElement);
243
+ controls.enableDamping = true;
244
+ controls.dampingFactor = 0.05;
245
+
246
+ controls.target.set(0.000, 0.000, 0.000);
247
+
248
+ const loader = new GLTFLoader();
249
+ console.log("Loading:", MODEL_PATH);
250
+
251
+ loader.load(MODEL_PATH, function (gltf) {{
252
+ const root = gltf.scene;
253
+ scene.add(root);
254
+
255
+ frames = [];
256
+ root.traverse((node) => {{
257
+
258
+ if (node.isMesh) {{
259
+ node.geometry.computeVertexNormals();
260
+ if (node.geometry.attributes.color) {{
261
+ node.geometry.deleteAttribute('color');
262
+ }}
263
+ node.material = new THREE.MeshStandardMaterial({{
264
+ color: 0xff9966,
265
+ roughness: 0.4,
266
+ metalness: 0.0,
267
+ side: THREE.DoubleSide
268
+ }});
269
+ node.material.vertexColors = false;
270
+ }}
271
+
272
+ if (node.isPoints) {{
273
+ if (node.name.toLowerCase().includes('scene')) {{
274
+ node.material.size = 0.05;
275
+ node.material.sizeAttenuation = true;
276
+ }}
277
+ if (node.name.toLowerCase().includes('human')) {{
278
+ node.material.size = 0.005;
279
+ }}
280
+ }}
281
+
282
+ if (node.name && node.name.startsWith('frame_')) {{
283
+ const parts = node.name.split('_');
284
+ if (parts.length === 2 && !isNaN(parseInt(parts[1]))) {{
285
+ const idx = parseInt(parts[1]);
286
+ frames[idx] = node;
287
+ node.visible = false;
288
+ }}
289
+ }}
290
+ }});
291
+
292
+ frames = frames.filter(n => n !== undefined);
293
+ console.log(`Loaded ${{frames.length}} frames.`);
294
+
295
+ if (frames.length > 0) {{
296
+ slider.max = frames.length - 1;
297
+ loadingOverlay.style.display = 'none';
298
+ showFrame(0);
299
+ }} else {{
300
+ loadingOverlay.innerHTML = "<p>No frames found.</p>";
301
+ }}
302
+
303
+ }}, undefined, function (error) {{
304
+ console.error(error);
305
+ loadingOverlay.innerHTML = "<p>Error loading model.</p>";
306
+ }});
307
+
308
+ window.addEventListener('resize', onWindowResize);
309
+ animate();
310
+ }}
311
+
312
+ function showFrame(idx) {{
313
+ if (!frames[idx]) return;
314
+ if (frames[currentFrame]) frames[currentFrame].visible = false;
315
+ frames[idx].visible = true;
316
+ currentFrame = idx;
317
+ slider.value = idx;
318
+ frameLabel.innerText = `Frame: ${{idx}}`;
319
+ }}
320
+
321
+ function togglePlay() {{
322
+ if (frames.length === 0) return;
323
+ isPlaying = !isPlaying;
324
+
325
+ const icon = playBtn.querySelector('.fa-play, .fa-pause');
326
+
327
+ if (isPlaying) {{
328
+ if(icon) {{ icon.classList.remove('fa-play'); icon.classList.add('fa-pause'); }}
329
+ intervalId = setInterval(() => {{
330
+ let next = currentFrame + 1;
331
+ if (next >= frames.length) next = 0;
332
+ showFrame(next);
333
+ }}, 1000 / FPS);
334
+ }} else {{
335
+ if(icon) {{ icon.classList.remove('fa-pause'); icon.classList.add('fa-play'); }}
336
+ clearInterval(intervalId);
337
+ }}
338
+ }}
339
+
340
+ slider.addEventListener('input', (e) => {{
341
+ if (isPlaying) togglePlay();
342
+ showFrame(parseInt(e.target.value));
343
+ }});
344
+ playBtn.addEventListener('click', togglePlay);
345
+
346
+ function onWindowResize() {{
347
+ camera.aspect = container.clientWidth / container.clientHeight;
348
+ camera.updateProjectionMatrix();
349
+ renderer.setSize(container.clientWidth, container.clientHeight);
350
+ }}
351
+
352
+ function animate() {{
353
+ requestAnimationFrame(animate);
354
+ controls.update();
355
+ renderer.render(scene, camera);
356
+ }}
357
+ </script>
358
+ </body>
359
+ </html>
360
+ """
361
+ return html_content
362
+
363
+ @spaces.GPU(duration=120)
364
+ def predict(video_path, duration_seconds=3.0):
365
+ global MODEL
366
+
367
+ # 0. Setup directories
368
+ output_dir = tempfile.mkdtemp()
369
+
370
+ # 1. Trim video
371
+ duration = min(float(duration_seconds), 10.0)
372
+ trimmed_video_path = os.path.join(output_dir, "input_trimmed.mp4")
373
+
374
+ cmd = [
375
+ "ffmpeg", "-i", video_path,
376
+ "-t", str(duration),
377
+ "-c:v", "libx264", "-c:a", "aac",
378
+ trimmed_video_path, "-y"
379
+ ]
380
+ subprocess.run(cmd, check=True)
381
+
382
+ # 2. Load Model
383
+ if MODEL is None:
384
+ MODEL = load_model()
385
+
386
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
387
+ MODEL.to(device)
388
+ MODEL.eval()
389
+
390
+ # 3. Process Video
391
+ fps = 6.0
392
+ target_size = 518
393
+ human_idx = 0
394
+ bbox_scale = 1.0
395
+
396
+ # Check and download SMPL assets
397
+ download_smpl_assets(BODY_MODELS_PATH)
398
+
399
+ data_dict = process_video(
400
+ trimmed_video_path, fps, human_idx, target_size,
401
+ bbox_scale=bbox_scale
402
+ )
403
+
404
+ # 4. Run Inference
405
+ results = run_inference(MODEL, data_dict, device, chunk_size=30)
406
+
407
+ # 5. Generate Geometries & Save
408
+ seq_name = results['seq_name']
409
+
410
+ viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
411
+ results, BODY_MODELS_PATH, fps=fps, conf_thres=0.1
412
+ )
413
+
414
+ # Save to disk
415
+ save_smpl_meshes_per_frame(results, output_dir, BODY_MODELS_PATH)
416
+ save_scene_only_point_clouds(viz_scene_only_point_clouds, output_dir, seq_name)
417
+ save_human_point_clouds(viz_scene_point_clouds, viz_scene_only_point_clouds, output_dir, seq_name, results)
418
+
419
+ # 6. Pack to GLB
420
+ base_dir = os.path.join(output_dir, seq_name)
421
+ output_glb_path = os.path.join(output_dir, "output.glb")
422
+
423
+ num_frames = len(viz_scene_point_clouds)
424
+
425
+ pack_sequence_to_glb(
426
+ base_dir,
427
+ output_glb_path,
428
+ start_frame=0,
429
+ end_frame=num_frames,
430
+ scene_rate=0.5
431
+ )
432
+
433
+ return get_player_html(output_glb_path)
434
+
435
+ with gr.Blocks() as demo:
436
+ gr.Markdown("# UniSH Demo")
437
+ gr.Markdown("Upload a video to reconstruct scene and human in 3D.")
438
+
439
+ with gr.Row():
440
+ with gr.Column():
441
+ input_video = gr.Video(label="Input Video")
442
+ duration_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Duration to Process (seconds)")
443
+ submit_btn = gr.Button("Run Inference", variant="primary")
444
+
445
+ with gr.Column():
446
+ output_html = gr.HTML(label="3D Result", min_height=600)
447
+
448
+ submit_btn.click(
449
+ predict,
450
+ inputs=[input_video, duration_slider],
451
+ outputs=[output_html]
452
+ )
453
+
454
+ demo.queue()
455
+ demo.launch()
456
+
457
+
environment.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: unish
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip
8
+ - git
9
+ - ninja
10
+ - mesalib
11
+ - libgl-devel
12
+ - libegl-devel
13
+ - gxx_linux-64=11.*
14
+ - ffmpeg
inference.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ import logging
7
+ from unish.utils.inference_utils import *
8
+
9
+ def setup_seed(seed):
10
+ torch.manual_seed(seed)
11
+ torch.cuda.manual_seed_all(seed)
12
+ np.random.seed(seed)
13
+ random.seed(seed)
14
+ torch.backends.cudnn.deterministic = True
15
+
16
+ def setup_logging(output_dir):
17
+ os.makedirs(output_dir, exist_ok=True)
18
+
19
+ # Create logger
20
+ logger = logging.getLogger()
21
+ logger.setLevel(logging.INFO)
22
+
23
+ # Create handlers
24
+ c_handler = logging.StreamHandler()
25
+ f_handler = logging.FileHandler(os.path.join(output_dir, 'inference.log'), mode='w')
26
+
27
+ # Create formatters and add it to handlers
28
+ c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
29
+ f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
30
+ c_handler.setFormatter(c_format)
31
+ f_handler.setFormatter(f_format)
32
+
33
+ # Add handlers to the logger
34
+ logger.addHandler(c_handler)
35
+ logger.addHandler(f_handler)
36
+
37
+ return logger
38
+
39
+ def main():
40
+ parser = argparse.ArgumentParser(description="Video Inference Script")
41
+ parser.add_argument("--video_path", type=str, required=True,
42
+ help="Path to the input video file or directory containing images")
43
+ parser.add_argument("--fps", type=float, default=6.0,
44
+ help="Target FPS for frame extraction (default: 6.0)")
45
+ parser.add_argument("--original_fps", type=float, default=30.0,
46
+ help="Original FPS of the image sequence (default: 30.0, used only for directory input)")
47
+ parser.add_argument("--target_size", type=int, default=518,
48
+ help="Target size for frame processing (default: 518)")
49
+ parser.add_argument("--checkpoint", type=str, default="checkpoints/unish_release.safetensors",
50
+ help="Path to the model checkpoint")
51
+ parser.add_argument("--output_dir", type=str, default="inference_results_video",
52
+ help="Output directory for results")
53
+ parser.add_argument("--body_models_path", type=str, default="body_models/",
54
+ help="Path to SMPL body models")
55
+ parser.add_argument("--device", type=str, default="cuda",
56
+ help="Device to run inference on")
57
+ parser.add_argument("--save_results", action="store_true", default=True,
58
+ help="Save additional results including smpl_points_for_camera (default: True)")
59
+ parser.add_argument("--chunk_size", type=int, default=30,
60
+ help="Number of frames to process in each chunk during inference (default: 30)")
61
+ parser.add_argument("--gpu_id", type=int, default=0,
62
+ help="GPU ID to use for inference (default: 0)")
63
+ parser.add_argument("--camera_mode", type=str, default="fixed",
64
+ choices=["predicted", "fixed"],
65
+ help="Camera mode: 'predicted' uses model-predicted camera parameters, "
66
+ "'fixed' uses a fixed camera angle (default: predicted)")
67
+ parser.add_argument("--human_idx", type=int, default=0,
68
+ help="Human index to process (default: 0)")
69
+ parser.add_argument("--start_idx", type=int, default=None,
70
+ help="Start frame index for processing (default: None, process from beginning)")
71
+ parser.add_argument("--end_idx", type=int, default=None,
72
+ help="End frame index for processing (default: None, process to end)")
73
+ parser.add_argument("--bbox_scale", type=float, default=1.0,
74
+ help="Scale factor for bounding box size (default: 1.0)")
75
+ parser.add_argument("--conf_thres", type=float, default=0.1,
76
+ help="Confidence threshold for point cloud generation (default: 0.1)")
77
+
78
+ # New arguments
79
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
80
+ parser.add_argument("--yolo_ckpt", type=str, default="ckpts/yolo11n.pt", help="Path to YOLO checkpoint")
81
+ parser.add_argument("--sam2_model", type=str, default="facebook/sam2-hiera-large", help="SAM2 model name or path")
82
+
83
+ args = parser.parse_args()
84
+
85
+ # Setup seed
86
+ setup_seed(args.seed)
87
+
88
+ # Setup logging
89
+ logger = setup_logging(args.output_dir)
90
+
91
+ # Setup device
92
+ if torch.cuda.is_available():
93
+ if args.device == "cuda":
94
+ # Use specified GPU ID
95
+ device = torch.device(f"cuda:{args.gpu_id}")
96
+ # Set the current CUDA device
97
+ torch.cuda.set_device(args.gpu_id)
98
+ logger.info(
99
+ f"Using GPU {args.gpu_id}: {torch.cuda.get_device_name(args.gpu_id)}")
100
+ else:
101
+ device = torch.device(args.device)
102
+ else:
103
+ device = torch.device("cpu")
104
+ logger.info("CUDA not available, using CPU")
105
+
106
+ logger.info(f"Using device: {device}")
107
+
108
+ # Load model
109
+ logger.info("Loading model...")
110
+ model = load_model(args.checkpoint)
111
+ model = model.to(device)
112
+ model.eval()
113
+
114
+ # Process video
115
+ logger.info(f"Processing video: {args.video_path}")
116
+ data_dict = process_video(
117
+ args.video_path, args.fps, args.human_idx, args.target_size,
118
+ bbox_scale=args.bbox_scale, start_idx=args.start_idx, end_idx=args.end_idx,
119
+ original_fps=args.original_fps,
120
+ yolo_ckpt=args.yolo_ckpt, sam2_model=args.sam2_model
121
+ )
122
+
123
+ # Run inference
124
+ results = run_inference(model, data_dict, device, args.chunk_size)
125
+
126
+ # Create output directory
127
+ os.makedirs(args.output_dir, exist_ok=True)
128
+
129
+ viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
130
+ results, args.body_models_path, fps=args.fps, conf_thres=args.conf_thres
131
+ )
132
+
133
+ # Determine camera mode based on arguments
134
+ use_predicted_camera = (args.camera_mode == "predicted")
135
+ logger.info(f"Using {args.camera_mode} camera mode")
136
+
137
+ original_rgb_images = results['rgb_images']
138
+
139
+ if original_rgb_images is not None:
140
+ if hasattr(original_rgb_images, 'permute'): # It's a torch tensor
141
+ original_rgb_images = original_rgb_images.permute(
142
+ 0, 2, 3, 1).cpu().numpy() # [S, H, W, 3]
143
+ elif not isinstance(original_rgb_images, np.ndarray):
144
+ original_rgb_images = np.array(original_rgb_images)
145
+
146
+ # Ensure proper data type and range
147
+ if original_rgb_images.max() <= 1.0:
148
+ original_rgb_images = (
149
+ original_rgb_images * 255).astype(np.uint8)
150
+
151
+ original_human_boxes = data_dict['human_boxes']
152
+
153
+ run_visualization(viz_scene_point_clouds, viz_smpl_meshes, smpl_points_for_camera,
154
+ args.output_dir, results['seq_name'],
155
+ fps=args.fps, # Use original fps
156
+ rgb_images=original_rgb_images,
157
+ human_boxes=original_human_boxes,
158
+ chunk_size=args.chunk_size, # Use original chunk size
159
+ results=results,
160
+ use_predicted_camera=use_predicted_camera,
161
+ scene_only_point_clouds=viz_scene_only_point_clouds,
162
+ conf_thres=args.conf_thres)
163
+
164
+ if args.save_results:
165
+
166
+ logger.info("Creating SMPL meshes per frame...")
167
+ save_smpl_meshes_per_frame(
168
+ results, args.output_dir, args.body_models_path)
169
+
170
+ logger.info("Saving scene point clouds (without human)...")
171
+ save_scene_only_point_clouds(
172
+ viz_scene_only_point_clouds, args.output_dir, results['seq_name'])
173
+
174
+ logger.info("Saving human point clouds...")
175
+ save_human_point_clouds(viz_scene_point_clouds,
176
+ viz_scene_only_point_clouds, args.output_dir, results['seq_name'], results)
177
+
178
+ logger.info("Saving camera parameters per frame...")
179
+ save_camera_parameters_per_frame(
180
+ results, args.output_dir, results['seq_name'])
181
+
182
+ logger.info(f"Inference completed! Results saved to {args.output_dir}")
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
install.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # ==========================================
5
+ # UniSH Auto-Install Script
6
+ # ==========================================
7
+
8
+ get_cuda_version() {
9
+ if [ ! -z "$1" ]; then echo "$1"; return; fi
10
+ if command -v nvidia-smi &> /dev/null; then
11
+ DRIVER_CUDA_MAJOR=$(nvidia-smi | grep "CUDA Version" | awk -F'CUDA Version:' '{print $2}' | awk -F'.' '{print $1}' | tr -d '[:space:]')
12
+ if [ "$DRIVER_CUDA_MAJOR" == "12" ]; then echo "12.1"; elif [ "$DRIVER_CUDA_MAJOR" == "11" ]; then echo "11.8"; else echo "12.1"; fi
13
+ else echo "12.1"; fi
14
+ }
15
+
16
+ if [[ -z "$CONDA_PREFIX" ]]; then
17
+ echo "❌ Error: Please activate the conda environment first!"
18
+ exit 1
19
+ fi
20
+
21
+ TARGET_CUDA=$(get_cuda_version "$1")
22
+ echo "========================================"
23
+ echo " Detected/Selected CUDA: $TARGET_CUDA"
24
+ echo "========================================"
25
+
26
+ if [[ "$TARGET_CUDA" == "12.1" ]]; then TORCH_INDEX_URL="https://download.pytorch.org/whl/cu121";
27
+ elif [[ "$TARGET_CUDA" == "11.8" ]]; then TORCH_INDEX_URL="https://download.pytorch.org/whl/cu118";
28
+ else TORCH_INDEX_URL=""; fi
29
+
30
+ echo "[1/6] Installing PyTorch 2.4.1 (CUDA $TARGET_CUDA)..."
31
+ pip install torch==2.4.1 torchvision==0.19.1 --index-url $TORCH_INDEX_URL
32
+
33
+ echo "[2/6] Installing Safe Requirements..."
34
+ pip install -r requirements.txt
35
+
36
+ echo "[3/6] Installing Custom Utils3D..."
37
+ pip install "git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183"
38
+
39
+ echo "[4/6] Installing Heavy Dependencies..."
40
+ pip install open3d==0.19.0 --no-deps
41
+ pip install ultralytics==8.3.227 --no-deps
42
+ pip install timm==1.0.24 --no-deps
43
+
44
+ echo "[5/6] Installing MMCV & PyTorch3D..."
45
+ pip install mmcv==2.2.0 --no-deps --no-binary mmcv
46
+ pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" --no-build-isolation
47
+
48
+ echo "[6/6] Installing SAM 2 (With Setuptools Fix)..."
49
+
50
+ pip install setuptools==69.5.1 wheel
51
+ rm -rf _tmp_install_sam2
52
+
53
+ mkdir -p _tmp_install_sam2
54
+ cd _tmp_install_sam2
55
+
56
+ echo " -> Cloning SAM 2..."
57
+ git clone https://github.com/facebookresearch/segment-anything-2.git --depth 1
58
+ cd segment-anything-2
59
+
60
+ echo " -> Patching setup.py..."
61
+ python -c "
62
+ path = 'setup.py'
63
+ with open(path, 'r') as f: c = f.read()
64
+ c = c.replace('torch>=2.5.1', 'torch>=2.4.1')
65
+ with open(path, 'w') as f: f.write(c)
66
+ "
67
+ pip install . --no-deps --no-build-isolation
68
+ cd ../..
69
+ rm -rf _tmp_install_sam2
70
+
71
+ echo "========================================"
72
+ echo "Installation Complete!"
73
+ python -c "import torch; print(f'PyTorch: {torch.__version__} | CUDA: {torch.version.cuda}')"
74
+ echo "========================================"
packages.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ffmpeg
2
+ libgl1-mesa-glx
3
+ libglib2.0-0
4
+ libegl1-mesa
5
+ xvfb
6
+
7
+
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.1
2
+ torchvision==0.19.1
3
+ numpy
4
+ scipy
5
+ trimesh
6
+ tqdm
7
+ opencv-python-headless
8
+ pillow
9
+ gradio
10
+ spaces
11
+ ninja
12
+ einops
13
+ safetensors
14
+ huggingface_hub
15
+ open3d==0.19.0
16
+ ultralytics==8.3.227
17
+ timm==1.0.24
18
+ git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183
19
+ mmcv==2.2.0 --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html
20
+ pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
21
+ git+https://github.com/facebookresearch/segment-anything-2.git
22
+ smplx
static/teaser.svg ADDED
unish/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (6.53 kB). View file
 
unish/heads/__pycache__/align_net.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
unish/heads/__pycache__/dpt_head.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
unish/heads/__pycache__/head_act.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
unish/heads/__pycache__/human_head_cliff.cpython-310.pyc ADDED
Binary file (2.92 kB). View file
 
unish/heads/__pycache__/pose_transformer.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
unish/heads/__pycache__/t_cond_mlp.cpython-310.pyc ADDED
Binary file (6.08 kB). View file
 
unish/heads/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.14 kB). View file
 
unish/heads/__pycache__/vit.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
unish/heads/align_net.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+
7
+ from unish.utils.data_utils import rot6d_to_rotmat
8
+ from unish.utils.constants import SMPL_MEAN_PARAMS
9
+
10
+
11
+ class TimeStepRoPE1D(nn.Module):
12
+ """1D RoPE for timestep embedding, similar to pi3's RoPE2D but for 1D time sequence"""
13
+
14
+ def __init__(self, freq=100.0):
15
+ super().__init__()
16
+ self.base = freq
17
+ self.cache = {}
18
+ self.max_train_len = 120
19
+
20
+ def get_cos_sin(self, D, seq_len, device, dtype):
21
+ if (D, seq_len, device, dtype) in self.cache:
22
+ return self.cache[D, seq_len, device, dtype]
23
+
24
+ if seq_len <= self.max_train_len:
25
+ assert D % 2 == 0
26
+
27
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
28
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
29
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
30
+
31
+ freqs = torch.cat((freqs, freqs), dim=-1)
32
+ cos = freqs.cos() # (seq_len, D)
33
+ sin = freqs.sin() # (seq_len, D)
34
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
35
+ return cos, sin
36
+
37
+ else:
38
+ cos_train, sin_train = self.get_cos_sin(D, self.max_train_len, device, dtype)
39
+ cos_train_res = cos_train.transpose(0, 1).unsqueeze(0)
40
+ sin_train_res = sin_train.transpose(0, 1).unsqueeze(0)
41
+
42
+ # [1, D, max_train_len] -> [1, D, seq_len]
43
+ cos_interp = F.interpolate(cos_train_res, size=seq_len, mode='linear', align_corners=True)
44
+ sin_interp = F.interpolate(sin_train_res, size=seq_len, mode='linear', align_corners=True)
45
+
46
+ # [1, D, seq_len] -> [seq_len, D]
47
+ cos_final = cos_interp.squeeze(0).transpose(0, 1)
48
+ sin_final = sin_interp.squeeze(0).transpose(0, 1)
49
+
50
+ self.cache[D, seq_len, device, dtype] = (cos_final, sin_final)
51
+ return cos_final, sin_final
52
+
53
+ @staticmethod
54
+ def rotate_half(x):
55
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
56
+ return torch.cat((-x2, x1), dim=-1)
57
+
58
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
59
+ """Apply 1D RoPE to tokens based on 1D positions"""
60
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] # [batch, 1, seq_len, D]
61
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] # [batch, 1, seq_len, D]
62
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
63
+
64
+ def forward(self, tokens, positions):
65
+ """
66
+ Apply 1D RoPE to tokens based on timestep positions.
67
+ Args:
68
+ tokens: [batch, num_heads, seq_len, head_dim]
69
+ positions: [batch, seq_len] - timestep positions (0, 1, 2, ...)
70
+ Returns:
71
+ tokens with RoPE applied: [batch, num_heads, seq_len, head_dim]
72
+ """
73
+ head_dim = tokens.size(3)
74
+ assert head_dim % 2 == 0, "head_dim should be a multiple of two"
75
+ assert positions.ndim == 2 # [batch, seq_len]
76
+
77
+ cos, sin = self.get_cos_sin(head_dim, int(positions.max()) + 1, tokens.device, tokens.dtype)
78
+
79
+ return self.apply_rope1d(tokens, positions.long(), cos, sin)
80
+
81
+
82
+ class TransformerDecoderLayer(nn.Module):
83
+ """单层Transformer Decoder with RoPE support"""
84
+
85
+ def __init__(self, hidden_dim=512, num_heads=8, ff_dim=1024, dropout=0.1, use_rope=True):
86
+ super().__init__()
87
+
88
+ self.use_rope = use_rope
89
+ self.hidden_dim = hidden_dim
90
+ self.num_heads = num_heads
91
+ self.head_dim = hidden_dim // num_heads
92
+
93
+ if use_rope:
94
+ self.self_attention = None
95
+ self.cross_attention = None
96
+
97
+ self.self_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
98
+ self.self_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
99
+ self.self_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
100
+ self.self_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
101
+
102
+ self.cross_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
103
+ self.cross_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
104
+ self.cross_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
105
+ self.cross_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
106
+
107
+ # RoPE for timestep embedding
108
+ self.timestep_rope = TimeStepRoPE1D(freq=100.0)
109
+ else:
110
+ self.self_attention = nn.MultiheadAttention(
111
+ embed_dim=hidden_dim,
112
+ num_heads=num_heads,
113
+ dropout=dropout,
114
+ batch_first=True
115
+ )
116
+
117
+ self.cross_attention = nn.MultiheadAttention(
118
+ embed_dim=hidden_dim,
119
+ num_heads=num_heads,
120
+ dropout=dropout,
121
+ batch_first=True
122
+ )
123
+
124
+ self.feed_forward = nn.Sequential(
125
+ nn.Linear(hidden_dim, ff_dim),
126
+ nn.ReLU(),
127
+ nn.Dropout(dropout),
128
+ nn.Linear(ff_dim, hidden_dim),
129
+ nn.Dropout(dropout)
130
+ )
131
+
132
+ self.norm1 = nn.LayerNorm(hidden_dim) # for self attention
133
+ self.norm2 = nn.LayerNorm(hidden_dim) # for cross attention
134
+ self.norm3 = nn.LayerNorm(hidden_dim) # for feed forward
135
+
136
+ # Dropout
137
+ self.dropout = nn.Dropout(dropout)
138
+ self.attn_dropout = nn.Dropout(dropout)
139
+
140
+ # Scale factor for attention
141
+ self.scale = self.head_dim ** -0.5
142
+
143
+ # Gradient checkpointing flag
144
+ self.use_gradient_checkpoint = False
145
+
146
+ def gradient_checkpointing_enable(self):
147
+ """Enable gradient checkpointing for memory optimization."""
148
+ self.use_gradient_checkpoint = True
149
+
150
+ def _rope_attention(self, q_proj, k_proj, v_proj, out_proj, query, key, value, timestep_pos=None):
151
+ """Apply RoPE-based attention using torch.nn.functional.scaled_dot_product_attention"""
152
+ batch_size, seq_len, _ = query.shape
153
+
154
+ # Project Q, K, V
155
+ q = q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
156
+ k = k_proj(key).view(batch_size, key.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
157
+ v = v_proj(value).view(batch_size, value.shape[1], self.num_heads, self.head_dim).transpose(1, 2)
158
+
159
+ # Apply RoPE to Q and K if timestep positions are provided
160
+ if timestep_pos is not None and self.use_rope:
161
+ # For self-attention, both q and k use the same timestep positions
162
+ if query.shape == key.shape: # self-attention case
163
+ q = self.timestep_rope(q, timestep_pos)
164
+ k = self.timestep_rope(k, timestep_pos)
165
+ else: # cross-attention case
166
+ # Only apply RoPE to query (cam_token), key/value are spatial features
167
+ q = self.timestep_rope(q, timestep_pos)
168
+
169
+ attn_output = F.scaled_dot_product_attention(
170
+ q, k, v,
171
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
172
+ scale=self.scale
173
+ )
174
+
175
+ # Reshape output
176
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)
177
+
178
+ # Output projection
179
+ return out_proj(attn_output)
180
+
181
+ def forward(self, query, key, value, self_attn_mask=None, cross_attn_mask=None, timestep_pos=None):
182
+ """
183
+ Args:
184
+ query: [batch, num_views, hidden_dim]
185
+ key: [batch, num_views, hidden_dim]
186
+ value: [batch, num_views, hidden_dim]
187
+ timestep_pos: [batch, num_views] - timestep positions for RoPE
188
+ """
189
+ if self.use_gradient_checkpoint and self.training:
190
+ from torch.utils.checkpoint import checkpoint
191
+
192
+ if self.use_rope:
193
+ # 1. Self Attention + Residual with RoPE (with gradient checkpointing)
194
+ self_attn_output = checkpoint(
195
+ self._rope_attention,
196
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
197
+ query, query, query, timestep_pos,
198
+ use_reentrant=False
199
+ )
200
+ query = self.norm1(query + self.dropout(self_attn_output))
201
+
202
+ # 2. Cross Attention + Residual with RoPE (with gradient checkpointing)
203
+ cross_attn_output = checkpoint(
204
+ self._rope_attention,
205
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
206
+ query, key, value, timestep_pos,
207
+ use_reentrant=False
208
+ )
209
+ query = self.norm2(query + self.dropout(cross_attn_output))
210
+ else:
211
+ # 1. Self Attention + Residual (with gradient checkpointing)
212
+ def self_attn_fn(q, k, v):
213
+ out, _ = self.self_attention(q, k, v, attn_mask=self_attn_mask)
214
+ return out
215
+ self_attn_output = checkpoint(self_attn_fn, query, query, query, use_reentrant=False)
216
+ query = self.norm1(query + self.dropout(self_attn_output))
217
+
218
+ # 2. Cross Attention + Residual (with gradient checkpointing)
219
+ def cross_attn_fn(q, k, v):
220
+ out, _ = self.cross_attention(q, k, v, attn_mask=cross_attn_mask)
221
+ return out
222
+ cross_attn_output = checkpoint(cross_attn_fn, query, key, value, use_reentrant=False)
223
+ query = self.norm2(query + self.dropout(cross_attn_output))
224
+
225
+ # 3. Feed Forward + Residual (with gradient checkpointing)
226
+ ff_output = checkpoint(self.feed_forward, query, use_reentrant=False)
227
+ query = self.norm3(query + ff_output)
228
+ else:
229
+ # Original implementation without gradient checkpointing
230
+ if self.use_rope:
231
+ # 1. Self Attention + Residual with RoPE
232
+ self_attn_output = self._rope_attention(
233
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
234
+ query, query, query, timestep_pos
235
+ )
236
+ query = self.norm1(query + self.dropout(self_attn_output))
237
+
238
+ # 2. Cross Attention + Residual with RoPE
239
+ cross_attn_output = self._rope_attention(
240
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
241
+ query, key, value, timestep_pos
242
+ )
243
+ query = self.norm2(query + self.dropout(cross_attn_output))
244
+ else:
245
+ # 1. Self Attention + Residual (original implementation)
246
+ self_attn_output, _ = self.self_attention(query, query, query, attn_mask=self_attn_mask)
247
+ query = self.norm1(query + self.dropout(self_attn_output))
248
+
249
+ # 2. Cross Attention + Residual (original implementation)
250
+ cross_attn_output, _ = self.cross_attention(query, key, value, attn_mask=cross_attn_mask)
251
+ query = self.norm2(query + self.dropout(cross_attn_output))
252
+
253
+ # 3. Feed Forward + Residual
254
+ ff_output = self.feed_forward(query)
255
+ query = self.norm3(query + ff_output)
256
+
257
+ return query
258
+
259
+
260
+ class CrossViewTransformerDecoderLayer(nn.Module):
261
+ """Cross-view Transformer Decoder Layer for V4 - handles concatenated tokens from multiple views"""
262
+
263
+ def __init__(self, hidden_dim=512, num_heads=8, ff_dim=1024, dropout=0.1, use_rope=True):
264
+ super().__init__()
265
+
266
+ self.use_rope = use_rope
267
+ self.hidden_dim = hidden_dim
268
+ self.num_heads = num_heads
269
+ self.head_dim = hidden_dim // num_heads
270
+
271
+ if use_rope:
272
+ self.self_attention = None
273
+ self.cross_attention = None
274
+
275
+ # Self-attention components
276
+ self.self_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
277
+ self.self_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
278
+ self.self_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
279
+ self.self_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
280
+
281
+ # Cross-attention components
282
+ self.cross_q_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
283
+ self.cross_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
284
+ self.cross_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
285
+ self.cross_out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
286
+
287
+ # RoPE for timestep embedding
288
+ self.timestep_rope = TimeStepRoPE1D(freq=100.0)
289
+ else:
290
+ # Self Attention层
291
+ self.self_attention = nn.MultiheadAttention(
292
+ embed_dim=hidden_dim,
293
+ num_heads=num_heads,
294
+ dropout=dropout,
295
+ batch_first=True
296
+ )
297
+
298
+ # Cross Attention层
299
+ self.cross_attention = nn.MultiheadAttention(
300
+ embed_dim=hidden_dim,
301
+ num_heads=num_heads,
302
+ dropout=dropout,
303
+ batch_first=True
304
+ )
305
+
306
+ self.feed_forward = nn.Sequential(
307
+ nn.Linear(hidden_dim, ff_dim),
308
+ nn.ReLU(),
309
+ nn.Dropout(dropout),
310
+ nn.Linear(ff_dim, hidden_dim),
311
+ nn.Dropout(dropout)
312
+ )
313
+
314
+ self.norm1 = nn.LayerNorm(hidden_dim) # for self attention
315
+ self.norm2 = nn.LayerNorm(hidden_dim) # for cross attention
316
+ self.norm3 = nn.LayerNorm(hidden_dim) # for feed forward
317
+
318
+ self.dropout = nn.Dropout(dropout)
319
+ self.attn_dropout = nn.Dropout(dropout)
320
+
321
+ self.scale = self.head_dim ** -0.5
322
+
323
+ self.use_gradient_checkpoint = False
324
+
325
+ def gradient_checkpointing_enable(self):
326
+ """Enable gradient checkpointing for memory optimization."""
327
+ self.use_gradient_checkpoint = True
328
+
329
+ def _rope_attention(self, q_proj, k_proj, v_proj, out_proj, query, key, value, query_timestep_pos=None, key_timestep_pos=None):
330
+ """Apply RoPE-based attention for cross-view scenarios using torch.nn.functional.scaled_dot_product_attention"""
331
+ batch_size, query_seq_len, _ = query.shape
332
+ _, key_seq_len, _ = key.shape
333
+
334
+ # Project Q, K, V
335
+ q = q_proj(query).view(batch_size, query_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
336
+ k = k_proj(key).view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
337
+ v = v_proj(value).view(batch_size, key_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
338
+
339
+ # Apply RoPE to Q and K if timestep positions are provided
340
+ if self.use_rope:
341
+ if query_timestep_pos is not None:
342
+ q_scale = q[:, :, 0:1, :] # [batch, num_heads, 1, head_dim] - scale token
343
+ q_cam = q[:, :, 1:, :] # [batch, num_heads, num_views, head_dim] - cam tokens
344
+
345
+ cam_timestep_pos = query_timestep_pos[:, 1:]
346
+ q_cam_rope = self.timestep_rope(q_cam, cam_timestep_pos)
347
+
348
+ q = torch.cat([q_scale, q_cam_rope], dim=2)
349
+ if key_timestep_pos is not None:
350
+ k = self.timestep_rope(k, key_timestep_pos)
351
+
352
+ attn_output = F.scaled_dot_product_attention(
353
+ q, k, v,
354
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
355
+ scale=self.scale
356
+ )
357
+
358
+ # Reshape output
359
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_seq_len, self.hidden_dim)
360
+
361
+ # Output projection
362
+ return out_proj(attn_output)
363
+
364
+ def forward(self, query, key, value, query_timestep_pos=None, key_timestep_pos=None):
365
+ """
366
+ Args:
367
+ query: [batch, num_queries, hidden_dim] - cam tokens + scale token
368
+ key: [batch, num_views * num_tokens, hidden_dim] - concatenated feature tokens from all views
369
+ value: [batch, num_views * num_tokens, hidden_dim] - concatenated feature tokens from all views
370
+ query_timestep_pos: [batch, num_queries] - timestep positions for query tokens
371
+ key_timestep_pos: [batch, num_views * num_tokens] - timestep positions for key/value tokens
372
+ """
373
+ if self.use_gradient_checkpoint and self.training:
374
+ from torch.utils.checkpoint import checkpoint
375
+
376
+ if self.use_rope:
377
+ # 1. Self Attention + Residual with RoPE (with gradient checkpointing)
378
+ self_attn_output = checkpoint(
379
+ self._rope_attention,
380
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
381
+ query, query, query, query_timestep_pos, query_timestep_pos,
382
+ use_reentrant=False
383
+ )
384
+ query = self.norm1(query + self.dropout(self_attn_output))
385
+
386
+ # 2. Cross Attention + Residual with RoPE (with gradient checkpointing)
387
+ cross_attn_output = checkpoint(
388
+ self._rope_attention,
389
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
390
+ query, key, value, query_timestep_pos, key_timestep_pos,
391
+ use_reentrant=False
392
+ )
393
+ query = self.norm2(query + self.dropout(cross_attn_output))
394
+ else:
395
+ # 1. Self Attention + Residual (with gradient checkpointing)
396
+ def self_attn_fn(q, k, v):
397
+ out, _ = self.self_attention(q, k, v)
398
+ return out
399
+ self_attn_output = checkpoint(self_attn_fn, query, query, query, use_reentrant=False)
400
+ query = self.norm1(query + self.dropout(self_attn_output))
401
+
402
+ # 2. Cross Attention + Residual (with gradient checkpointing)
403
+ def cross_attn_fn(q, k, v):
404
+ out, _ = self.cross_attention(q, k, v)
405
+ return out
406
+ cross_attn_output = checkpoint(cross_attn_fn, query, key, value, use_reentrant=False)
407
+ query = self.norm2(query + self.dropout(cross_attn_output))
408
+
409
+ # 3. Feed Forward + Residual (with gradient checkpointing)
410
+ ff_output = checkpoint(self.feed_forward, query, use_reentrant=False)
411
+ query = self.norm3(query + ff_output)
412
+ else:
413
+ # Original implementation without gradient checkpointing
414
+ if self.use_rope:
415
+ # 1. Self Attention + Residual with RoPE
416
+ self_attn_output = self._rope_attention(
417
+ self.self_q_proj, self.self_k_proj, self.self_v_proj, self.self_out_proj,
418
+ query, query, query, query_timestep_pos, query_timestep_pos
419
+ )
420
+ query = self.norm1(query + self.dropout(self_attn_output))
421
+
422
+ # 2. Cross Attention + Residual with RoPE
423
+ cross_attn_output = self._rope_attention(
424
+ self.cross_q_proj, self.cross_k_proj, self.cross_v_proj, self.cross_out_proj,
425
+ query, key, value, query_timestep_pos, key_timestep_pos
426
+ )
427
+ query = self.norm2(query + self.dropout(cross_attn_output))
428
+ else:
429
+ # 1. Self Attention + Residual (original implementation)
430
+ self_attn_output, _ = self.self_attention(query, query, query)
431
+ query = self.norm1(query + self.dropout(self_attn_output))
432
+
433
+ # 2. Cross Attention + Residual (original implementation)
434
+ cross_attn_output, _ = self.cross_attention(query, key, value)
435
+ query = self.norm2(query + self.dropout(cross_attn_output))
436
+
437
+ # 3. Feed Forward + Residual
438
+ ff_output = self.feed_forward(query)
439
+ query = self.norm3(query + ff_output)
440
+
441
+ return query
442
+
443
+
444
+ class AlignNet(nn.Module):
445
+ def __init__(self, aggregated_dim=2048, cam_dim=1024, hidden_dim=512, num_heads=8, ff_dim=512, dropout=0.1, use_rope=True, num_decoder_layers=2):
446
+ super().__init__()
447
+
448
+ self.use_rope = use_rope
449
+ self.hidden_dim = hidden_dim
450
+ self.num_decoder_layers = num_decoder_layers
451
+
452
+ self.scale_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
453
+
454
+ self.cam_feature_adapter = nn.Sequential(
455
+ nn.LayerNorm(cam_dim),
456
+ nn.Linear(cam_dim, hidden_dim),
457
+ nn.ReLU(),
458
+ nn.Dropout(dropout)
459
+ )
460
+
461
+ self.patch_feature_adapter = nn.Sequential(
462
+ nn.LayerNorm(aggregated_dim),
463
+ nn.Linear(aggregated_dim, hidden_dim),
464
+ nn.ReLU(),
465
+ nn.Dropout(dropout)
466
+ )
467
+ self.register_feature_adapter = nn.Sequential(
468
+ nn.LayerNorm(aggregated_dim),
469
+ nn.Linear(aggregated_dim, hidden_dim),
470
+ nn.ReLU(),
471
+ nn.Dropout(dropout)
472
+ )
473
+
474
+ self.decoder_layers = nn.ModuleList([
475
+ CrossViewTransformerDecoderLayer(hidden_dim, num_heads, ff_dim, dropout, use_rope=use_rope)
476
+ for _ in range(num_decoder_layers)
477
+ ])
478
+
479
+ mean_params = SMPL_MEAN_PARAMS
480
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
481
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
482
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
483
+ self.register_buffer('init_body_pose', init_body_pose)
484
+ self.register_buffer('init_betas', init_betas)
485
+ self.register_buffer('init_cam', init_cam)
486
+
487
+ self.trans_head = nn.Linear(hidden_dim, 3)
488
+
489
+ self.scale_head = nn.Linear(hidden_dim, 1)
490
+
491
+ self.joint_conversion_fn = rot6d_to_rotmat
492
+
493
+ def gradient_checkpointing_enable(self):
494
+ """Enable gradient checkpointing for memory optimization."""
495
+ for layer in self.decoder_layers:
496
+ if hasattr(layer, 'gradient_checkpointing_enable'):
497
+ layer.gradient_checkpointing_enable()
498
+
499
+ def forward(self, hidden_tokens, cam_token, fps=6.0):
500
+ batch_size, num_views, num_tokens, _ = hidden_tokens.shape
501
+
502
+ register_tokens = hidden_tokens[:, :, :5, :]
503
+ patch_tokens = hidden_tokens[:, :, 5:, :]
504
+
505
+ if cam_token.dim() == 4:
506
+ cam_token = cam_token.squeeze(2) # [batch, num_views, 1, 1024] -> [batch, num_views, 1024]
507
+
508
+ cam_adapted = self.cam_feature_adapter(cam_token) # [batch, num_views, hidden_dim]
509
+
510
+ patch_tokens_reshaped = patch_tokens.view(batch_size * num_views, patch_tokens.shape[2], -1) # [batch*num_views, 777, 2048]
511
+ patch_adapted_tokens = self.patch_feature_adapter(patch_tokens_reshaped) # [batch*num_views, 777, hidden_dim]
512
+ patch_adapted_tokens = patch_adapted_tokens.view(batch_size, num_views, patch_tokens.shape[2], -1) # [batch, num_views, 777, hidden_dim]
513
+
514
+ register_tokens_reshaped = register_tokens.view(batch_size * num_views, 5, -1) # [batch*num_views, 5, 2048]
515
+ register_adapted_tokens = self.register_feature_adapter(register_tokens_reshaped) # [batch*num_views, 5, hidden_dim]
516
+ register_adapted_tokens = register_adapted_tokens.view(batch_size, num_views, 5, -1) # [batch, num_views, 5, hidden_dim]
517
+
518
+ fused_features_per_view = torch.cat([register_adapted_tokens, patch_adapted_tokens], dim=2) # [batch, num_views, 782, hidden_dim]
519
+
520
+ concatenated_features = fused_features_per_view.view(batch_size, num_views * num_tokens, -1)
521
+
522
+ scale_token_expanded = self.scale_token.expand(batch_size, -1, -1)
523
+
524
+ query_tokens = torch.cat([scale_token_expanded, cam_adapted], dim=1)
525
+
526
+ if self.use_rope:
527
+ base_fps = 6.0
528
+
529
+ time_scale = base_fps / fps
530
+
531
+ scale_timestep = torch.zeros((batch_size, 1), device=cam_adapted.device, dtype=torch.long)
532
+
533
+ cam_timestep_float = torch.arange(num_views, device=cam_adapted.device, dtype=torch.float32) * time_scale
534
+ cam_timestep = cam_timestep_float.round().long().unsqueeze(0).expand(batch_size, -1)
535
+ query_timestep_pos = torch.cat([scale_timestep, cam_timestep], dim=1) # [batch, 1 + num_views]
536
+
537
+ key_timestep_base_float = torch.arange(num_views, device=cam_adapted.device, dtype=torch.float32) * time_scale
538
+ key_timestep_base = key_timestep_base_float.round().long()
539
+ key_timestep_pos = key_timestep_base.unsqueeze(1).expand(-1, num_tokens).flatten()
540
+ key_timestep_pos = key_timestep_pos.unsqueeze(0).expand(batch_size, -1) # [batch, num_views * num_tokens]
541
+ else:
542
+ query_timestep_pos = None
543
+ key_timestep_pos = None
544
+
545
+ decoder_output = query_tokens
546
+ for i, layer in enumerate(self.decoder_layers):
547
+ residual = decoder_output
548
+
549
+ decoder_output = layer(
550
+ decoder_output, concatenated_features, concatenated_features,
551
+ query_timestep_pos=query_timestep_pos, key_timestep_pos=key_timestep_pos
552
+ )
553
+
554
+ decoder_output = decoder_output + residual
555
+
556
+ scale_output = decoder_output[:, 0, :]
557
+ cam_outputs = decoder_output[:, 1:, :]
558
+
559
+ scale_logits = self.scale_head(scale_output) # [batch, 1]
560
+ scale = F.softplus(scale_logits)
561
+
562
+ trans_raw = self.trans_head(cam_outputs) # [batch, num_views, 3]
563
+ xy, z = trans_raw.split([2, 1], dim=-1) # xy: [batch, num_views, 2], z: [batch, num_views, 1]
564
+ z = torch.exp(z)
565
+ trans = torch.cat([xy * z, z], dim=-1) # [batch, num_views, 3]
566
+
567
+
568
+ return {
569
+ "scale": scale, # [batch, 1]
570
+ "trans_cam": trans, # [batch, num_views, 3]
571
+ }
unish/heads/dpt_head.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ x = x.to(self.projects[0].weight.dtype)
222
+
223
+ # Select frames if processing a chunk
224
+ if frames_start_idx is not None and frames_end_idx is not None:
225
+ x = x[:, frames_start_idx:frames_end_idx]
226
+
227
+ x = x.reshape(B * S, -1, x.shape[-1])
228
+
229
+ x = self.norm(x)
230
+
231
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
232
+
233
+ x = self.projects[dpt_idx](x)
234
+ if self.pos_embed:
235
+ x = self._apply_pos_embed(x, W, H).to(self.projects[0].weight.dtype)
236
+
237
+ x = self.resize_layers[dpt_idx](x)
238
+
239
+ out.append(x)
240
+ dpt_idx += 1
241
+
242
+ # Fuse features from multiple layers.
243
+ out = self.scratch_forward(out)
244
+ # Interpolate fused output to match target image resolution.
245
+ out = custom_interpolate(
246
+ out,
247
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
248
+ mode="bilinear",
249
+ align_corners=True,
250
+ )
251
+
252
+ if self.pos_embed:
253
+ out = self._apply_pos_embed(out, W, H).to(self.projects[0].weight.dtype)
254
+
255
+ if self.feature_only:
256
+ return out.view(B, S, *out.shape[1:])
257
+
258
+ out = self.scratch.output_conv2(out)
259
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
260
+
261
+ preds = preds.view(B, S, *preds.shape[1:])
262
+ conf = conf.view(B, S, *conf.shape[1:])
263
+ return preds, conf
264
+
265
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
266
+ """
267
+ Apply positional embedding to tensor x.
268
+ """
269
+ patch_w = x.shape[-1]
270
+ patch_h = x.shape[-2]
271
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
272
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
273
+ pos_embed = pos_embed * ratio
274
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
275
+ return x + pos_embed
276
+
277
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
278
+ """
279
+ Forward pass through the fusion blocks.
280
+
281
+ Args:
282
+ features (List[Tensor]): List of feature maps from different layers.
283
+
284
+ Returns:
285
+ Tensor: Fused feature map.
286
+ """
287
+ layer_1, layer_2, layer_3, layer_4 = features
288
+
289
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
290
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
291
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
292
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
293
+
294
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
295
+ del layer_4_rn, layer_4
296
+
297
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
298
+ del layer_3_rn, layer_3
299
+
300
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
301
+ del layer_2_rn, layer_2
302
+
303
+ out = self.scratch.refinenet1(out, layer_1_rn)
304
+ del layer_1_rn, layer_1
305
+
306
+ out = self.scratch.output_conv1(out)
307
+ return out
308
+
309
+
310
+ ################################################################################
311
+ # Modules
312
+ ################################################################################
313
+
314
+
315
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
316
+ return FeatureFusionBlock(
317
+ features,
318
+ nn.ReLU(inplace=True),
319
+ deconv=False,
320
+ bn=False,
321
+ expand=False,
322
+ align_corners=True,
323
+ size=size,
324
+ has_residual=has_residual,
325
+ groups=groups,
326
+ )
327
+
328
+
329
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
330
+ scratch = nn.Module()
331
+ out_shape1 = out_shape
332
+ out_shape2 = out_shape
333
+ out_shape3 = out_shape
334
+ if len(in_shape) >= 4:
335
+ out_shape4 = out_shape
336
+
337
+ if expand:
338
+ out_shape1 = out_shape
339
+ out_shape2 = out_shape * 2
340
+ out_shape3 = out_shape * 4
341
+ if len(in_shape) >= 4:
342
+ out_shape4 = out_shape * 8
343
+
344
+ scratch.layer1_rn = nn.Conv2d(
345
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer2_rn = nn.Conv2d(
348
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ scratch.layer3_rn = nn.Conv2d(
351
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
352
+ )
353
+ if len(in_shape) >= 4:
354
+ scratch.layer4_rn = nn.Conv2d(
355
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
356
+ )
357
+ return scratch
358
+
359
+
360
+ class ResidualConvUnit(nn.Module):
361
+ """Residual convolution module."""
362
+
363
+ def __init__(self, features, activation, bn, groups=1):
364
+ """Init.
365
+
366
+ Args:
367
+ features (int): number of features
368
+ """
369
+ super().__init__()
370
+
371
+ self.bn = bn
372
+ self.groups = groups
373
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
374
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
375
+
376
+ self.norm1 = None
377
+ self.norm2 = None
378
+
379
+ self.activation = activation
380
+ self.skip_add = nn.quantized.FloatFunctional()
381
+
382
+ def forward(self, x):
383
+ """Forward pass.
384
+
385
+ Args:
386
+ x (tensor): input
387
+
388
+ Returns:
389
+ tensor: output
390
+ """
391
+
392
+ out = self.activation(x)
393
+ out = self.conv1(out)
394
+ if self.norm1 is not None:
395
+ out = self.norm1(out)
396
+
397
+ out = self.activation(out)
398
+ out = self.conv2(out)
399
+ if self.norm2 is not None:
400
+ out = self.norm2(out)
401
+
402
+ return self.skip_add.add(out, x)
403
+
404
+
405
+ class FeatureFusionBlock(nn.Module):
406
+ """Feature fusion block."""
407
+
408
+ def __init__(
409
+ self,
410
+ features,
411
+ activation,
412
+ deconv=False,
413
+ bn=False,
414
+ expand=False,
415
+ align_corners=True,
416
+ size=None,
417
+ has_residual=True,
418
+ groups=1,
419
+ ):
420
+ """Init.
421
+
422
+ Args:
423
+ features (int): number of features
424
+ """
425
+ super(FeatureFusionBlock, self).__init__()
426
+
427
+ self.deconv = deconv
428
+ self.align_corners = align_corners
429
+ self.groups = groups
430
+ self.expand = expand
431
+ out_features = features
432
+ if self.expand == True:
433
+ out_features = features // 2
434
+
435
+ self.out_conv = nn.Conv2d(
436
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
437
+ )
438
+
439
+ if has_residual:
440
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.has_residual = has_residual
443
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
444
+
445
+ self.skip_add = nn.quantized.FloatFunctional()
446
+ self.size = size
447
+
448
+ def forward(self, *xs, size=None):
449
+ """Forward pass.
450
+
451
+ Returns:
452
+ tensor: output
453
+ """
454
+ output = xs[0]
455
+
456
+ if self.has_residual:
457
+ res = self.resConfUnit1(xs[1])
458
+ output = self.skip_add.add(output, res)
459
+
460
+ output = self.resConfUnit2(output)
461
+
462
+ if (size is None) and (self.size is None):
463
+ modifier = {"scale_factor": 2}
464
+ elif size is None:
465
+ modifier = {"size": self.size}
466
+ else:
467
+ modifier = {"size": size}
468
+
469
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
470
+ output = self.out_conv(output)
471
+
472
+ return output
473
+
474
+
475
+ def custom_interpolate(
476
+ x: torch.Tensor,
477
+ size: Tuple[int, int] = None,
478
+ scale_factor: float = None,
479
+ mode: str = "bilinear",
480
+ align_corners: bool = True,
481
+ ) -> torch.Tensor:
482
+ """
483
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
484
+ """
485
+ if size is None:
486
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
487
+
488
+ INT_MAX = 1610612736
489
+
490
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
491
+
492
+ if input_elements > INT_MAX:
493
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
494
+ interpolated_chunks = [
495
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
496
+ ]
497
+ x = torch.cat(interpolated_chunks, dim=0)
498
+ return x.contiguous()
499
+ else:
500
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
unish/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
unish/heads/human_head_cliff.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import einops
6
+
7
+
8
+ from unish.utils.data_utils import rot6d_to_rotmat
9
+ from unish.utils.constants import SMPL_MEAN_PARAMS
10
+ from .pose_transformer import TransformerDecoder
11
+
12
+ TRANSFORMER_DECODER={'depth': 6,
13
+ 'heads': 8,
14
+ 'mlp_dim': 1024,
15
+ 'dim_head': 64,
16
+ 'dropout': 0.0,
17
+ 'emb_dropout': 0.0,
18
+ 'norm': 'layer',
19
+ 'context_dim': 1280}
20
+
21
+ NUM_POSE_INPUT = 23
22
+ NUM_BETAS_INPUT = 10
23
+ NUM_BETAS = 10
24
+ NUM_POSE_PARAMS = 23
25
+
26
+ class HumanHeadCliff(nn.Module):
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.joint_rep_dim = 6
31
+ npose = self.joint_rep_dim * (NUM_POSE_INPUT + 1)
32
+ self.npose = npose
33
+ transformer_args = dict(
34
+ num_tokens=1,
35
+ token_dim=(3 + npose + NUM_BETAS_INPUT + 3),
36
+ dim=1024,
37
+ )
38
+ transformer_args = (transformer_args | dict(TRANSFORMER_DECODER))
39
+ self.transformer = TransformerDecoder(
40
+ **transformer_args
41
+ )
42
+ dim=transformer_args['dim']
43
+ self.decpose = nn.Linear(dim, self.joint_rep_dim * (NUM_POSE_PARAMS + 1))
44
+ self.decshape = nn.Linear(dim, NUM_BETAS)
45
+ # self.deccam = nn.Linear(dim, 3)
46
+ # self.deckp = nn.Linear(dim, 88)
47
+
48
+ mean_params = SMPL_MEAN_PARAMS
49
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
50
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
51
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
52
+ self.register_buffer('init_body_pose', init_body_pose)
53
+ self.register_buffer('init_betas', init_betas)
54
+ self.register_buffer('init_cam', init_cam)
55
+
56
+ def gradient_checkpointing_enable(self):
57
+ """Enable gradient checkpointing for memory optimization."""
58
+ if hasattr(self.transformer, 'gradient_checkpointing_enable'):
59
+ self.transformer.gradient_checkpointing_enable()
60
+
61
+ def forward(self, x, bbox_info, **kwargs):
62
+ """
63
+ x: (B, N, C, H, W)
64
+ bbox_info: [cx / f, cy / f, box_size / f], (B, N, 3)
65
+ """
66
+
67
+ batch_size, num_views = x.shape[:2]
68
+ x = einops.rearrange(x, 'b n c h w -> (b n) (h w) c')
69
+
70
+ init_body_pose = self.init_body_pose.expand(batch_size * num_views, -1)
71
+ init_betas = self.init_betas.expand(batch_size * num_views, -1)
72
+ init_cam = self.init_cam.expand(batch_size * num_views, -1)
73
+ bbox_info = bbox_info.view(-1, 3)
74
+
75
+ pred_body_pose = init_body_pose
76
+ pred_betas = init_betas
77
+ pred_cam = init_cam
78
+ token = torch.cat([bbox_info, pred_body_pose, pred_betas, pred_cam], dim=-1)[:, None, :]
79
+
80
+ # Pass through transformer
81
+ token_out = self.transformer(token, context=x)
82
+ token_out = token_out.squeeze(1) # (B, C)
83
+
84
+ pred_body_pose = self.decpose(token_out) + pred_body_pose
85
+ pred_betas = self.decshape(token_out) + pred_betas
86
+
87
+ joint_conversion_fn = rot6d_to_rotmat
88
+
89
+ pred_body_pose = pred_body_pose.view(-1, 6)
90
+ pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, num_views, -1)
91
+ pred_betas = pred_betas.view(batch_size, num_views, -1).mean(dim=1)
92
+ token_out = token_out.view(batch_size, num_views, -1)
93
+
94
+ pred_smpl_params = {'pose_cam': pred_body_pose,
95
+ 'token_out': token_out,
96
+ 'betas': pred_betas}
97
+ return pred_smpl_params
unish/heads/pose_transformer.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from .t_cond_mlp import (
10
+ AdaptiveLayerNorm1D,
11
+ FrequencyEmbedder,
12
+ normalization_layer,
13
+ )
14
+ # from .vit import Attention, FeedForward
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if isfunction(d) else d
25
+
26
+
27
+ class PreNorm(nn.Module):
28
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
29
+ super().__init__()
30
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
31
+ self.fn = fn
32
+
33
+ def forward(self, x: torch.Tensor, *args, **kwargs):
34
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
35
+ return self.fn(self.norm(x, *args), **kwargs)
36
+ else:
37
+ return self.fn(self.norm(x), **kwargs)
38
+
39
+
40
+ class FeedForward(nn.Module):
41
+ def __init__(self, dim, hidden_dim, dropout=0.0):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.Linear(dim, hidden_dim),
45
+ nn.GELU(),
46
+ nn.Dropout(dropout),
47
+ nn.Linear(hidden_dim, dim),
48
+ nn.Dropout(dropout),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
57
+ super().__init__()
58
+ inner_dim = dim_head * heads
59
+ project_out = not (heads == 1 and dim_head == dim)
60
+
61
+ self.heads = heads
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.attend = nn.Softmax(dim=-1)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
68
+
69
+ self.to_out = (
70
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
71
+ if project_out
72
+ else nn.Identity()
73
+ )
74
+
75
+ def forward(self, x):
76
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
77
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
78
+
79
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
80
+
81
+ attn = self.attend(dots)
82
+ attn = self.dropout(attn)
83
+
84
+ out = torch.matmul(attn, v)
85
+ out = rearrange(out, "b h n d -> b n (h d)")
86
+ return self.to_out(out)
87
+
88
+
89
+ class CrossAttention(nn.Module):
90
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
91
+ super().__init__()
92
+ inner_dim = dim_head * heads
93
+ project_out = not (heads == 1 and dim_head == dim)
94
+
95
+ self.heads = heads
96
+ self.scale = dim_head**-0.5
97
+
98
+ self.attend = nn.Softmax(dim=-1)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ context_dim = default(context_dim, dim)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
103
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
104
+
105
+ self.to_out = (
106
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
107
+ if project_out
108
+ else nn.Identity()
109
+ )
110
+
111
+ def forward(self, x, context=None):
112
+ context = default(context, x)
113
+ k, v = self.to_kv(context).chunk(2, dim=-1)
114
+ q = self.to_q(x)
115
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
116
+
117
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
118
+
119
+ attn = self.attend(dots)
120
+ attn = self.dropout(attn)
121
+
122
+ out = torch.matmul(attn, v)
123
+ out = rearrange(out, "b h n d -> b n (h d)")
124
+ return self.to_out(out)
125
+
126
+
127
+ class Transformer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ depth: int,
132
+ heads: int,
133
+ dim_head: int,
134
+ mlp_dim: int,
135
+ dropout: float = 0.0,
136
+ norm: str = "layer",
137
+ norm_cond_dim: int = -1,
138
+ ):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
143
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
144
+ self.layers.append(
145
+ nn.ModuleList(
146
+ [
147
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
148
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
149
+ ]
150
+ )
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor, *args):
154
+ for attn, ff in self.layers:
155
+ x = attn(x, *args) + x
156
+ x = ff(x, *args) + x
157
+ return x
158
+
159
+
160
+ class TransformerCrossAttn(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ depth: int,
165
+ heads: int,
166
+ dim_head: int,
167
+ mlp_dim: int,
168
+ dropout: float = 0.0,
169
+ norm: str = "layer",
170
+ norm_cond_dim: int = -1,
171
+ context_dim: Optional[int] = None,
172
+ ):
173
+ super().__init__()
174
+ self.layers = nn.ModuleList([])
175
+ for _ in range(depth):
176
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
177
+ ca = CrossAttention(
178
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
179
+ )
180
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
181
+ self.layers.append(
182
+ nn.ModuleList(
183
+ [
184
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
185
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
186
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
187
+ ]
188
+ )
189
+ )
190
+
191
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
192
+ if context_list is None:
193
+ context_list = [context] * len(self.layers)
194
+ if len(context_list) != len(self.layers):
195
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
196
+
197
+ b, n = x.shape[:2]
198
+
199
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
200
+ x = self_attn(x, *args) + x
201
+ # TODO
202
+ # x = x.view(b*n, 1, -1)
203
+ x = cross_attn(x, *args, context=context_list[i]) + x
204
+ # x = x.view(b, n, -1)
205
+ x = ff(x, *args) + x
206
+ return x
207
+
208
+
209
+ class DropTokenDropout(nn.Module):
210
+ def __init__(self, p: float = 0.1):
211
+ super().__init__()
212
+ if p < 0 or p > 1:
213
+ raise ValueError(
214
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
215
+ )
216
+ self.p = p
217
+
218
+ def forward(self, x: torch.Tensor):
219
+ # x: (batch_size, seq_len, dim)
220
+ if self.training and self.p > 0:
221
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
222
+ # TODO: permutation idx for each batch using torch.argsort
223
+ if zero_mask.any():
224
+ x = x[:, ~zero_mask, :]
225
+ return x
226
+
227
+
228
+ class ZeroTokenDropout(nn.Module):
229
+ def __init__(self, p: float = 0.1):
230
+ super().__init__()
231
+ if p < 0 or p > 1:
232
+ raise ValueError(
233
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
234
+ )
235
+ self.p = p
236
+
237
+ def forward(self, x: torch.Tensor):
238
+ # x: (batch_size, seq_len, dim)
239
+ if self.training and self.p > 0:
240
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
241
+ # Zero-out the masked tokens
242
+ x[zero_mask, :] = 0
243
+ return x
244
+
245
+
246
+ class TransformerEncoder(nn.Module):
247
+ def __init__(
248
+ self,
249
+ num_tokens: int,
250
+ token_dim: int,
251
+ dim: int,
252
+ depth: int,
253
+ heads: int,
254
+ mlp_dim: int,
255
+ dim_head: int = 64,
256
+ dropout: float = 0.0,
257
+ emb_dropout: float = 0.0,
258
+ emb_dropout_type: str = "drop",
259
+ emb_dropout_loc: str = "token",
260
+ norm: str = "layer",
261
+ norm_cond_dim: int = -1,
262
+ token_pe_numfreq: int = -1,
263
+ ):
264
+ super().__init__()
265
+ if token_pe_numfreq > 0:
266
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
267
+ self.to_token_embedding = nn.Sequential(
268
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
269
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
270
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
271
+ nn.Linear(token_dim_new, dim),
272
+ )
273
+ else:
274
+ self.to_token_embedding = nn.Linear(token_dim, dim)
275
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
276
+ if emb_dropout_type == "drop":
277
+ self.dropout = DropTokenDropout(emb_dropout)
278
+ elif emb_dropout_type == "zero":
279
+ self.dropout = ZeroTokenDropout(emb_dropout)
280
+ else:
281
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
282
+ self.emb_dropout_loc = emb_dropout_loc
283
+
284
+ self.transformer = Transformer(
285
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
286
+ )
287
+
288
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
289
+ x = inp
290
+
291
+ if self.emb_dropout_loc == "input":
292
+ x = self.dropout(x)
293
+ x = self.to_token_embedding(x)
294
+
295
+ if self.emb_dropout_loc == "token":
296
+ x = self.dropout(x)
297
+ b, n, _ = x.shape
298
+ x += self.pos_embedding[:, :n]
299
+
300
+ if self.emb_dropout_loc == "token_afterpos":
301
+ x = self.dropout(x)
302
+ x = self.transformer(x, *args)
303
+ return x
304
+
305
+
306
+ class TransformerDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ num_tokens: int,
310
+ token_dim: int,
311
+ dim: int,
312
+ depth: int,
313
+ heads: int,
314
+ mlp_dim: int,
315
+ dim_head: int = 64,
316
+ dropout: float = 0.0,
317
+ emb_dropout: float = 0.0,
318
+ emb_dropout_type: str = 'drop',
319
+ norm: str = "layer",
320
+ norm_cond_dim: int = -1,
321
+ context_dim: Optional[int] = None,
322
+ skip_token_embedding: bool = False,
323
+ ):
324
+ super().__init__()
325
+ if not skip_token_embedding:
326
+ self.to_token_embedding = nn.Linear(token_dim, dim)
327
+ else:
328
+ self.to_token_embedding = nn.Identity()
329
+ if token_dim != dim:
330
+ raise ValueError(
331
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
332
+ )
333
+
334
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
335
+ if emb_dropout_type == "drop":
336
+ self.dropout = DropTokenDropout(emb_dropout)
337
+ elif emb_dropout_type == "zero":
338
+ self.dropout = ZeroTokenDropout(emb_dropout)
339
+ elif emb_dropout_type == "normal":
340
+ self.dropout = nn.Dropout(emb_dropout)
341
+
342
+ self.transformer = TransformerCrossAttn(
343
+ dim,
344
+ depth,
345
+ heads,
346
+ dim_head,
347
+ mlp_dim,
348
+ dropout,
349
+ norm=norm,
350
+ norm_cond_dim=norm_cond_dim,
351
+ context_dim=context_dim,
352
+ )
353
+
354
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
355
+
356
+ x = self.to_token_embedding(inp)
357
+ b, n, _ = x.shape
358
+
359
+ x = self.dropout(x)
360
+ x += self.pos_embedding[:, :n]
361
+
362
+ x = self.transformer(x, *args, context=context, context_list=context_list)
363
+ return x
364
+
unish/heads/t_cond_mlp.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+
7
+ class AdaptiveLayerNorm1D(torch.nn.Module):
8
+ def __init__(self, data_dim: int, norm_cond_dim: int):
9
+ super().__init__()
10
+ if data_dim <= 0:
11
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
12
+ if norm_cond_dim <= 0:
13
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
14
+ self.norm = torch.nn.LayerNorm(
15
+ data_dim
16
+ ) # TODO: Check if elementwise_affine=True is correct
17
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
18
+ torch.nn.init.zeros_(self.linear.weight)
19
+ torch.nn.init.zeros_(self.linear.bias)
20
+
21
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
22
+ # x: (batch, ..., data_dim)
23
+ # t: (batch, norm_cond_dim)
24
+ # return: (batch, data_dim)
25
+ x = self.norm(x)
26
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
27
+
28
+ # Add singleton dimensions to alpha and beta
29
+ if x.dim() > 2:
30
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
31
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
32
+
33
+ return x * (1 + alpha) + beta
34
+
35
+
36
+ class SequentialCond(torch.nn.Sequential):
37
+ def forward(self, input, *args, **kwargs):
38
+ for module in self:
39
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
40
+ # print(f'Passing on args to {module}', [a.shape for a in args])
41
+ input = module(input, *args, **kwargs)
42
+ else:
43
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
44
+ input = module(input)
45
+ return input
46
+
47
+
48
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
49
+ if norm == "batch":
50
+ return torch.nn.BatchNorm1d(dim)
51
+ elif norm == "layer":
52
+ return torch.nn.LayerNorm(dim)
53
+ elif norm == "ada":
54
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
55
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
56
+ elif norm is None:
57
+ return torch.nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unknown norm: {norm}")
60
+
61
+
62
+ def linear_norm_activ_dropout(
63
+ input_dim: int,
64
+ output_dim: int,
65
+ activation: torch.nn.Module = torch.nn.ReLU(),
66
+ bias: bool = True,
67
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
68
+ dropout: float = 0.0,
69
+ norm_cond_dim: int = -1,
70
+ ) -> SequentialCond:
71
+ layers = []
72
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
73
+ if norm is not None:
74
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
75
+ layers.append(copy.deepcopy(activation))
76
+ if dropout > 0.0:
77
+ layers.append(torch.nn.Dropout(dropout))
78
+ return SequentialCond(*layers)
79
+
80
+
81
+ def create_simple_mlp(
82
+ input_dim: int,
83
+ hidden_dims: List[int],
84
+ output_dim: int,
85
+ activation: torch.nn.Module = torch.nn.ReLU(),
86
+ bias: bool = True,
87
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
88
+ dropout: float = 0.0,
89
+ norm_cond_dim: int = -1,
90
+ ) -> SequentialCond:
91
+ layers = []
92
+ prev_dim = input_dim
93
+ for hidden_dim in hidden_dims:
94
+ layers.extend(
95
+ linear_norm_activ_dropout(
96
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
97
+ )
98
+ )
99
+ prev_dim = hidden_dim
100
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
101
+ return SequentialCond(*layers)
102
+
103
+
104
+ class ResidualMLPBlock(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ input_dim: int,
108
+ hidden_dim: int,
109
+ num_hidden_layers: int,
110
+ output_dim: int,
111
+ activation: torch.nn.Module = torch.nn.ReLU(),
112
+ bias: bool = True,
113
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
114
+ dropout: float = 0.0,
115
+ norm_cond_dim: int = -1,
116
+ ):
117
+ super().__init__()
118
+ if not (input_dim == output_dim == hidden_dim):
119
+ raise NotImplementedError(
120
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
121
+ )
122
+
123
+ layers = []
124
+ prev_dim = input_dim
125
+ for i in range(num_hidden_layers):
126
+ layers.append(
127
+ linear_norm_activ_dropout(
128
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
129
+ )
130
+ )
131
+ prev_dim = hidden_dim
132
+ self.model = SequentialCond(*layers)
133
+ self.skip = torch.nn.Identity()
134
+
135
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
136
+ return x + self.model(x, *args, **kwargs)
137
+
138
+
139
+ class ResidualMLP(torch.nn.Module):
140
+ def __init__(
141
+ self,
142
+ input_dim: int,
143
+ hidden_dim: int,
144
+ num_hidden_layers: int,
145
+ output_dim: int,
146
+ activation: torch.nn.Module = torch.nn.ReLU(),
147
+ bias: bool = True,
148
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
149
+ dropout: float = 0.0,
150
+ num_blocks: int = 1,
151
+ norm_cond_dim: int = -1,
152
+ ):
153
+ super().__init__()
154
+ self.input_dim = input_dim
155
+ self.model = SequentialCond(
156
+ linear_norm_activ_dropout(
157
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
158
+ ),
159
+ *[
160
+ ResidualMLPBlock(
161
+ hidden_dim,
162
+ hidden_dim,
163
+ num_hidden_layers,
164
+ hidden_dim,
165
+ activation,
166
+ bias,
167
+ norm,
168
+ dropout,
169
+ norm_cond_dim,
170
+ )
171
+ for _ in range(num_blocks)
172
+ ],
173
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
174
+ )
175
+
176
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
177
+ return self.model(x, *args, **kwargs)
178
+
179
+
180
+ class FrequencyEmbedder(torch.nn.Module):
181
+ def __init__(self, num_frequencies, max_freq_log2):
182
+ super().__init__()
183
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
184
+ self.register_buffer("frequencies", frequencies)
185
+
186
+ def forward(self, x):
187
+ # x should be of size (N,) or (N, D)
188
+ N = x.size(0)
189
+ if x.dim() == 1: # (N,)
190
+ x = x.unsqueeze(1) # (N, D) where D=1
191
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
192
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
193
+ s = torch.sin(scaled)
194
+ c = torch.cos(scaled)
195
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
196
+ N, -1
197
+ ) # (N, D * 2 * num_frequencies + D)
198
+ return embedded
199
+
unish/heads/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
unish/heads/vit.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ from functools import partial
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
11
+
12
+ def vit():
13
+ return ViT(
14
+ img_size=(256, 192),
15
+ patch_size=16,
16
+ embed_dim=1280,
17
+ depth=32,
18
+ num_heads=16,
19
+ ratio=1,
20
+ use_checkpoint=False,
21
+ mlp_ratio=4,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.55,
24
+ )
25
+
26
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
27
+ """
28
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
29
+ dimension for the original embeddings.
30
+ Args:
31
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
32
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
33
+ hw (Tuple): size of input image tokens.
34
+
35
+ Returns:
36
+ Absolute positional embeddings after processing with shape (1, H, W, C)
37
+ """
38
+ cls_token = None
39
+ B, L, C = abs_pos.shape
40
+ if has_cls_token:
41
+ cls_token = abs_pos[:, 0:1]
42
+ abs_pos = abs_pos[:, 1:]
43
+
44
+ if ori_h != h or ori_w != w:
45
+ new_abs_pos = F.interpolate(
46
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
47
+ size=(h, w),
48
+ mode="bicubic",
49
+ align_corners=False,
50
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
51
+
52
+ else:
53
+ new_abs_pos = abs_pos
54
+
55
+ if cls_token is not None:
56
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
57
+ return new_abs_pos
58
+
59
+ class DropPath(nn.Module):
60
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61
+ """
62
+ def __init__(self, drop_prob=None):
63
+ super(DropPath, self).__init__()
64
+ self.drop_prob = drop_prob
65
+
66
+ def forward(self, x):
67
+ return drop_path(x, self.drop_prob, self.training)
68
+
69
+ def extra_repr(self):
70
+ return 'p={}'.format(self.drop_prob)
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
74
+ super().__init__()
75
+ out_features = out_features or in_features
76
+ hidden_features = hidden_features or in_features
77
+ self.fc1 = nn.Linear(in_features, hidden_features)
78
+ self.act = act_layer()
79
+ self.fc2 = nn.Linear(hidden_features, out_features)
80
+ self.drop = nn.Dropout(drop)
81
+
82
+ def forward(self, x):
83
+ x = self.fc1(x)
84
+ x = self.act(x)
85
+ x = self.fc2(x)
86
+ x = self.drop(x)
87
+ return x
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(
91
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
92
+ proj_drop=0., attn_head_dim=None,):
93
+ super().__init__()
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ self.dim = dim
97
+
98
+ if attn_head_dim is not None:
99
+ head_dim = attn_head_dim
100
+ all_head_dim = head_dim * self.num_heads
101
+
102
+ self.scale = qk_scale or head_dim ** -0.5
103
+
104
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
105
+
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(all_head_dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x):
111
+ B, N, C = x.shape
112
+ qkv = self.qkv(x)
113
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
114
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
115
+
116
+ q = q * self.scale
117
+ attn = (q @ k.transpose(-2, -1))
118
+
119
+ attn = attn.softmax(dim=-1)
120
+ attn = self.attn_drop(attn)
121
+
122
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
123
+ x = self.proj(x)
124
+ x = self.proj_drop(x)
125
+
126
+ return x
127
+
128
+ class Block(nn.Module):
129
+
130
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
131
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
132
+ norm_layer=nn.LayerNorm, attn_head_dim=None
133
+ ):
134
+ super().__init__()
135
+
136
+ self.norm1 = norm_layer(dim)
137
+ self.attn = Attention(
138
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
139
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
140
+ )
141
+
142
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
143
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144
+ self.norm2 = norm_layer(dim)
145
+ mlp_hidden_dim = int(dim * mlp_ratio)
146
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147
+
148
+ def forward(self, x):
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ return x
152
+
153
+
154
+ class PatchEmbed(nn.Module):
155
+ """ Image to Patch Embedding
156
+ """
157
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
158
+ super().__init__()
159
+ img_size = to_2tuple(img_size)
160
+ patch_size = to_2tuple(patch_size)
161
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
162
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
163
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
164
+ self.img_size = img_size
165
+ self.patch_size = patch_size
166
+ self.num_patches = num_patches
167
+
168
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
169
+
170
+ def forward(self, x, **kwargs):
171
+ B, C, H, W = x.shape
172
+ x = self.proj(x)
173
+ Hp, Wp = x.shape[2], x.shape[3]
174
+
175
+ x = x.flatten(2).transpose(1, 2)
176
+ return x, (Hp, Wp)
177
+
178
+
179
+ class HybridEmbed(nn.Module):
180
+ """ CNN Feature Map Embedding
181
+ Extract feature map from CNN, flatten, project to embedding dim.
182
+ """
183
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
184
+ super().__init__()
185
+ assert isinstance(backbone, nn.Module)
186
+ img_size = to_2tuple(img_size)
187
+ self.img_size = img_size
188
+ self.backbone = backbone
189
+ if feature_size is None:
190
+ with torch.no_grad():
191
+ training = backbone.training
192
+ if training:
193
+ backbone.eval()
194
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
195
+ feature_size = o.shape[-2:]
196
+ feature_dim = o.shape[1]
197
+ backbone.train(training)
198
+ else:
199
+ feature_size = to_2tuple(feature_size)
200
+ feature_dim = self.backbone.feature_info.channels()[-1]
201
+ self.num_patches = feature_size[0] * feature_size[1]
202
+ self.proj = nn.Linear(feature_dim, embed_dim)
203
+
204
+ def forward(self, x):
205
+ x = self.backbone(x)[-1]
206
+ x = x.flatten(2).transpose(1, 2)
207
+ x = self.proj(x)
208
+ return x
209
+
210
+
211
+ class ViT(nn.Module):
212
+ def __init__(self,
213
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
214
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
215
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
216
+ frozen_stages=-1, ratio=1, last_norm=True,
217
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
218
+ ):
219
+ super(ViT, self).__init__()
220
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
221
+ self.num_classes = num_classes
222
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
223
+ self.frozen_stages = frozen_stages
224
+ self.use_checkpoint = use_checkpoint
225
+ self.patch_padding = patch_padding
226
+ self.freeze_attn = freeze_attn
227
+ self.freeze_ffn = freeze_ffn
228
+ self.depth = depth
229
+
230
+ if hybrid_backbone is not None:
231
+ self.patch_embed = HybridEmbed(
232
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
233
+ else:
234
+ self.patch_embed = PatchEmbed(
235
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
236
+ num_patches = self.patch_embed.num_patches
237
+
238
+ # since the pretraining model has class token
239
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
240
+
241
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
242
+
243
+ self.blocks = nn.ModuleList([
244
+ Block(
245
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
246
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
247
+ )
248
+ for i in range(depth)])
249
+
250
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
251
+
252
+ if self.pos_embed is not None:
253
+ trunc_normal_(self.pos_embed, std=.02)
254
+
255
+ self._freeze_stages()
256
+
257
+ def _freeze_stages(self):
258
+ """Freeze parameters."""
259
+ if self.frozen_stages >= 0:
260
+ self.patch_embed.eval()
261
+ for param in self.patch_embed.parameters():
262
+ param.requires_grad = False
263
+
264
+ for i in range(1, self.frozen_stages + 1):
265
+ m = self.blocks[i]
266
+ m.eval()
267
+ for param in m.parameters():
268
+ param.requires_grad = False
269
+
270
+ if self.freeze_attn:
271
+ for i in range(0, self.depth):
272
+ m = self.blocks[i]
273
+ m.attn.eval()
274
+ m.norm1.eval()
275
+ for param in m.attn.parameters():
276
+ param.requires_grad = False
277
+ for param in m.norm1.parameters():
278
+ param.requires_grad = False
279
+
280
+ if self.freeze_ffn:
281
+ self.pos_embed.requires_grad = False
282
+ self.patch_embed.eval()
283
+ for param in self.patch_embed.parameters():
284
+ param.requires_grad = False
285
+ for i in range(0, self.depth):
286
+ m = self.blocks[i]
287
+ m.mlp.eval()
288
+ m.norm2.eval()
289
+ for param in m.mlp.parameters():
290
+ param.requires_grad = False
291
+ for param in m.norm2.parameters():
292
+ param.requires_grad = False
293
+
294
+ def init_weights(self):
295
+ """Initialize the weights in backbone.
296
+ Args:
297
+ pretrained (str, optional): Path to pre-trained weights.
298
+ Defaults to None.
299
+ """
300
+ def _init_weights(m):
301
+ if isinstance(m, nn.Linear):
302
+ trunc_normal_(m.weight, std=.02)
303
+ if isinstance(m, nn.Linear) and m.bias is not None:
304
+ nn.init.constant_(m.bias, 0)
305
+ elif isinstance(m, nn.LayerNorm):
306
+ nn.init.constant_(m.bias, 0)
307
+ nn.init.constant_(m.weight, 1.0)
308
+
309
+ self.apply(_init_weights)
310
+
311
+ def get_num_layers(self):
312
+ return len(self.blocks)
313
+
314
+ @torch.jit.ignore
315
+ def no_weight_decay(self):
316
+ return {'pos_embed', 'cls_token'}
317
+
318
+ def forward_features(self, x):
319
+ B, C, H, W = x.shape
320
+ x, (Hp, Wp) = self.patch_embed(x)
321
+
322
+ if self.pos_embed is not None:
323
+ # fit for multiple GPU training
324
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
325
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
326
+
327
+ for blk in self.blocks:
328
+ if self.use_checkpoint:
329
+ x = checkpoint.checkpoint(blk, x)
330
+ else:
331
+ x = blk(x)
332
+
333
+ x = self.last_norm(x)
334
+
335
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
336
+
337
+ return xp
338
+
339
+ def forward(self, x):
340
+ x = self.forward_features(x)
341
+ return x
342
+
343
+ def train(self, mode=True):
344
+ """Convert the model into training mode."""
345
+ super().train(mode)
346
+ self._freeze_stages()
unish/pi3/models/__pycache__/pi3.cpython-310.pyc ADDED
Binary file (7.01 kB). View file
 
unish/pi3/models/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
unish/pi3/models/dinov2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (181 Bytes). View file
 
unish/pi3/models/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
unish/pi3/models/dinov2/hub/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
unish/pi3/models/dinov2/hub/__pycache__/backbones.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
unish/pi3/models/dinov2/hub/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
unish/pi3/models/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
unish/pi3/models/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
unish/pi3/models/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
unish/pi3/models/dinov2/layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (451 Bytes). View file
 
unish/pi3/models/dinov2/layers/__pycache__/attention.cpython-310.pyc ADDED
Binary file (2.47 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/block.cpython-310.pyc ADDED
Binary file (8.04 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/dino_head.cpython-310.pyc ADDED
Binary file (1.99 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/drop_path.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/layer_scale.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/mlp.cpython-310.pyc ADDED
Binary file (1.2 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/patch_embed.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
unish/pi3/models/dinov2/layers/__pycache__/swiglu_ffn.cpython-310.pyc ADDED
Binary file (2.12 kB). View file
 
unish/pi3/models/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
unish/pi3/models/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
unish/pi3/models/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
unish/pi3/models/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
unish/pi3/models/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
unish/pi3/models/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x