HanzhouLiu commited on
Commit
a6e928c
·
1 Parent(s): 3d936c4

Track all files under examples/ with Git LFS

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +21 -0
  3. README.md +5 -8
  4. app.py +272 -0
  5. examples/demo_styles/00011395.png +3 -0
  6. examples/demo_styles/00018289.png +3 -0
  7. examples/demo_styles/00038427.png +3 -0
  8. examples/demo_styles/00047052.png +3 -0
  9. examples/demo_styles/00047819.png +3 -0
  10. examples/demo_styles/00054987.png +3 -0
  11. examples/demo_styles/00066540.png +3 -0
  12. examples/demo_styles/00069352.png +3 -0
  13. examples/demo_styles/00091988.png +3 -0
  14. examples/demo_styles/1098.png +3 -0
  15. examples/demo_styles/1414.png +3 -0
  16. examples/demo_styles/1842.png +3 -0
  17. examples/demo_styles/201.png +3 -0
  18. examples/demo_styles/2190.png +3 -0
  19. examples/demo_styles/23.jpeg +3 -0
  20. examples/demo_styles/24.jpeg +3 -0
  21. examples/demo_styles/5.jpeg +3 -0
  22. examples/demo_styles/977.png +3 -0
  23. examples/video/bungeenerf_colosseum.mp4 +3 -0
  24. examples/video/dtu_scan_106.mp4 +3 -0
  25. examples/video/fillerbuster_hand_hand.mp4 +3 -0
  26. examples/video/fillerbuster_ramen.mp4 +3 -0
  27. examples/video/fox.mp4 +3 -0
  28. examples/video/horizongs_hillside_summer.mp4 +3 -0
  29. examples/video/kitti360.mp4 +3 -0
  30. examples/video/llff_fortress.mp4 +3 -0
  31. examples/video/llff_horns.mp4 +3 -0
  32. examples/video/matrixcity_street.mp4 +3 -0
  33. examples/video/meganerf_rubble.mp4 +3 -0
  34. examples/video/re10k_1eca36ec55b88fe4.mp4 +3 -0
  35. examples/video/vrnerf_apartment.mp4 +3 -0
  36. examples/video/vrnerf_kitchen.mp4 +3 -0
  37. examples/video/vrnerf_riverview.mp4 +3 -0
  38. examples/video/vrnerf_workshop.mp4 +3 -0
  39. requirements.txt +38 -0
  40. src/dataset/shims/normalize_shim.py +29 -0
  41. src/dataset/types.py +51 -0
  42. src/geometry/camera_emb.py +29 -0
  43. src/geometry/projection.py +261 -0
  44. src/misc/image_io.py +248 -0
  45. src/misc/sh_rotation.py +111 -0
  46. src/misc/sht.py +1637 -0
  47. src/misc/utils.py +73 -0
  48. src/model/decoder/__init__.py +12 -0
  49. src/model/decoder/cuda_splatting.py +244 -0
  50. src/model/decoder/decoder.py +47 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/** filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Hanzhou(Marco) Liu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,10 @@
1
  ---
2
- title: Stylos Gradio
3
- emoji: 🏆
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Stylos Style Transfer
3
+ emoji: 🎨
4
+ colorFrom: pink
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.41.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
 
 
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Stylos 3D Stylization Demo — Pro Space Edition with Quota Limits
5
+ Author: Hanzhou Liu
6
+ """
7
+
8
+ # ===============================================================
9
+ # ZeroGPU & Gradio Compatibility
10
+ # ===============================================================
11
+ import asyncio
12
+ import gradio.queueing as grq
13
+ if not hasattr(grq.Queue, "pending_message_lock") or not hasattr(grq.Queue.pending_message_lock, "__aenter__"):
14
+ grq.Queue.pending_message_lock = asyncio.Lock()
15
+
16
+ # ===============================================================
17
+ # Imports
18
+ # ===============================================================
19
+ import gc
20
+ import os
21
+ import shutil
22
+ import sys
23
+ import time
24
+ from pathlib import Path
25
+ from datetime import datetime
26
+ from dataclasses import dataclass
27
+
28
+ import cv2
29
+ import torch
30
+ import gradio as gr
31
+ from PIL import Image
32
+ from huggingface_hub import snapshot_download
33
+ import spaces
34
+
35
+ # ===============================================================
36
+ # Project Imports
37
+ # ===============================================================
38
+ THIS_FILE = Path(__file__).resolve()
39
+ PROJECT_ROOT = THIS_FILE.parent
40
+ sys.path.append(str(PROJECT_ROOT))
41
+
42
+ from src.misc.image_io import save_interpolated_video
43
+ from src.model.model.stylos import Stylos
44
+ from src.model.ply_export import export_ply
45
+ from src.utils.image import process_image
46
+
47
+ # ===============================================================
48
+ # Constants
49
+ # ===============================================================
50
+ TMP_ROOT = Path("demo_tmp")
51
+ TMP_ROOT.mkdir(exist_ok=True)
52
+
53
+ EXAMPLES = [
54
+ ["examples/video/re10k_1eca36ec55b88fe4.mp4", "examples/demo_styles/23.jpeg"],
55
+ ["examples/video/bungeenerf_colosseum.mp4", "examples/demo_styles/24.jpeg"],
56
+ ["examples/video/fox.mp4", "examples/demo_styles/201.png"],
57
+ ["examples/video/vrnerf_apartment.mp4", "examples/demo_styles/977.png"],
58
+ ]
59
+
60
+ # ===============================================================
61
+ # Usage Limits
62
+ # ===============================================================
63
+ MAX_RUNS_PER_USER = 5 # Max runs per user per day
64
+ MAX_GPU_TIME = 120 # Max GPU time per task (seconds)
65
+ MAX_FRAMES_PER_RUN = 32 # Max frames per reconstruction
66
+ _user_usage = {} # Temporary quota memory (clears on restart)
67
+
68
+
69
+ def check_user_quota(user_id: str):
70
+ """Track and enforce per-user daily quota."""
71
+ today = time.strftime("%Y-%m-%d")
72
+ key = f"{user_id}_{today}"
73
+ _user_usage[key] = _user_usage.get(key, 0) + 1
74
+ if _user_usage[key] > MAX_RUNS_PER_USER:
75
+ raise gr.Error(f"⚠️ You have reached your daily limit ({MAX_RUNS_PER_USER} runs). Please try again tomorrow.")
76
+ return f"✅ Run {_user_usage[key]} / {MAX_RUNS_PER_USER}"
77
+
78
+
79
+ # ===============================================================
80
+ # Model Container
81
+ # ===============================================================
82
+ @dataclass
83
+ class ModelBundle:
84
+ stylos_model: Stylos
85
+ device: torch.device
86
+
87
+
88
+ # ===============================================================
89
+ # Utility Functions
90
+ # ===============================================================
91
+ def create_run_dir(base_dir: Path = TMP_ROOT) -> Path:
92
+ run_dir = base_dir / f"run_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
93
+ run_dir.mkdir(parents=True, exist_ok=True)
94
+ return run_dir
95
+
96
+
97
+ def ensure_dir(path: Path, clear: bool = False):
98
+ if clear and path.exists():
99
+ shutil.rmtree(path)
100
+ path.mkdir(parents=True, exist_ok=True)
101
+ return path
102
+
103
+
104
+ def empty_cuda():
105
+ gc.collect()
106
+ if torch.cuda.is_available():
107
+ torch.cuda.empty_cache()
108
+
109
+
110
+ def ingest_content(video_input=None, reuse_dir=None):
111
+ """Extract frames from uploaded video."""
112
+ empty_cuda()
113
+ target_dir = reuse_dir if (reuse_dir and reuse_dir.exists()) else create_run_dir()
114
+ img_dir = ensure_dir(target_dir / "images", clear=True)
115
+ paths = []
116
+
117
+ if video_input:
118
+ src = Path(video_input if isinstance(video_input, str) else video_input["name"])
119
+ cap = cv2.VideoCapture(str(src))
120
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
121
+ interval = max(1, int(fps))
122
+ idx, frame_id = 0, 0
123
+ while True:
124
+ ok, frame = cap.read()
125
+ if not ok:
126
+ break
127
+ idx += 1
128
+ if idx % interval == 0:
129
+ outp = img_dir / f"{frame_id:06}.png"
130
+ cv2.imwrite(str(outp), frame)
131
+ paths.append(outp)
132
+ frame_id += 1
133
+ cap.release()
134
+ paths.sort()
135
+ return target_dir, paths
136
+
137
+
138
+ def ingest_style(style_input, reuse_dir=None):
139
+ """Save uploaded style image to working directory."""
140
+ target_dir = reuse_dir if (reuse_dir and reuse_dir.exists()) else create_run_dir()
141
+ style_dir = ensure_dir(target_dir / "styles", clear=True)
142
+ dst = style_dir / "style.jpg"
143
+ if style_input:
144
+ Image.open(style_input).convert("RGB").save(dst)
145
+ return target_dir, [dst] if dst.exists() else []
146
+
147
+
148
+ # ===============================================================
149
+ # Inference
150
+ # ===============================================================
151
+ @spaces.GPU()
152
+ def run_reconstruction(target_dir: Path, bundle: ModelBundle, user_id="guest"):
153
+ start_time = time.time()
154
+ check_user_quota(user_id)
155
+
156
+ if not target_dir.exists():
157
+ raise gr.Error("❌ Temporary directory not found.")
158
+
159
+ img_dir = target_dir / "images"
160
+ style_img = target_dir / "styles" / "style.jpg"
161
+ if not img_dir.exists() or not style_img.exists():
162
+ raise gr.Error("⚠️ Please upload both a content video and a style image.")
163
+
164
+ imgs = sorted([img_dir / f for f in os.listdir(img_dir)])
165
+ if len(imgs) > MAX_FRAMES_PER_RUN:
166
+ raise gr.Error(f"⚠️ Maximum {MAX_FRAMES_PER_RUN} frames allowed per run.")
167
+
168
+ tensors = [process_image(str(p)).to(bundle.device) for p in imgs]
169
+ content = torch.stack(tensors, dim=0).unsqueeze(0)
170
+ style = process_image(str(style_img)).unsqueeze(0).unsqueeze(0).to(bundle.device)
171
+
172
+ if time.time() - start_time > MAX_GPU_TIME:
173
+ raise gr.Error("⚠️ Exceeded GPU time limit. Please try a shorter sequence.")
174
+
175
+ with torch.no_grad():
176
+ gauss, pose_dict = bundle.stylos_model.inference(
177
+ (content + 1) * 0.5, style_image=(style + 1) * 0.5
178
+ )
179
+
180
+ extr, intr = pose_dict["extrinsic"], pose_dict["intrinsic"]
181
+ rgb_path, depth_path = save_interpolated_video(
182
+ extr, intr, 1, 448, 448, gauss, str(target_dir), bundle.stylos_model.decoder
183
+ )
184
+
185
+ ply_path = target_dir / "gaussians.ply"
186
+ export_ply(
187
+ gauss.means[0],
188
+ gauss.scales[0],
189
+ gauss.rotations[0],
190
+ gauss.harmonics[0],
191
+ gauss.opacities[0],
192
+ ply_path,
193
+ save_sh_dc_only=True,
194
+ )
195
+
196
+ empty_cuda()
197
+ return str(ply_path), rgb_path, depth_path
198
+
199
+
200
+ # ===============================================================
201
+ # Gradio Callbacks
202
+ # ===============================================================
203
+ def cb_update(video_input, style_input):
204
+ tdir, imgs = ingest_content(video_input)
205
+ tdir, styles = ingest_style(style_input, reuse_dir=tdir)
206
+ ok = len(imgs) and len(styles)
207
+ return str(tdir), [str(p) for p in imgs], str(styles[0]) if styles else None, gr.update(interactive=ok)
208
+
209
+
210
+ def cb_reconstruct(target_dir_str):
211
+ from spaces import get_token_username
212
+ user = get_token_username() or "guest"
213
+ ply, rgb, depth = run_reconstruction(Path(target_dir_str), GLOBAL_BUNDLE, user)
214
+ return ply, rgb, depth
215
+
216
+
217
+ # ===============================================================
218
+ # UI
219
+ # ===============================================================
220
+ def create_interface():
221
+ theme = gr.themes.Soft()
222
+ with gr.Blocks(title="Stylos 3D Stylization Demo", theme=theme) as demo:
223
+ gr.Markdown("### 🎨 **Stylos 3D Stylization Demo (with Quota Limits)**")
224
+
225
+ run_dir_text = gr.Textbox(visible=False, value="None")
226
+ video_input = gr.Video(label="Upload Video", height=300)
227
+ style_input = gr.Image(label="Upload Style Image", type="filepath")
228
+ gallery = gr.Gallery(label="Extracted Frames", height=200)
229
+ reconstruct_btn = gr.Button("Reconstruct", variant="primary", interactive=False)
230
+ model3d = gr.Model3D(label="3D Gaussian Splat", height=400)
231
+ rgb_out = gr.Video(label="Stylized RGB")
232
+ depth_out = gr.Video(label="Depth")
233
+
234
+ video_input.change(cb_update, [video_input, style_input], [run_dir_text, gallery, style_input, reconstruct_btn])
235
+ style_input.change(cb_update, [video_input, style_input], [run_dir_text, gallery, style_input, reconstruct_btn])
236
+ reconstruct_btn.click(cb_reconstruct, [run_dir_text], [model3d, rgb_out, depth_out])
237
+
238
+ return demo
239
+
240
+
241
+ # ===============================================================
242
+ # Entry Point
243
+ # ===============================================================
244
+ GLOBAL_BUNDLE = None
245
+
246
+ def main():
247
+ global GLOBAL_BUNDLE
248
+ print("🚀 Starting Stylos Demo with Quota Limits")
249
+
250
+ weights_dir = snapshot_download(
251
+ repo_id="HanzhouLiu/Stylos_Weights",
252
+ repo_type="dataset",
253
+ allow_patterns=["DL3DV/2025-10-09_16-10-03/*"],
254
+ token=False,
255
+ )
256
+ weights_dir = os.path.join(weights_dir, "DL3DV/2025-10-09_16-10-03")
257
+ print(f"✅ Checkpoint ready at: {weights_dir}")
258
+
259
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
260
+ model = Stylos.from_pretrained(weights_dir).to(device)
261
+ model.eval()
262
+ for p in model.parameters():
263
+ p.requires_grad = False
264
+
265
+ GLOBAL_BUNDLE = ModelBundle(model, device)
266
+
267
+ demo = create_interface()
268
+ demo.queue(max_size=20).launch(show_error=True, ssr_mode=False)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
examples/demo_styles/00011395.png ADDED

Git LFS Details

  • SHA256: f0ecf88adf7896cc453d48ab54b1f493bcd89a16629bd4fe68b0b80fa3e3d6bd
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
examples/demo_styles/00018289.png ADDED

Git LFS Details

  • SHA256: a8b163ef02e4e37842084c7110c69256927a0b16ae69299b018bbc707b2f5703
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
examples/demo_styles/00038427.png ADDED

Git LFS Details

  • SHA256: 2a3f3037c717c67465a1b03e99b0e92fe0294270ce657bb4fc9ba9c908683849
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
examples/demo_styles/00047052.png ADDED

Git LFS Details

  • SHA256: bae75fb2476cf8a7008a68a693111f8988e6b6e047132e1d7fff1fb98e17fa87
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB
examples/demo_styles/00047819.png ADDED

Git LFS Details

  • SHA256: 6ffca4e8e794e34b36fc71076b871fa9545a4c6dfc906be53f724ec530c2a955
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
examples/demo_styles/00054987.png ADDED

Git LFS Details

  • SHA256: ce942e67d605ed0c9f2dc67e240216a89f41168f9ef0e23a90da823c61bac575
  • Pointer size: 131 Bytes
  • Size of remote file: 146 kB
examples/demo_styles/00066540.png ADDED

Git LFS Details

  • SHA256: 81cc32dad07296d8cbbd72aad631f28a08b4d8e503270e92b3fdad22e094f617
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
examples/demo_styles/00069352.png ADDED

Git LFS Details

  • SHA256: 0b2bad6fb8b6de3a49e428922ced7f419b3c0c412e7abdef4bbbb829ed6c05b2
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
examples/demo_styles/00091988.png ADDED

Git LFS Details

  • SHA256: 79f119985dce1d1ef7bc4d778b0ab4712b81de8bb58e7ea8c9af671301fc5eeb
  • Pointer size: 130 Bytes
  • Size of remote file: 98.4 kB
examples/demo_styles/1098.png ADDED

Git LFS Details

  • SHA256: 1198d903907f3276a6df12226fbc03b9c931daaabe6df40931ee150348c9f85e
  • Pointer size: 131 Bytes
  • Size of remote file: 165 kB
examples/demo_styles/1414.png ADDED

Git LFS Details

  • SHA256: 020aa921fe5e72df4a8aa0ca98de86b8ef8699dc58a10dd2f332a7892cbf8475
  • Pointer size: 130 Bytes
  • Size of remote file: 74 kB
examples/demo_styles/1842.png ADDED

Git LFS Details

  • SHA256: 78193a86a7e6acad22f2e8f482fdd1395e094f6a764fcc76290e9c3b8e147b1f
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
examples/demo_styles/201.png ADDED

Git LFS Details

  • SHA256: c0f2502c470dd398aab6368a6c829d55b63b82266c7f20d063208618eeb65760
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
examples/demo_styles/2190.png ADDED

Git LFS Details

  • SHA256: 85c93031078cf0c90f073b9a8d90d8192d7f2491111a5980f362b1f7e8eab6c9
  • Pointer size: 131 Bytes
  • Size of remote file: 178 kB
examples/demo_styles/23.jpeg ADDED

Git LFS Details

  • SHA256: 1b28969f1364083063523603b2764b0a434e07a2cb6dfea7b073b733ce202c17
  • Pointer size: 130 Bytes
  • Size of remote file: 28.9 kB
examples/demo_styles/24.jpeg ADDED

Git LFS Details

  • SHA256: cc912a234a7b2b2447afbd3a8a4807dfd610ed94abf8f8c70acebd7df3aa58df
  • Pointer size: 130 Bytes
  • Size of remote file: 24.7 kB
examples/demo_styles/5.jpeg ADDED

Git LFS Details

  • SHA256: 09b0a91e961e61342a3e9adcf0ddb0a1c8c662e4be70c63ae2714dfce002eaa8
  • Pointer size: 130 Bytes
  • Size of remote file: 17.1 kB
examples/demo_styles/977.png ADDED

Git LFS Details

  • SHA256: 0034399cd1422020945fb69cecfb1c87218b23b11e0fecb99fb0566e25c77f5d
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB
examples/video/bungeenerf_colosseum.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:416b6af945547b5d19476823672de552944c7b5a147d29e9e8243e91a16aee3e
3
+ size 329073
examples/video/dtu_scan_106.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16d7a06325cd368b134908e600a6c0741c7d0d188f1db690532b8ac85d65fef5
3
+ size 352188
examples/video/fillerbuster_hand_hand.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b4ca982672bc92342b3e722c171d9d2e4d67a5a8116cd9f346956fbe01e253f
3
+ size 319404
examples/video/fillerbuster_ramen.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60346a64a0a0d6805131d0d57edeeb0dae24f24c3f10560e95df65531221229
3
+ size 660736
examples/video/fox.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3fa2ccff78e5d8085bb58f3def2d482e8df285ced5ef1b56abfe3766f0d90e0
3
+ size 2361921
examples/video/horizongs_hillside_summer.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5dff78d9c00b3776bfca3a370061698bddead2ae940fe5a42d082ccf2ca80d1
3
+ size 1606537
examples/video/kitti360.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c6b13929b2c2aae8b95921d8626f5be06f6afffe05ea4e47940ffeb9906f9fc
3
+ size 1843629
examples/video/llff_fortress.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90ea046a0ec78651975529ebe6b9c72b60c19561fe61b15b15b9df0e44d9fe9a
3
+ size 196243
examples/video/llff_horns.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bc4c443c2a3f889f0c1283e98bd6a7026c36858fb37808bb2e8699ad1a2c1d8
3
+ size 372570
examples/video/matrixcity_street.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa415f27177398b4e06f580beb3778701ca55784afade2fd6a058212213febc8
3
+ size 3163684
examples/video/meganerf_rubble.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3410c759eb73ca2403ab8fe35d5ebabdbc25e3a0e67d8670a89fe17686246ed0
3
+ size 450116
examples/video/re10k_1eca36ec55b88fe4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3516eea797fe8035a7ff6d80098dfddd53a8d087dc3c00419d4192d73960d00
3
+ size 35089
examples/video/vrnerf_apartment.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fdd5f165a4293cd95e3dd88d84b1f370decdd86308aa67a9d3832e01f4d6906
3
+ size 2076392
examples/video/vrnerf_kitchen.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3db5d766ec86a7abdfe1f033b252337e6d934ea15035fafb4d0fc0c0e9e9740a
3
+ size 775715
examples/video/vrnerf_riverview.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b8187936cc49910ef330a37b1bbdab0076096d6c01f33b097c11937184de168
3
+ size 768290
examples/video/vrnerf_workshop.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0f1334acc74bd70086a9be94d0c36838ebd7499af27f942c315e1ba282e285b
3
+ size 1718918
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trimesh
2
+ numpy==1.25.0
3
+ wheel
4
+ tqdm
5
+ lightning
6
+ black
7
+ ruff
8
+ hydra-core
9
+ jaxtyping
10
+ beartype
11
+ wandb
12
+ einops
13
+ colorama
14
+ scikit-image
15
+ colorspacious
16
+ matplotlib
17
+ moviepy
18
+ imageio
19
+ timm
20
+ dacite
21
+ lpips
22
+ e3nn
23
+ plyfile
24
+ tabulate
25
+ svg.py
26
+ scikit-video
27
+ opencv-python
28
+ Pillow
29
+ #xformers==0.0.24
30
+ #huggingface-hub<0.14
31
+ xformers
32
+ moviepy==1.0.3
33
+ pydantic
34
+ open3d
35
+ einops
36
+ safetensors
37
+ torch_scatter @ https://data.pyg.org/whl/torch-2.8.0%2Bcu128/torch_scatter-2.1.2%2Bpt28cu128-cp310-cp310-linux_x86_64.whl
38
+ gsplat @ https://github.com/nerfstudio-project/gsplat/releases/download/v1.5.3/gsplat-1.5.3+pt22cu121-cp310-cp310-linux_x86_64.whl
src/dataset/shims/normalize_shim.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import einsum, reduce, repeat
3
+ from jaxtyping import Float
4
+ from torch import Tensor
5
+
6
+ from ..types import BatchedExample
7
+
8
+
9
+ def inverse_normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
10
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
11
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
12
+ return tensor * std + mean
13
+
14
+
15
+ def normalize_image(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
16
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
17
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
18
+ return (tensor - mean) / std
19
+
20
+
21
+ def apply_normalize_shim(
22
+ batch: BatchedExample,
23
+ mean: tuple[float, float, float] = (0.5, 0.5, 0.5),
24
+ std: tuple[float, float, float] = (0.5, 0.5, 0.5),
25
+ ) -> BatchedExample:
26
+ batch["context"]["image"] = normalize_image(batch["context"]["image"], mean, std)
27
+ if "style_image" in batch["context"]:
28
+ batch["context"]["style_image"] = normalize_image(batch["context"]["style_image"], mean, std)
29
+ return batch
src/dataset/types.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal, TypedDict
2
+
3
+ from jaxtyping import Float, Int64
4
+ from torch import Tensor
5
+
6
+ Stage = Literal["train", "val", "test"]
7
+
8
+
9
+ # The following types mainly exist to make type-hinted keys show up in VS Code. Some
10
+ # dimensions are annotated as "_" because either:
11
+ # 1. They're expected to change as part of a function call (e.g., resizing the dataset).
12
+ # 2. They're expected to vary within the same function call (e.g., the number of views,
13
+ # which differs between context and target BatchedViews).
14
+
15
+
16
+ class BatchedViews(TypedDict, total=False):
17
+ extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4
18
+ intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3
19
+ image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width
20
+ near: Float[Tensor, "batch _"] # batch view
21
+ far: Float[Tensor, "batch _"] # batch view
22
+ index: Int64[Tensor, "batch _"] # batch view
23
+ overlap: Float[Tensor, "batch _"] # batch view
24
+
25
+
26
+ class BatchedExample(TypedDict, total=False):
27
+ target: BatchedViews
28
+ context: BatchedViews
29
+ scene: list[str]
30
+
31
+
32
+ class UnbatchedViews(TypedDict, total=False):
33
+ extrinsics: Float[Tensor, "_ 4 4"]
34
+ intrinsics: Float[Tensor, "_ 3 3"]
35
+ image: Float[Tensor, "_ 3 height width"]
36
+ near: Float[Tensor, " _"]
37
+ far: Float[Tensor, " _"]
38
+ index: Int64[Tensor, " _"]
39
+
40
+
41
+ class UnbatchedExample(TypedDict, total=False):
42
+ target: UnbatchedViews
43
+ context: UnbatchedViews
44
+ scene: str
45
+
46
+
47
+ # A data shim modifies the example after it's been returned from the data loader.
48
+ DataShim = Callable[[BatchedExample], BatchedExample]
49
+
50
+ AnyExample = BatchedExample | UnbatchedExample
51
+ AnyViews = BatchedViews | UnbatchedViews
src/geometry/camera_emb.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+
3
+ from .projection import sample_image_grid, get_local_rays
4
+ from ..misc.sht import rsh_cart_2, rsh_cart_4, rsh_cart_6, rsh_cart_8
5
+
6
+
7
+ def get_intrinsic_embedding(context, degree=0, downsample=1, merge_hw=False):
8
+ assert degree in [0, 2, 4, 8]
9
+
10
+ b, v, _, h, w = context["image"].shape
11
+ device = context["image"].device
12
+ tgt_h, tgt_w = h // downsample, w // downsample
13
+ xy_ray, _ = sample_image_grid((tgt_h, tgt_w), device)
14
+ xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # [b, v, h, w, 2]
15
+ directions = get_local_rays(xy_ray, rearrange(context["intrinsics"], "b v i j -> b v () () i j"),)
16
+
17
+ if degree == 2:
18
+ directions = rsh_cart_2(directions)
19
+ elif degree == 4:
20
+ directions = rsh_cart_4(directions)
21
+ elif degree == 8:
22
+ directions = rsh_cart_8(directions)
23
+
24
+ if merge_hw:
25
+ directions = rearrange(directions, "b v h w d -> b v (h w) d")
26
+ else:
27
+ directions = rearrange(directions, "b v h w d -> b v d h w")
28
+
29
+ return directions
src/geometry/projection.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import prod
2
+
3
+ import torch
4
+ from einops import einsum, rearrange, reduce, repeat
5
+ from jaxtyping import Bool, Float, Int64
6
+ from torch import Tensor
7
+
8
+
9
+ def homogenize_points(
10
+ points: Float[Tensor, "*batch dim"],
11
+ ) -> Float[Tensor, "*batch dim+1"]:
12
+ """Convert batched points (xyz) to (xyz1)."""
13
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
14
+
15
+
16
+ def homogenize_vectors(
17
+ vectors: Float[Tensor, "*batch dim"],
18
+ ) -> Float[Tensor, "*batch dim+1"]:
19
+ """Convert batched vectors (xyz) to (xyz0)."""
20
+ return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
21
+
22
+
23
+ def transform_rigid(
24
+ homogeneous_coordinates: Float[Tensor, "*#batch dim"],
25
+ transformation: Float[Tensor, "*#batch dim dim"],
26
+ ) -> Float[Tensor, "*batch dim"]:
27
+ """Apply a rigid-body transformation to points or vectors."""
28
+ return einsum(transformation, homogeneous_coordinates, "... i j, ... j -> ... i")
29
+
30
+
31
+ def transform_cam2world(
32
+ homogeneous_coordinates: Float[Tensor, "*#batch dim"],
33
+ extrinsics: Float[Tensor, "*#batch dim dim"],
34
+ ) -> Float[Tensor, "*batch dim"]:
35
+ """Transform points from 3D camera coordinates to 3D world coordinates."""
36
+ return transform_rigid(homogeneous_coordinates, extrinsics)
37
+
38
+
39
+ def transform_world2cam(
40
+ homogeneous_coordinates: Float[Tensor, "*#batch dim"],
41
+ extrinsics: Float[Tensor, "*#batch dim dim"],
42
+ ) -> Float[Tensor, "*batch dim"]:
43
+ """Transform points from 3D world coordinates to 3D camera coordinates."""
44
+ return transform_rigid(homogeneous_coordinates, extrinsics.inverse())
45
+
46
+
47
+ def project_camera_space(
48
+ points: Float[Tensor, "*#batch dim"],
49
+ intrinsics: Float[Tensor, "*#batch dim dim"],
50
+ epsilon: float = torch.finfo(torch.float32).eps,
51
+ infinity: float = 1e8,
52
+ ) -> Float[Tensor, "*batch dim-1"]:
53
+ points = points / (points[..., -1:] + epsilon)
54
+ points = points.nan_to_num(posinf=infinity, neginf=-infinity)
55
+ points = einsum(intrinsics, points, "... i j, ... j -> ... i")
56
+ return points[..., :-1]
57
+
58
+
59
+ def project(
60
+ points: Float[Tensor, "*#batch dim"],
61
+ extrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
62
+ intrinsics: Float[Tensor, "*#batch dim dim"],
63
+ epsilon: float = torch.finfo(torch.float32).eps,
64
+ ) -> tuple[
65
+ Float[Tensor, "*batch dim-1"], # xy coordinates
66
+ Bool[Tensor, " *batch"], # whether points are in front of the camera
67
+ ]:
68
+ points = homogenize_points(points)
69
+ points = transform_world2cam(points, extrinsics)[..., :-1]
70
+ in_front_of_camera = points[..., -1] >= 0
71
+ return project_camera_space(points, intrinsics, epsilon=epsilon), in_front_of_camera
72
+
73
+
74
+ def unproject(
75
+ coordinates: Float[Tensor, "*#batch dim"],
76
+ z: Float[Tensor, "*#batch"],
77
+ intrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
78
+ ) -> Float[Tensor, "*batch dim+1"]:
79
+ """Unproject 2D camera coordinates with the given Z values."""
80
+
81
+ # Apply the inverse intrinsics to the coordinates.
82
+ coordinates = homogenize_points(coordinates)
83
+ ray_directions = einsum(
84
+ intrinsics.inverse(), coordinates, "... i j, ... j -> ... i"
85
+ )
86
+
87
+ # Apply the supplied depth values.
88
+ return ray_directions * z[..., None]
89
+
90
+
91
+ def get_world_rays(
92
+ coordinates: Float[Tensor, "*#batch dim"],
93
+ extrinsics: Float[Tensor, "*#batch dim+2 dim+2"],
94
+ intrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
95
+ ) -> tuple[
96
+ Float[Tensor, "*batch dim+1"], # origins
97
+ Float[Tensor, "*batch dim+1"], # directions
98
+ ]:
99
+ # Get camera-space ray directions.
100
+ directions = unproject(
101
+ coordinates,
102
+ torch.ones_like(coordinates[..., 0]),
103
+ intrinsics,
104
+ )
105
+ directions = directions / directions.norm(dim=-1, keepdim=True)
106
+
107
+ # Transform ray directions to world coordinates.
108
+ directions = homogenize_vectors(directions)
109
+ directions = transform_cam2world(directions, extrinsics)[..., :-1]
110
+
111
+ # Tile the ray origins to have the same shape as the ray directions.
112
+ origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape)
113
+
114
+ return origins, directions
115
+
116
+
117
+ def get_local_rays(
118
+ coordinates: Float[Tensor, "*#batch dim"],
119
+ intrinsics: Float[Tensor, "*#batch dim+1 dim+1"],
120
+ ) -> Float[Tensor, "*batch dim+1"]:
121
+ # Get camera-space ray directions.
122
+ directions = unproject(
123
+ coordinates,
124
+ torch.ones_like(coordinates[..., 0]),
125
+ intrinsics,
126
+ )
127
+ directions = directions / directions.norm(dim=-1, keepdim=True)
128
+ return directions
129
+
130
+
131
+ def sample_image_grid(
132
+ shape: tuple[int, ...],
133
+ device: torch.device = torch.device("cpu"),
134
+ ) -> tuple[
135
+ Float[Tensor, "*shape dim"], # float coordinates (xy indexing)
136
+ Int64[Tensor, "*shape dim"], # integer indices (ij indexing)
137
+ ]:
138
+ """Get normalized (range 0 to 1) coordinates and integer indices for an image."""
139
+
140
+ # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a
141
+ # (row, col) coordinate.
142
+ indices = [torch.arange(length, device=device) for length in shape]
143
+ stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1)
144
+
145
+ # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case,
146
+ # each entry is an (x, y) coordinate.
147
+ coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)]
148
+ coordinates = reversed(coordinates)
149
+ coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1)
150
+
151
+ return coordinates, stacked_indices
152
+
153
+
154
+ def sample_training_rays(
155
+ image: Float[Tensor, "batch view channel ..."],
156
+ intrinsics: Float[Tensor, "batch view dim dim"],
157
+ extrinsics: Float[Tensor, "batch view dim+1 dim+1"],
158
+ num_rays: int,
159
+ ) -> tuple[
160
+ Float[Tensor, "batch ray dim"], # origins
161
+ Float[Tensor, "batch ray dim"], # directions
162
+ Float[Tensor, "batch ray 3"], # sampled color
163
+ ]:
164
+ device = extrinsics.device
165
+ b, v, _, *grid_shape = image.shape
166
+
167
+ # Generate all possible target rays.
168
+ xy, _ = sample_image_grid(tuple(grid_shape), device)
169
+ origins, directions = get_world_rays(
170
+ rearrange(xy, "... d -> ... () () d"),
171
+ extrinsics,
172
+ intrinsics,
173
+ )
174
+ origins = rearrange(origins, "... b v xy -> b (v ...) xy", b=b, v=v)
175
+ directions = rearrange(directions, "... b v xy -> b (v ...) xy", b=b, v=v)
176
+ pixels = rearrange(image, "b v c ... -> b (v ...) c")
177
+
178
+ # Sample random rays.
179
+ num_possible_rays = v * prod(grid_shape)
180
+ ray_indices = torch.randint(num_possible_rays, (b, num_rays), device=device)
181
+ batch_indices = repeat(torch.arange(b, device=device), "b -> b n", n=num_rays)
182
+
183
+ return (
184
+ origins[batch_indices, ray_indices],
185
+ directions[batch_indices, ray_indices],
186
+ pixels[batch_indices, ray_indices],
187
+ )
188
+
189
+
190
+ def intersect_rays(
191
+ origins_x: Float[Tensor, "*#batch 3"],
192
+ directions_x: Float[Tensor, "*#batch 3"],
193
+ origins_y: Float[Tensor, "*#batch 3"],
194
+ directions_y: Float[Tensor, "*#batch 3"],
195
+ eps: float = 1e-5,
196
+ inf: float = 1e10,
197
+ ) -> Float[Tensor, "*batch 3"]:
198
+ """Compute the least-squares intersection of rays. Uses the math from here:
199
+ https://math.stackexchange.com/a/1762491/286022
200
+ """
201
+
202
+ # Broadcast the rays so their shapes match.
203
+ shape = torch.broadcast_shapes(
204
+ origins_x.shape,
205
+ directions_x.shape,
206
+ origins_y.shape,
207
+ directions_y.shape,
208
+ )
209
+ origins_x = origins_x.broadcast_to(shape)
210
+ directions_x = directions_x.broadcast_to(shape)
211
+ origins_y = origins_y.broadcast_to(shape)
212
+ directions_y = directions_y.broadcast_to(shape)
213
+
214
+ # Detect and remove batch elements where the directions are parallel.
215
+ parallel = einsum(directions_x, directions_y, "... xyz, ... xyz -> ...") > 1 - eps
216
+ origins_x = origins_x[~parallel]
217
+ directions_x = directions_x[~parallel]
218
+ origins_y = origins_y[~parallel]
219
+ directions_y = directions_y[~parallel]
220
+
221
+ # Stack the rays into (2, *shape).
222
+ origins = torch.stack([origins_x, origins_y], dim=0)
223
+ directions = torch.stack([directions_x, directions_y], dim=0)
224
+ dtype = origins.dtype
225
+ device = origins.device
226
+
227
+ # Compute n_i * n_i^T - eye(3) from the equation.
228
+ n = einsum(directions, directions, "r b i, r b j -> r b i j")
229
+ n = n - torch.eye(3, dtype=dtype, device=device).broadcast_to((2, 1, 3, 3))
230
+
231
+ # Compute the left-hand side of the equation.
232
+ lhs = reduce(n, "r b i j -> b i j", "sum")
233
+
234
+ # Compute the right-hand side of the equation.
235
+ rhs = einsum(n, origins, "r b i j, r b j -> r b i")
236
+ rhs = reduce(rhs, "r b i -> b i", "sum")
237
+
238
+ # Left-matrix-multiply both sides by the pseudo-inverse of lhs to find p.
239
+ result = torch.linalg.lstsq(lhs, rhs).solution
240
+
241
+ # Handle the case of parallel lines by setting depth to infinity.
242
+ result_all = torch.ones(shape, dtype=dtype, device=device) * inf
243
+ result_all[~parallel] = result
244
+ return result_all
245
+
246
+
247
+ def get_fov(intrinsics: Float[Tensor, "batch 3 3"]) -> Float[Tensor, "batch 2"]:
248
+ intrinsics_inv = intrinsics.inverse()
249
+
250
+ def process_vector(vector):
251
+ vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device)
252
+ vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
253
+ return vector / vector.norm(dim=-1, keepdim=True)
254
+
255
+ left = process_vector([0, 0.5, 1])
256
+ right = process_vector([1, 0.5, 1])
257
+ top = process_vector([0.5, 0, 1])
258
+ bottom = process_vector([0.5, 1, 1])
259
+ fov_x = (left * right).sum(dim=-1).acos()
260
+ fov_y = (top * bottom).sum(dim=-1).acos()
261
+ return torch.stack((fov_x, fov_y), dim=-1)
src/misc/image_io.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import cv2
7
+ import imageio
8
+ import numpy as np
9
+ import skvideo
10
+ import torch
11
+ import torchvision.transforms as tf
12
+ from einops import rearrange, repeat
13
+ from jaxtyping import Float, UInt8
14
+
15
+ from matplotlib import pyplot as plt
16
+ from matplotlib.figure import Figure
17
+ from PIL import Image
18
+ from torch import Tensor
19
+
20
+ FloatImage = Union[
21
+ Float[Tensor, "height width"],
22
+ Float[Tensor, "channel height width"],
23
+ Float[Tensor, "batch channel height width"],
24
+ ]
25
+
26
+
27
+ def fig_to_image(
28
+ fig: Figure,
29
+ dpi: int = 100,
30
+ device: torch.device = torch.device("cpu"),
31
+ ) -> Float[Tensor, "3 height width"]:
32
+ buffer = io.BytesIO()
33
+ fig.savefig(buffer, format="raw", dpi=dpi)
34
+ buffer.seek(0)
35
+ data = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
36
+ h = int(fig.bbox.bounds[3])
37
+ w = int(fig.bbox.bounds[2])
38
+ data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4)
39
+ buffer.close()
40
+ return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3]
41
+
42
+
43
+ def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]:
44
+ # Handle batched images.
45
+ if image.ndim == 4:
46
+ image = rearrange(image, "b c h w -> c h (b w)")
47
+
48
+ # Handle single-channel images.
49
+ if image.ndim == 2:
50
+ image = rearrange(image, "h w -> () h w")
51
+
52
+ # Ensure that there are 3 or 4 channels.
53
+ channel, _, _ = image.shape
54
+ if channel == 1:
55
+ image = repeat(image, "() h w -> c h w", c=3)
56
+ assert image.shape[0] in (3, 4)
57
+
58
+ image = (image.detach().clip(min=0, max=1) * 255).type(torch.uint8)
59
+ return rearrange(image, "c h w -> h w c").cpu().numpy()
60
+
61
+
62
+ def save_image(
63
+ image: FloatImage,
64
+ path: Union[Path, str],
65
+ ) -> None:
66
+ """Save an image. Assumed to be in range 0-1."""
67
+
68
+ # Create the parent directory if it doesn't already exist.
69
+ path = Path(path)
70
+ path.parent.mkdir(exist_ok=True, parents=True)
71
+
72
+ # Save the image.
73
+ Image.fromarray(prep_image(image)).save(path)
74
+
75
+
76
+ def load_image(
77
+ path: Union[Path, str],
78
+ ) -> Float[Tensor, "3 height width"]:
79
+ return tf.ToTensor()(Image.open(path))[:3]
80
+
81
+
82
+ def save_video(tensor, save_path, fps=10):
83
+ """
84
+ Save a tensor of shape (N, C, H, W) as a video file using imageio.
85
+ Args:
86
+ tensor: Tensor of shape (N, C, H, W) in range [0, 1]
87
+ save_path: Path to save the video file
88
+ fps: Frames per second for the video
89
+ """
90
+ # Convert tensor to numpy array and adjust dimensions
91
+ video = tensor.cpu().detach().numpy() # (N, C, H, W)
92
+ video = np.transpose(video, (0, 2, 3, 1)) # (N, H, W, C)
93
+
94
+ # Scale to [0, 255] and convert to uint8
95
+ video = (video * 255).astype(np.uint8)
96
+
97
+ # Ensure the directory exists
98
+ import os
99
+
100
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
101
+
102
+ # Use imageio to write video (handles codec compatibility automatically)
103
+ import imageio
104
+
105
+ writer = imageio.get_writer(save_path, fps=fps)
106
+
107
+ for frame in video:
108
+ writer.append_data(frame)
109
+
110
+ writer.close()
111
+
112
+ def save_images(tensor, save_path):
113
+ """
114
+ Save a tensor of shape (N, C, H, W) as a series of images using imageio.
115
+ Args:
116
+ tensor: Tensor of shape (N, C, H, W) in range [0, 1]
117
+ save_path: Path to save the video file
118
+ """
119
+ # Convert tensor to numpy array and adjust dimensions
120
+ images = tensor.cpu().detach().numpy() # (N, C, H, W)
121
+ images = np.transpose(images, (0, 2, 3, 1)) # (N, H, W, C)
122
+
123
+ # Scale to [0, 255] and convert to uint8
124
+ images = (images * 255).astype(np.uint8)
125
+
126
+ os.makedirs(save_path, exist_ok=True)
127
+ # save image in the folder
128
+ for i, img in enumerate(images):
129
+ imageio.imwrite(os.path.join(save_path, f"{i:03d}.png"), img)
130
+
131
+ def save_interpolated_video(
132
+ pred_extrinsics, pred_intrinsics, b, h, w, gaussians, save_path, decoder_func, t=10,
133
+ save_rgb_video=True, save_depth_video=True, save_rgb=False, save_depth=False,
134
+ save_name=""
135
+ ):
136
+ # Interpolate between neighboring frames
137
+ # t: Number of extra views to interpolate between each pair
138
+ interpolated_extrinsics = []
139
+ interpolated_intrinsics = []
140
+
141
+
142
+ if pred_extrinsics.shape[1]==1:
143
+ # If there's only one frame, just duplicate it
144
+ for _ in range(t):
145
+ interpolated_extrinsics.append(pred_extrinsics[:, 0].unsqueeze(1))
146
+ interpolated_intrinsics.append(pred_intrinsics[:, 0].unsqueeze(1))
147
+ else:
148
+ # For each pair of neighboring frame
149
+ for i in range(pred_extrinsics.shape[1] - 1):
150
+ # Add the current frame
151
+ interpolated_extrinsics.append(pred_extrinsics[:, i : i + 1])
152
+ interpolated_intrinsics.append(pred_intrinsics[:, i : i + 1])
153
+
154
+ # Interpolate between current and next frame
155
+ for j in range(1, t + 1):
156
+ alpha = j / (t + 1)
157
+
158
+ # Interpolate extrinsics
159
+ start_extrinsic = pred_extrinsics[:, i]
160
+ end_extrinsic = pred_extrinsics[:, i + 1]
161
+
162
+ # Separate rotation and translation
163
+ start_rot = start_extrinsic[:, :3, :3]
164
+ end_rot = end_extrinsic[:, :3, :3]
165
+ start_trans = start_extrinsic[:, :3, 3]
166
+ end_trans = end_extrinsic[:, :3, 3]
167
+
168
+ # Interpolate translation (linear)
169
+ interp_trans = (1 - alpha) * start_trans + alpha * end_trans
170
+
171
+ # Interpolate rotation (spherical)
172
+ start_rot_flat = start_rot.reshape(b, 9)
173
+ end_rot_flat = end_rot.reshape(b, 9)
174
+ interp_rot_flat = (1 - alpha) * start_rot_flat + alpha * end_rot_flat
175
+ interp_rot = interp_rot_flat.reshape(b, 3, 3)
176
+
177
+ # Normalize rotation matrix to ensure it's orthogonal
178
+ u, _, v = torch.svd(interp_rot)
179
+ interp_rot = torch.bmm(u, v.transpose(1, 2))
180
+
181
+ # Combine interpolated rotation and translation
182
+ interp_extrinsic = (
183
+ torch.eye(4, device=pred_extrinsics.device).unsqueeze(0).repeat(b, 1, 1)
184
+ )
185
+ interp_extrinsic[:, :3, :3] = interp_rot
186
+ interp_extrinsic[:, :3, 3] = interp_trans
187
+
188
+ # Interpolate intrinsics (linear)
189
+ start_intrinsic = pred_intrinsics[:, i]
190
+ end_intrinsic = pred_intrinsics[:, i + 1]
191
+ interp_intrinsic = (1 - alpha) * start_intrinsic + alpha * end_intrinsic
192
+
193
+ # Add interpolated frame
194
+ interpolated_extrinsics.append(interp_extrinsic.unsqueeze(1))
195
+ interpolated_intrinsics.append(interp_intrinsic.unsqueeze(1))
196
+
197
+ # Concatenate all frames
198
+ pred_all_extrinsic = torch.cat(interpolated_extrinsics, dim=1)
199
+ pred_all_intrinsic = torch.cat(interpolated_intrinsics, dim=1)
200
+ print(pred_all_extrinsic.shape, pred_all_intrinsic.shape)
201
+
202
+ # Add the last frame
203
+ interpolated_extrinsics.append(pred_all_extrinsic[:, -1:])
204
+ interpolated_intrinsics.append(pred_all_intrinsic[:, -1:])
205
+ print(len(interpolated_extrinsics), len(interpolated_intrinsics))
206
+
207
+ # Update K to reflect the new number of frames
208
+ num_frames = pred_all_extrinsic.shape[1]
209
+
210
+ # Render interpolated views
211
+ interpolated_output = decoder_func.forward(
212
+ gaussians,
213
+ pred_all_extrinsic,
214
+ pred_all_intrinsic.float(),
215
+ torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 0.1,
216
+ torch.ones(1, num_frames, device=pred_all_extrinsic.device) * 100,
217
+ (h, w),
218
+ )
219
+
220
+ # Convert to video format
221
+ video = interpolated_output.color[0].clip(min=0, max=1)
222
+ depth = interpolated_output.depth[0]
223
+
224
+ # Normalize depth for visualization
225
+ # to avoid `quantile() input tensor is too large`
226
+ num_views = pred_extrinsics.shape[1]
227
+ depth_norm = (depth - depth[::num_views].quantile(0.01)) / (
228
+ depth[::num_views].quantile(0.99) - depth[::num_views].quantile(0.01)
229
+ )
230
+ depth_norm = plt.cm.turbo(depth_norm.cpu().numpy())
231
+ depth_colored = (
232
+ torch.from_numpy(depth_norm[..., :3]).permute(0, 3, 1, 2).to(depth.device)
233
+ )
234
+ depth_colored = depth_colored.clip(min=0, max=1)
235
+
236
+ # Save depth video
237
+ if save_depth_video:
238
+ save_video(depth_colored, os.path.join(save_path, f"{save_name}depth.mp4"))
239
+ if save_rgb_video:
240
+ save_video(video, os.path.join(save_path, f"{save_name}rgb.mp4"))
241
+
242
+ # Save video
243
+ if save_rgb:
244
+ save_images(video, os.path.join(save_path, f"{save_name}rgb_frames"))
245
+ if save_depth:
246
+ save_images(depth_colored, os.path.join(save_path, f"{save_name}depth_frames"))
247
+
248
+ return os.path.join(save_path, f"{save_name}rgb.mp4"), os.path.join(save_path, f"{save_name}depth.mp4")
src/misc/sh_rotation.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import isqrt
2
+
3
+ import torch
4
+ from e3nn.o3 import matrix_to_angles, wigner_D
5
+ from einops import einsum
6
+ from jaxtyping import Float
7
+ from torch import Tensor
8
+
9
+
10
+ def rotate_sh(
11
+ sh_coefficients: Float[Tensor, "*#batch n"],
12
+ rotations: Float[Tensor, "*#batch 3 3"],
13
+ ) -> Float[Tensor, "*batch n"]:
14
+ device = sh_coefficients.device
15
+ dtype = sh_coefficients.dtype
16
+
17
+ # change the basis from YZX -> XYZ to fit the convention of e3nn
18
+ P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]],
19
+ dtype=sh_coefficients.dtype, device=sh_coefficients.device)
20
+ inversed_P = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0], ],
21
+ dtype=sh_coefficients.dtype, device=sh_coefficients.device)
22
+ permuted_rotation_matrix = inversed_P @ rotations @ P
23
+
24
+ *_, n = sh_coefficients.shape
25
+ alpha, beta, gamma = matrix_to_angles(permuted_rotation_matrix)
26
+ result = []
27
+ for degree in range(isqrt(n)):
28
+ with torch.device(device):
29
+ sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype)
30
+ sh_rotated = einsum(
31
+ sh_rotations,
32
+ sh_coefficients[..., degree**2 : (degree + 1) ** 2],
33
+ "... i j, ... j -> ... i",
34
+ )
35
+ result.append(sh_rotated)
36
+
37
+ return torch.cat(result, dim=-1)
38
+
39
+
40
+ # def rotate_sh(
41
+ # sh_coefficients: Float[Tensor, "*#batch n"],
42
+ # rotations: Float[Tensor, "*#batch 3 3"],
43
+ # ) -> Float[Tensor, "*batch n"]:
44
+ # device = sh_coefficients.device
45
+ # dtype = sh_coefficients.dtype
46
+ #
47
+ # *_, n = sh_coefficients.shape
48
+ # alpha, beta, gamma = matrix_to_angles(rotations)
49
+ # result = []
50
+ # for degree in range(isqrt(n)):
51
+ # with torch.device(device):
52
+ # sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype)
53
+ # sh_rotated = einsum(
54
+ # sh_rotations,
55
+ # sh_coefficients[..., degree**2 : (degree + 1) ** 2],
56
+ # "... i j, ... j -> ... i",
57
+ # )
58
+ # result.append(sh_rotated)
59
+ #
60
+ # return torch.cat(result, dim=-1)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ from pathlib import Path
65
+
66
+ import matplotlib.pyplot as plt
67
+ from e3nn.o3 import spherical_harmonics
68
+ from matplotlib import cm
69
+ from scipy.spatial.transform.rotation import Rotation as R
70
+
71
+ device = torch.device("cuda")
72
+
73
+ # Generate random spherical harmonics coefficients.
74
+ degree = 4
75
+ coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device)
76
+
77
+ def plot_sh(sh_coefficients, path: Path) -> None:
78
+ phi = torch.linspace(0, torch.pi, 100, device=device)
79
+ theta = torch.linspace(0, 2 * torch.pi, 100, device=device)
80
+ phi, theta = torch.meshgrid(phi, theta, indexing="xy")
81
+ x = torch.sin(phi) * torch.cos(theta)
82
+ y = torch.sin(phi) * torch.sin(theta)
83
+ z = torch.cos(phi)
84
+ xyz = torch.stack([x, y, z], dim=-1)
85
+ sh = spherical_harmonics(list(range(degree + 1)), xyz, True)
86
+ result = einsum(sh, sh_coefficients, "... n, n -> ...")
87
+ result = (result - result.min()) / (result.max() - result.min())
88
+
89
+ # Set the aspect ratio to 1 so our sphere looks spherical
90
+ fig = plt.figure(figsize=plt.figaspect(1.0))
91
+ ax = fig.add_subplot(111, projection="3d")
92
+ ax.plot_surface(
93
+ x.cpu().numpy(),
94
+ y.cpu().numpy(),
95
+ z.cpu().numpy(),
96
+ rstride=1,
97
+ cstride=1,
98
+ facecolors=cm.seismic(result.cpu().numpy()),
99
+ )
100
+ # Turn off the axis planes
101
+ ax.set_axis_off()
102
+ path.parent.mkdir(exist_ok=True, parents=True)
103
+ plt.savefig(path)
104
+
105
+ for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)):
106
+ rotation = torch.tensor(
107
+ R.from_euler("x", angle.item()).as_matrix(), device=device
108
+ )
109
+ plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png"))
110
+
111
+ print("Done!")
src/misc/sht.py ADDED
@@ -0,0 +1,1637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real spherical harmonics in Cartesian form for PyTorch.
2
+
3
+ This is an autogenerated file. See
4
+ https://github.com/cheind/torch-spherical-harmonics
5
+ for more information.
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def rsh_cart_0(xyz: torch.Tensor):
12
+ """Computes all real spherical harmonics up to degree 0.
13
+
14
+ This is an autogenerated method. See
15
+ https://github.com/cheind/torch-spherical-harmonics
16
+ for more information.
17
+
18
+ Params:
19
+ xyz: (N,...,3) tensor of points on the unit sphere
20
+
21
+ Returns:
22
+ rsh: (N,...,1) real spherical harmonics
23
+ projections of input. Ynm is found at index
24
+ `n*(n+1) + m`, with `0 <= n <= degree` and
25
+ `-n <= m <= n`.
26
+ """
27
+
28
+ return torch.stack(
29
+ [
30
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
31
+ ],
32
+ -1,
33
+ )
34
+
35
+
36
+ def rsh_cart_1(xyz: torch.Tensor):
37
+ """Computes all real spherical harmonics up to degree 1.
38
+
39
+ This is an autogenerated method. See
40
+ https://github.com/cheind/torch-spherical-harmonics
41
+ for more information.
42
+
43
+ Params:
44
+ xyz: (N,...,3) tensor of points on the unit sphere
45
+
46
+ Returns:
47
+ rsh: (N,...,4) real spherical harmonics
48
+ projections of input. Ynm is found at index
49
+ `n*(n+1) + m`, with `0 <= n <= degree` and
50
+ `-n <= m <= n`.
51
+ """
52
+ x = xyz[..., 0]
53
+ y = xyz[..., 1]
54
+ z = xyz[..., 2]
55
+
56
+ return torch.stack(
57
+ [
58
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
59
+ -0.48860251190292 * y,
60
+ 0.48860251190292 * z,
61
+ -0.48860251190292 * x,
62
+ ],
63
+ -1,
64
+ )
65
+
66
+
67
+ def rsh_cart_2(xyz: torch.Tensor):
68
+ """Computes all real spherical harmonics up to degree 2.
69
+
70
+ This is an autogenerated method. See
71
+ https://github.com/cheind/torch-spherical-harmonics
72
+ for more information.
73
+
74
+ Params:
75
+ xyz: (N,...,3) tensor of points on the unit sphere
76
+
77
+ Returns:
78
+ rsh: (N,...,9) real spherical harmonics
79
+ projections of input. Ynm is found at index
80
+ `n*(n+1) + m`, with `0 <= n <= degree` and
81
+ `-n <= m <= n`.
82
+ """
83
+ x = xyz[..., 0]
84
+ y = xyz[..., 1]
85
+ z = xyz[..., 2]
86
+
87
+ x2 = x**2
88
+ y2 = y**2
89
+ z2 = z**2
90
+ xy = x * y
91
+ xz = x * z
92
+ yz = y * z
93
+
94
+ return torch.stack(
95
+ [
96
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
97
+ -0.48860251190292 * y,
98
+ 0.48860251190292 * z,
99
+ -0.48860251190292 * x,
100
+ 1.09254843059208 * xy,
101
+ -1.09254843059208 * yz,
102
+ 0.94617469575756 * z2 - 0.31539156525252,
103
+ -1.09254843059208 * xz,
104
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
105
+ ],
106
+ -1,
107
+ )
108
+
109
+
110
+ def rsh_cart_3(xyz: torch.Tensor):
111
+ """Computes all real spherical harmonics up to degree 3.
112
+
113
+ This is an autogenerated method. See
114
+ https://github.com/cheind/torch-spherical-harmonics
115
+ for more information.
116
+
117
+ Params:
118
+ xyz: (N,...,3) tensor of points on the unit sphere
119
+
120
+ Returns:
121
+ rsh: (N,...,16) real spherical harmonics
122
+ projections of input. Ynm is found at index
123
+ `n*(n+1) + m`, with `0 <= n <= degree` and
124
+ `-n <= m <= n`.
125
+ """
126
+ x = xyz[..., 0]
127
+ y = xyz[..., 1]
128
+ z = xyz[..., 2]
129
+
130
+ x2 = x**2
131
+ y2 = y**2
132
+ z2 = z**2
133
+ xy = x * y
134
+ xz = x * z
135
+ yz = y * z
136
+
137
+ return torch.stack(
138
+ [
139
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
140
+ -0.48860251190292 * y,
141
+ 0.48860251190292 * z,
142
+ -0.48860251190292 * x,
143
+ 1.09254843059208 * xy,
144
+ -1.09254843059208 * yz,
145
+ 0.94617469575756 * z2 - 0.31539156525252,
146
+ -1.09254843059208 * xz,
147
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
148
+ -0.590043589926644 * y * (3.0 * x2 - y2),
149
+ 2.89061144264055 * xy * z,
150
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
151
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
152
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
153
+ 1.44530572132028 * z * (x2 - y2),
154
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
155
+ ],
156
+ -1,
157
+ )
158
+
159
+
160
+ def rsh_cart_4(xyz: torch.Tensor):
161
+ """Computes all real spherical harmonics up to degree 4.
162
+
163
+ This is an autogenerated method. See
164
+ https://github.com/cheind/torch-spherical-harmonics
165
+ for more information.
166
+
167
+ Params:
168
+ xyz: (N,...,3) tensor of points on the unit sphere
169
+
170
+ Returns:
171
+ rsh: (N,...,25) real spherical harmonics
172
+ projections of input. Ynm is found at index
173
+ `n*(n+1) + m`, with `0 <= n <= degree` and
174
+ `-n <= m <= n`.
175
+ """
176
+ x = xyz[..., 0]
177
+ y = xyz[..., 1]
178
+ z = xyz[..., 2]
179
+
180
+ x2 = x**2
181
+ y2 = y**2
182
+ z2 = z**2
183
+ xy = x * y
184
+ xz = x * z
185
+ yz = y * z
186
+ x4 = x2**2
187
+ y4 = y2**2
188
+ z4 = z2**2
189
+
190
+ return torch.stack(
191
+ [
192
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
193
+ -0.48860251190292 * y,
194
+ 0.48860251190292 * z,
195
+ -0.48860251190292 * x,
196
+ 1.09254843059208 * xy,
197
+ -1.09254843059208 * yz,
198
+ 0.94617469575756 * z2 - 0.31539156525252,
199
+ -1.09254843059208 * xz,
200
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
201
+ -0.590043589926644 * y * (3.0 * x2 - y2),
202
+ 2.89061144264055 * xy * z,
203
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
204
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
205
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
206
+ 1.44530572132028 * z * (x2 - y2),
207
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
208
+ 2.5033429417967 * xy * (x2 - y2),
209
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
210
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
211
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
212
+ 1.48099765681286
213
+ * z
214
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
215
+ - 0.952069922236839 * z2
216
+ + 0.317356640745613,
217
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
218
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
219
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
220
+ -3.75501441269506 * x2 * y2
221
+ + 0.625835735449176 * x4
222
+ + 0.625835735449176 * y4,
223
+ ],
224
+ -1,
225
+ )
226
+
227
+
228
+ def rsh_cart_5(xyz: torch.Tensor):
229
+ """Computes all real spherical harmonics up to degree 5.
230
+
231
+ This is an autogenerated method. See
232
+ https://github.com/cheind/torch-spherical-harmonics
233
+ for more information.
234
+
235
+ Params:
236
+ xyz: (N,...,3) tensor of points on the unit sphere
237
+
238
+ Returns:
239
+ rsh: (N,...,36) real spherical harmonics
240
+ projections of input. Ynm is found at index
241
+ `n*(n+1) + m`, with `0 <= n <= degree` and
242
+ `-n <= m <= n`.
243
+ """
244
+ x = xyz[..., 0]
245
+ y = xyz[..., 1]
246
+ z = xyz[..., 2]
247
+
248
+ x2 = x**2
249
+ y2 = y**2
250
+ z2 = z**2
251
+ xy = x * y
252
+ xz = x * z
253
+ yz = y * z
254
+ x4 = x2**2
255
+ y4 = y2**2
256
+ z4 = z2**2
257
+
258
+ return torch.stack(
259
+ [
260
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
261
+ -0.48860251190292 * y,
262
+ 0.48860251190292 * z,
263
+ -0.48860251190292 * x,
264
+ 1.09254843059208 * xy,
265
+ -1.09254843059208 * yz,
266
+ 0.94617469575756 * z2 - 0.31539156525252,
267
+ -1.09254843059208 * xz,
268
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
269
+ -0.590043589926644 * y * (3.0 * x2 - y2),
270
+ 2.89061144264055 * xy * z,
271
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
272
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
273
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
274
+ 1.44530572132028 * z * (x2 - y2),
275
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
276
+ 2.5033429417967 * xy * (x2 - y2),
277
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
278
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
279
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
280
+ 1.48099765681286
281
+ * z
282
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
283
+ - 0.952069922236839 * z2
284
+ + 0.317356640745613,
285
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
286
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
287
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
288
+ -3.75501441269506 * x2 * y2
289
+ + 0.625835735449176 * x4
290
+ + 0.625835735449176 * y4,
291
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
292
+ 8.30264925952416 * xy * z * (x2 - y2),
293
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
294
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
295
+ 0.241571547304372
296
+ * y
297
+ * (
298
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
299
+ + 9.375 * z2
300
+ - 1.875
301
+ ),
302
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
303
+ + 1.6840846433293
304
+ * z
305
+ * (
306
+ 1.75
307
+ * z
308
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
309
+ - 1.125 * z2
310
+ + 0.375
311
+ )
312
+ + 0.498988042467941 * z,
313
+ 0.241571547304372
314
+ * x
315
+ * (
316
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
317
+ + 9.375 * z2
318
+ - 1.875
319
+ ),
320
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
321
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
322
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
323
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
324
+ ],
325
+ -1,
326
+ )
327
+
328
+
329
+ def rsh_cart_6(xyz: torch.Tensor):
330
+ """Computes all real spherical harmonics up to degree 6.
331
+
332
+ This is an autogenerated method. See
333
+ https://github.com/cheind/torch-spherical-harmonics
334
+ for more information.
335
+
336
+ Params:
337
+ xyz: (N,...,3) tensor of points on the unit sphere
338
+
339
+ Returns:
340
+ rsh: (N,...,49) real spherical harmonics
341
+ projections of input. Ynm is found at index
342
+ `n*(n+1) + m`, with `0 <= n <= degree` and
343
+ `-n <= m <= n`.
344
+ """
345
+ x = xyz[..., 0]
346
+ y = xyz[..., 1]
347
+ z = xyz[..., 2]
348
+
349
+ x2 = x**2
350
+ y2 = y**2
351
+ z2 = z**2
352
+ xy = x * y
353
+ xz = x * z
354
+ yz = y * z
355
+ x4 = x2**2
356
+ y4 = y2**2
357
+ z4 = z2**2
358
+
359
+ return torch.stack(
360
+ [
361
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
362
+ -0.48860251190292 * y,
363
+ 0.48860251190292 * z,
364
+ -0.48860251190292 * x,
365
+ 1.09254843059208 * xy,
366
+ -1.09254843059208 * yz,
367
+ 0.94617469575756 * z2 - 0.31539156525252,
368
+ -1.09254843059208 * xz,
369
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
370
+ -0.590043589926644 * y * (3.0 * x2 - y2),
371
+ 2.89061144264055 * xy * z,
372
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
373
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
374
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
375
+ 1.44530572132028 * z * (x2 - y2),
376
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
377
+ 2.5033429417967 * xy * (x2 - y2),
378
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
379
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
380
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
381
+ 1.48099765681286
382
+ * z
383
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
384
+ - 0.952069922236839 * z2
385
+ + 0.317356640745613,
386
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
387
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
388
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
389
+ -3.75501441269506 * x2 * y2
390
+ + 0.625835735449176 * x4
391
+ + 0.625835735449176 * y4,
392
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
393
+ 8.30264925952416 * xy * z * (x2 - y2),
394
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
395
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
396
+ 0.241571547304372
397
+ * y
398
+ * (
399
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
400
+ + 9.375 * z2
401
+ - 1.875
402
+ ),
403
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
404
+ + 1.6840846433293
405
+ * z
406
+ * (
407
+ 1.75
408
+ * z
409
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
410
+ - 1.125 * z2
411
+ + 0.375
412
+ )
413
+ + 0.498988042467941 * z,
414
+ 0.241571547304372
415
+ * x
416
+ * (
417
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
418
+ + 9.375 * z2
419
+ - 1.875
420
+ ),
421
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
422
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
423
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
424
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
425
+ 4.09910463115149 * x**4 * xy
426
+ - 13.6636821038383 * xy**3
427
+ + 4.09910463115149 * xy * y**4,
428
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
429
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
430
+ 0.00584892228263444
431
+ * y
432
+ * (3.0 * x2 - y2)
433
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
434
+ 0.0701870673916132
435
+ * xy
436
+ * (
437
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
438
+ - 91.875 * z2
439
+ + 13.125
440
+ ),
441
+ 0.221950995245231
442
+ * y
443
+ * (
444
+ -2.8 * z * (1.5 - 7.5 * z2)
445
+ + 2.2
446
+ * z
447
+ * (
448
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
449
+ + 9.375 * z2
450
+ - 1.875
451
+ )
452
+ - 4.8 * z
453
+ ),
454
+ -1.48328138624466
455
+ * z
456
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
457
+ + 1.86469659985043
458
+ * z
459
+ * (
460
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
461
+ + 1.8
462
+ * z
463
+ * (
464
+ 1.75
465
+ * z
466
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
467
+ - 1.125 * z2
468
+ + 0.375
469
+ )
470
+ + 0.533333333333333 * z
471
+ )
472
+ + 0.953538034014426 * z2
473
+ - 0.317846011338142,
474
+ 0.221950995245231
475
+ * x
476
+ * (
477
+ -2.8 * z * (1.5 - 7.5 * z2)
478
+ + 2.2
479
+ * z
480
+ * (
481
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
482
+ + 9.375 * z2
483
+ - 1.875
484
+ )
485
+ - 4.8 * z
486
+ ),
487
+ 0.0350935336958066
488
+ * (x2 - y2)
489
+ * (
490
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
491
+ - 91.875 * z2
492
+ + 13.125
493
+ ),
494
+ 0.00584892228263444
495
+ * x
496
+ * (x2 - 3.0 * y2)
497
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
498
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
499
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
500
+ 0.683184105191914 * x2**3
501
+ + 10.2477615778787 * x2 * y4
502
+ - 10.2477615778787 * x4 * y2
503
+ - 0.683184105191914 * y2**3,
504
+ ],
505
+ -1,
506
+ )
507
+
508
+
509
+ def rsh_cart_7(xyz: torch.Tensor):
510
+ """Computes all real spherical harmonics up to degree 7.
511
+
512
+ This is an autogenerated method. See
513
+ https://github.com/cheind/torch-spherical-harmonics
514
+ for more information.
515
+
516
+ Params:
517
+ xyz: (N,...,3) tensor of points on the unit sphere
518
+
519
+ Returns:
520
+ rsh: (N,...,64) real spherical harmonics
521
+ projections of input. Ynm is found at index
522
+ `n*(n+1) + m`, with `0 <= n <= degree` and
523
+ `-n <= m <= n`.
524
+ """
525
+ x = xyz[..., 0]
526
+ y = xyz[..., 1]
527
+ z = xyz[..., 2]
528
+
529
+ x2 = x**2
530
+ y2 = y**2
531
+ z2 = z**2
532
+ xy = x * y
533
+ xz = x * z
534
+ yz = y * z
535
+ x4 = x2**2
536
+ y4 = y2**2
537
+ z4 = z2**2
538
+
539
+ return torch.stack(
540
+ [
541
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
542
+ -0.48860251190292 * y,
543
+ 0.48860251190292 * z,
544
+ -0.48860251190292 * x,
545
+ 1.09254843059208 * xy,
546
+ -1.09254843059208 * yz,
547
+ 0.94617469575756 * z2 - 0.31539156525252,
548
+ -1.09254843059208 * xz,
549
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
550
+ -0.590043589926644 * y * (3.0 * x2 - y2),
551
+ 2.89061144264055 * xy * z,
552
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
553
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
554
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
555
+ 1.44530572132028 * z * (x2 - y2),
556
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
557
+ 2.5033429417967 * xy * (x2 - y2),
558
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
559
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
560
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
561
+ 1.48099765681286
562
+ * z
563
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
564
+ - 0.952069922236839 * z2
565
+ + 0.317356640745613,
566
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
567
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
568
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
569
+ -3.75501441269506 * x2 * y2
570
+ + 0.625835735449176 * x4
571
+ + 0.625835735449176 * y4,
572
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
573
+ 8.30264925952416 * xy * z * (x2 - y2),
574
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
575
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
576
+ 0.241571547304372
577
+ * y
578
+ * (
579
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
580
+ + 9.375 * z2
581
+ - 1.875
582
+ ),
583
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
584
+ + 1.6840846433293
585
+ * z
586
+ * (
587
+ 1.75
588
+ * z
589
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
590
+ - 1.125 * z2
591
+ + 0.375
592
+ )
593
+ + 0.498988042467941 * z,
594
+ 0.241571547304372
595
+ * x
596
+ * (
597
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
598
+ + 9.375 * z2
599
+ - 1.875
600
+ ),
601
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
602
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
603
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
604
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
605
+ 4.09910463115149 * x**4 * xy
606
+ - 13.6636821038383 * xy**3
607
+ + 4.09910463115149 * xy * y**4,
608
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
609
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
610
+ 0.00584892228263444
611
+ * y
612
+ * (3.0 * x2 - y2)
613
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
614
+ 0.0701870673916132
615
+ * xy
616
+ * (
617
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
618
+ - 91.875 * z2
619
+ + 13.125
620
+ ),
621
+ 0.221950995245231
622
+ * y
623
+ * (
624
+ -2.8 * z * (1.5 - 7.5 * z2)
625
+ + 2.2
626
+ * z
627
+ * (
628
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
629
+ + 9.375 * z2
630
+ - 1.875
631
+ )
632
+ - 4.8 * z
633
+ ),
634
+ -1.48328138624466
635
+ * z
636
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
637
+ + 1.86469659985043
638
+ * z
639
+ * (
640
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
641
+ + 1.8
642
+ * z
643
+ * (
644
+ 1.75
645
+ * z
646
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
647
+ - 1.125 * z2
648
+ + 0.375
649
+ )
650
+ + 0.533333333333333 * z
651
+ )
652
+ + 0.953538034014426 * z2
653
+ - 0.317846011338142,
654
+ 0.221950995245231
655
+ * x
656
+ * (
657
+ -2.8 * z * (1.5 - 7.5 * z2)
658
+ + 2.2
659
+ * z
660
+ * (
661
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
662
+ + 9.375 * z2
663
+ - 1.875
664
+ )
665
+ - 4.8 * z
666
+ ),
667
+ 0.0350935336958066
668
+ * (x2 - y2)
669
+ * (
670
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
671
+ - 91.875 * z2
672
+ + 13.125
673
+ ),
674
+ 0.00584892228263444
675
+ * x
676
+ * (x2 - 3.0 * y2)
677
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
678
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
679
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
680
+ 0.683184105191914 * x2**3
681
+ + 10.2477615778787 * x2 * y4
682
+ - 10.2477615778787 * x4 * y2
683
+ - 0.683184105191914 * y2**3,
684
+ -0.707162732524596
685
+ * y
686
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
687
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
688
+ 9.98394571852353e-5
689
+ * y
690
+ * (5197.5 - 67567.5 * z2)
691
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
692
+ 0.00239614697244565
693
+ * xy
694
+ * (x2 - y2)
695
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
696
+ 0.00397356022507413
697
+ * y
698
+ * (3.0 * x2 - y2)
699
+ * (
700
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
701
+ + 1063.125 * z2
702
+ - 118.125
703
+ ),
704
+ 0.0561946276120613
705
+ * xy
706
+ * (
707
+ -4.8 * z * (52.5 * z2 - 7.5)
708
+ + 2.6
709
+ * z
710
+ * (
711
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
712
+ - 91.875 * z2
713
+ + 13.125
714
+ )
715
+ + 48.0 * z
716
+ ),
717
+ 0.206472245902897
718
+ * y
719
+ * (
720
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
721
+ + 2.16666666666667
722
+ * z
723
+ * (
724
+ -2.8 * z * (1.5 - 7.5 * z2)
725
+ + 2.2
726
+ * z
727
+ * (
728
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
729
+ + 9.375 * z2
730
+ - 1.875
731
+ )
732
+ - 4.8 * z
733
+ )
734
+ - 10.9375 * z2
735
+ + 2.1875
736
+ ),
737
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
738
+ - 1.68564615005635
739
+ * z
740
+ * (
741
+ 1.75
742
+ * z
743
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
744
+ - 1.125 * z2
745
+ + 0.375
746
+ )
747
+ + 2.02901851395672
748
+ * z
749
+ * (
750
+ -1.45833333333333
751
+ * z
752
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
753
+ + 1.83333333333333
754
+ * z
755
+ * (
756
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
757
+ + 1.8
758
+ * z
759
+ * (
760
+ 1.75
761
+ * z
762
+ * (
763
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
764
+ - 0.666666666666667 * z
765
+ )
766
+ - 1.125 * z2
767
+ + 0.375
768
+ )
769
+ + 0.533333333333333 * z
770
+ )
771
+ + 0.9375 * z2
772
+ - 0.3125
773
+ )
774
+ - 0.499450711127808 * z,
775
+ 0.206472245902897
776
+ * x
777
+ * (
778
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
779
+ + 2.16666666666667
780
+ * z
781
+ * (
782
+ -2.8 * z * (1.5 - 7.5 * z2)
783
+ + 2.2
784
+ * z
785
+ * (
786
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
787
+ + 9.375 * z2
788
+ - 1.875
789
+ )
790
+ - 4.8 * z
791
+ )
792
+ - 10.9375 * z2
793
+ + 2.1875
794
+ ),
795
+ 0.0280973138060306
796
+ * (x2 - y2)
797
+ * (
798
+ -4.8 * z * (52.5 * z2 - 7.5)
799
+ + 2.6
800
+ * z
801
+ * (
802
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
803
+ - 91.875 * z2
804
+ + 13.125
805
+ )
806
+ + 48.0 * z
807
+ ),
808
+ 0.00397356022507413
809
+ * x
810
+ * (x2 - 3.0 * y2)
811
+ * (
812
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
813
+ + 1063.125 * z2
814
+ - 118.125
815
+ ),
816
+ 0.000599036743111412
817
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
818
+ * (-6.0 * x2 * y2 + x4 + y4),
819
+ 9.98394571852353e-5
820
+ * x
821
+ * (5197.5 - 67567.5 * z2)
822
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
823
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
824
+ -0.707162732524596
825
+ * x
826
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
827
+ ],
828
+ -1,
829
+ )
830
+
831
+
832
+ # @torch.jit.script
833
+ def rsh_cart_8(xyz: torch.Tensor):
834
+ """Computes all real spherical harmonics up to degree 8.
835
+
836
+ This is an autogenerated method. See
837
+ https://github.com/cheind/torch-spherical-harmonics
838
+ for more information.
839
+
840
+ Params:
841
+ xyz: (N,...,3) tensor of points on the unit sphere
842
+
843
+ Returns:
844
+ rsh: (N,...,81) real spherical harmonics
845
+ projections of input. Ynm is found at index
846
+ `n*(n+1) + m`, with `0 <= n <= degree` and
847
+ `-n <= m <= n`.
848
+ """
849
+ x = xyz[..., 0]
850
+ y = xyz[..., 1]
851
+ z = xyz[..., 2]
852
+
853
+ x2 = x**2
854
+ y2 = y**2
855
+ z2 = z**2
856
+ xy = x * y
857
+ xz = x * z
858
+ yz = y * z
859
+ x4 = x2**2
860
+ y4 = y2**2
861
+ # z4 = z2**2
862
+ return torch.stack(
863
+ [
864
+ 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
865
+ -0.48860251190292 * y,
866
+ 0.48860251190292 * z,
867
+ -0.48860251190292 * x,
868
+ 1.09254843059208 * xy,
869
+ -1.09254843059208 * yz,
870
+ 0.94617469575756 * z2 - 0.31539156525252,
871
+ -1.09254843059208 * xz,
872
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
873
+ -0.590043589926644 * y * (3.0 * x2 - y2),
874
+ 2.89061144264055 * xy * z,
875
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
876
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
877
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
878
+ 1.44530572132028 * z * (x2 - y2),
879
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
880
+ 2.5033429417967 * xy * (x2 - y2),
881
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
882
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
883
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
884
+ 1.48099765681286
885
+ * z
886
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
887
+ - 0.952069922236839 * z2
888
+ + 0.317356640745613,
889
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
890
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
891
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
892
+ -3.75501441269506 * x2 * y2
893
+ + 0.625835735449176 * x4
894
+ + 0.625835735449176 * y4,
895
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
896
+ 8.30264925952416 * xy * z * (x2 - y2),
897
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
898
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
899
+ 0.241571547304372
900
+ * y
901
+ * (
902
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
903
+ + 9.375 * z2
904
+ - 1.875
905
+ ),
906
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
907
+ + 1.6840846433293
908
+ * z
909
+ * (
910
+ 1.75
911
+ * z
912
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
913
+ - 1.125 * z2
914
+ + 0.375
915
+ )
916
+ + 0.498988042467941 * z,
917
+ 0.241571547304372
918
+ * x
919
+ * (
920
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
921
+ + 9.375 * z2
922
+ - 1.875
923
+ ),
924
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
925
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
926
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
927
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
928
+ 4.09910463115149 * x**4 * xy
929
+ - 13.6636821038383 * xy**3
930
+ + 4.09910463115149 * xy * y**4,
931
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
932
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
933
+ 0.00584892228263444
934
+ * y
935
+ * (3.0 * x2 - y2)
936
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
937
+ 0.0701870673916132
938
+ * xy
939
+ * (
940
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
941
+ - 91.875 * z2
942
+ + 13.125
943
+ ),
944
+ 0.221950995245231
945
+ * y
946
+ * (
947
+ -2.8 * z * (1.5 - 7.5 * z2)
948
+ + 2.2
949
+ * z
950
+ * (
951
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
952
+ + 9.375 * z2
953
+ - 1.875
954
+ )
955
+ - 4.8 * z
956
+ ),
957
+ -1.48328138624466
958
+ * z
959
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
960
+ + 1.86469659985043
961
+ * z
962
+ * (
963
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
964
+ + 1.8
965
+ * z
966
+ * (
967
+ 1.75
968
+ * z
969
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
970
+ - 1.125 * z2
971
+ + 0.375
972
+ )
973
+ + 0.533333333333333 * z
974
+ )
975
+ + 0.953538034014426 * z2
976
+ - 0.317846011338142,
977
+ 0.221950995245231
978
+ * x
979
+ * (
980
+ -2.8 * z * (1.5 - 7.5 * z2)
981
+ + 2.2
982
+ * z
983
+ * (
984
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
985
+ + 9.375 * z2
986
+ - 1.875
987
+ )
988
+ - 4.8 * z
989
+ ),
990
+ 0.0350935336958066
991
+ * (x2 - y2)
992
+ * (
993
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
994
+ - 91.875 * z2
995
+ + 13.125
996
+ ),
997
+ 0.00584892228263444
998
+ * x
999
+ * (x2 - 3.0 * y2)
1000
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
1001
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
1002
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
1003
+ 0.683184105191914 * x2**3
1004
+ + 10.2477615778787 * x2 * y4
1005
+ - 10.2477615778787 * x4 * y2
1006
+ - 0.683184105191914 * y2**3,
1007
+ -0.707162732524596
1008
+ * y
1009
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
1010
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
1011
+ 9.98394571852353e-5
1012
+ * y
1013
+ * (5197.5 - 67567.5 * z2)
1014
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
1015
+ 0.00239614697244565
1016
+ * xy
1017
+ * (x2 - y2)
1018
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
1019
+ 0.00397356022507413
1020
+ * y
1021
+ * (3.0 * x2 - y2)
1022
+ * (
1023
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1024
+ + 1063.125 * z2
1025
+ - 118.125
1026
+ ),
1027
+ 0.0561946276120613
1028
+ * xy
1029
+ * (
1030
+ -4.8 * z * (52.5 * z2 - 7.5)
1031
+ + 2.6
1032
+ * z
1033
+ * (
1034
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1035
+ - 91.875 * z2
1036
+ + 13.125
1037
+ )
1038
+ + 48.0 * z
1039
+ ),
1040
+ 0.206472245902897
1041
+ * y
1042
+ * (
1043
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1044
+ + 2.16666666666667
1045
+ * z
1046
+ * (
1047
+ -2.8 * z * (1.5 - 7.5 * z2)
1048
+ + 2.2
1049
+ * z
1050
+ * (
1051
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1052
+ + 9.375 * z2
1053
+ - 1.875
1054
+ )
1055
+ - 4.8 * z
1056
+ )
1057
+ - 10.9375 * z2
1058
+ + 2.1875
1059
+ ),
1060
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
1061
+ - 1.68564615005635
1062
+ * z
1063
+ * (
1064
+ 1.75
1065
+ * z
1066
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1067
+ - 1.125 * z2
1068
+ + 0.375
1069
+ )
1070
+ + 2.02901851395672
1071
+ * z
1072
+ * (
1073
+ -1.45833333333333
1074
+ * z
1075
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1076
+ + 1.83333333333333
1077
+ * z
1078
+ * (
1079
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
1080
+ + 1.8
1081
+ * z
1082
+ * (
1083
+ 1.75
1084
+ * z
1085
+ * (
1086
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
1087
+ - 0.666666666666667 * z
1088
+ )
1089
+ - 1.125 * z2
1090
+ + 0.375
1091
+ )
1092
+ + 0.533333333333333 * z
1093
+ )
1094
+ + 0.9375 * z2
1095
+ - 0.3125
1096
+ )
1097
+ - 0.499450711127808 * z,
1098
+ 0.206472245902897
1099
+ * x
1100
+ * (
1101
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1102
+ + 2.16666666666667
1103
+ * z
1104
+ * (
1105
+ -2.8 * z * (1.5 - 7.5 * z2)
1106
+ + 2.2
1107
+ * z
1108
+ * (
1109
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1110
+ + 9.375 * z2
1111
+ - 1.875
1112
+ )
1113
+ - 4.8 * z
1114
+ )
1115
+ - 10.9375 * z2
1116
+ + 2.1875
1117
+ ),
1118
+ 0.0280973138060306
1119
+ * (x2 - y2)
1120
+ * (
1121
+ -4.8 * z * (52.5 * z2 - 7.5)
1122
+ + 2.6
1123
+ * z
1124
+ * (
1125
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1126
+ - 91.875 * z2
1127
+ + 13.125
1128
+ )
1129
+ + 48.0 * z
1130
+ ),
1131
+ 0.00397356022507413
1132
+ * x
1133
+ * (x2 - 3.0 * y2)
1134
+ * (
1135
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1136
+ + 1063.125 * z2
1137
+ - 118.125
1138
+ ),
1139
+ 0.000599036743111412
1140
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
1141
+ * (-6.0 * x2 * y2 + x4 + y4),
1142
+ 9.98394571852353e-5
1143
+ * x
1144
+ * (5197.5 - 67567.5 * z2)
1145
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
1146
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
1147
+ -0.707162732524596
1148
+ * x
1149
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
1150
+ 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
1151
+ -2.91570664069932
1152
+ * yz
1153
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
1154
+ 7.87853281621404e-6
1155
+ * (1013512.5 * z2 - 67567.5)
1156
+ * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
1157
+ 5.10587282657803e-5
1158
+ * y
1159
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
1160
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
1161
+ 0.00147275890257803
1162
+ * xy
1163
+ * (x2 - y2)
1164
+ * (
1165
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
1166
+ - 14293.125 * z2
1167
+ + 1299.375
1168
+ ),
1169
+ 0.0028519853513317
1170
+ * y
1171
+ * (3.0 * x2 - y2)
1172
+ * (
1173
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
1174
+ + 3.0
1175
+ * z
1176
+ * (
1177
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1178
+ + 1063.125 * z2
1179
+ - 118.125
1180
+ )
1181
+ - 560.0 * z
1182
+ ),
1183
+ 0.0463392770473559
1184
+ * xy
1185
+ * (
1186
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1187
+ + 2.5
1188
+ * z
1189
+ * (
1190
+ -4.8 * z * (52.5 * z2 - 7.5)
1191
+ + 2.6
1192
+ * z
1193
+ * (
1194
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1195
+ - 91.875 * z2
1196
+ + 13.125
1197
+ )
1198
+ + 48.0 * z
1199
+ )
1200
+ + 137.8125 * z2
1201
+ - 19.6875
1202
+ ),
1203
+ 0.193851103820053
1204
+ * y
1205
+ * (
1206
+ 3.2 * z * (1.5 - 7.5 * z2)
1207
+ - 2.51428571428571
1208
+ * z
1209
+ * (
1210
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1211
+ + 9.375 * z2
1212
+ - 1.875
1213
+ )
1214
+ + 2.14285714285714
1215
+ * z
1216
+ * (
1217
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1218
+ + 2.16666666666667
1219
+ * z
1220
+ * (
1221
+ -2.8 * z * (1.5 - 7.5 * z2)
1222
+ + 2.2
1223
+ * z
1224
+ * (
1225
+ 2.25
1226
+ * z
1227
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1228
+ + 9.375 * z2
1229
+ - 1.875
1230
+ )
1231
+ - 4.8 * z
1232
+ )
1233
+ - 10.9375 * z2
1234
+ + 2.1875
1235
+ )
1236
+ + 5.48571428571429 * z
1237
+ ),
1238
+ 1.48417251362228
1239
+ * z
1240
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1241
+ - 1.86581687426801
1242
+ * z
1243
+ * (
1244
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
1245
+ + 1.8
1246
+ * z
1247
+ * (
1248
+ 1.75
1249
+ * z
1250
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1251
+ - 1.125 * z2
1252
+ + 0.375
1253
+ )
1254
+ + 0.533333333333333 * z
1255
+ )
1256
+ + 2.1808249179756
1257
+ * z
1258
+ * (
1259
+ 1.14285714285714 * z * (1.5 * z2 - 0.5)
1260
+ - 1.54285714285714
1261
+ * z
1262
+ * (
1263
+ 1.75
1264
+ * z
1265
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1266
+ - 1.125 * z2
1267
+ + 0.375
1268
+ )
1269
+ + 1.85714285714286
1270
+ * z
1271
+ * (
1272
+ -1.45833333333333
1273
+ * z
1274
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
1275
+ + 1.83333333333333
1276
+ * z
1277
+ * (
1278
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
1279
+ + 1.8
1280
+ * z
1281
+ * (
1282
+ 1.75
1283
+ * z
1284
+ * (
1285
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
1286
+ - 0.666666666666667 * z
1287
+ )
1288
+ - 1.125 * z2
1289
+ + 0.375
1290
+ )
1291
+ + 0.533333333333333 * z
1292
+ )
1293
+ + 0.9375 * z2
1294
+ - 0.3125
1295
+ )
1296
+ - 0.457142857142857 * z
1297
+ )
1298
+ - 0.954110901614325 * z2
1299
+ + 0.318036967204775,
1300
+ 0.193851103820053
1301
+ * x
1302
+ * (
1303
+ 3.2 * z * (1.5 - 7.5 * z2)
1304
+ - 2.51428571428571
1305
+ * z
1306
+ * (
1307
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1308
+ + 9.375 * z2
1309
+ - 1.875
1310
+ )
1311
+ + 2.14285714285714
1312
+ * z
1313
+ * (
1314
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1315
+ + 2.16666666666667
1316
+ * z
1317
+ * (
1318
+ -2.8 * z * (1.5 - 7.5 * z2)
1319
+ + 2.2
1320
+ * z
1321
+ * (
1322
+ 2.25
1323
+ * z
1324
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
1325
+ + 9.375 * z2
1326
+ - 1.875
1327
+ )
1328
+ - 4.8 * z
1329
+ )
1330
+ - 10.9375 * z2
1331
+ + 2.1875
1332
+ )
1333
+ + 5.48571428571429 * z
1334
+ ),
1335
+ 0.0231696385236779
1336
+ * (x2 - y2)
1337
+ * (
1338
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1339
+ + 2.5
1340
+ * z
1341
+ * (
1342
+ -4.8 * z * (52.5 * z2 - 7.5)
1343
+ + 2.6
1344
+ * z
1345
+ * (
1346
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
1347
+ - 91.875 * z2
1348
+ + 13.125
1349
+ )
1350
+ + 48.0 * z
1351
+ )
1352
+ + 137.8125 * z2
1353
+ - 19.6875
1354
+ ),
1355
+ 0.0028519853513317
1356
+ * x
1357
+ * (x2 - 3.0 * y2)
1358
+ * (
1359
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
1360
+ + 3.0
1361
+ * z
1362
+ * (
1363
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
1364
+ + 1063.125 * z2
1365
+ - 118.125
1366
+ )
1367
+ - 560.0 * z
1368
+ ),
1369
+ 0.000368189725644507
1370
+ * (-6.0 * x2 * y2 + x4 + y4)
1371
+ * (
1372
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
1373
+ - 14293.125 * z2
1374
+ + 1299.375
1375
+ ),
1376
+ 5.10587282657803e-5
1377
+ * x
1378
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
1379
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
1380
+ 7.87853281621404e-6
1381
+ * (1013512.5 * z2 - 67567.5)
1382
+ * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
1383
+ -2.91570664069932
1384
+ * xz
1385
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
1386
+ -20.4099464848952 * x2**3 * y2
1387
+ - 20.4099464848952 * x2 * y2**3
1388
+ + 0.72892666017483 * x4**2
1389
+ + 51.0248662122381 * x4 * y4
1390
+ + 0.72892666017483 * y4**2,
1391
+ ],
1392
+ -1,
1393
+ )
1394
+
1395
+
1396
+ __all__ = [
1397
+ "rsh_cart_0",
1398
+ "rsh_cart_1",
1399
+ "rsh_cart_2",
1400
+ "rsh_cart_3",
1401
+ "rsh_cart_4",
1402
+ "rsh_cart_5",
1403
+ "rsh_cart_6",
1404
+ "rsh_cart_7",
1405
+ "rsh_cart_8",
1406
+ ]
1407
+
1408
+
1409
+ from typing import Optional
1410
+ import torch
1411
+
1412
+
1413
+ class SphHarm(torch.nn.Module):
1414
+ def __init__(self, m, n, dtype=torch.float32) -> None:
1415
+ super().__init__()
1416
+ self.dtype = dtype
1417
+ m = torch.tensor(list(range(-m + 1, m)))
1418
+ n = torch.tensor(list(range(n)))
1419
+ self.is_normalized = False
1420
+ vals = torch.cartesian_prod(m, n).T
1421
+ vals = vals[:, vals[0] <= vals[1]]
1422
+ m, n = vals.unbind(0)
1423
+
1424
+ self.register_buffer("m", tensor=m)
1425
+ self.register_buffer("n", tensor=n)
1426
+ self.register_buffer("l_max", tensor=torch.max(self.n))
1427
+
1428
+ f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
1429
+ self.register_buffer("f_a", tensor=f_a)
1430
+ self.register_buffer("f_b", tensor=f_b)
1431
+ self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
1432
+ self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
1433
+ self.register_buffer("initial_value", tensor=initial_value)
1434
+
1435
+ @property
1436
+ def device(self):
1437
+ return next(self.buffers()).device
1438
+
1439
+ def forward(self, points: torch.Tensor) -> torch.Tensor:
1440
+ """Computes the spherical harmonics."""
1441
+ # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
1442
+ B, N, D = points.shape
1443
+ dtype = points.dtype
1444
+ theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
1445
+ cos_colatitude = torch.cos(phi)
1446
+ legendre = self._gen_associated_legendre(cos_colatitude)
1447
+ vals = torch.stack([self.m.abs(), self.n], dim=0)
1448
+ vals = torch.cat(
1449
+ [
1450
+ vals.repeat(1, theta.shape[0]),
1451
+ torch.arange(theta.shape[0], device=theta.device)
1452
+ .unsqueeze(0)
1453
+ .repeat_interleave(vals.shape[1], dim=1),
1454
+ ],
1455
+ dim=0,
1456
+ )
1457
+ legendre_vals = legendre[vals[0], vals[1], vals[2]]
1458
+ legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
1459
+ angle = torch.outer(self.m.abs(), theta)
1460
+ vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
1461
+ harmonics = torch.complex(
1462
+ legendre_vals * torch.real(vandermonde),
1463
+ legendre_vals * torch.imag(vandermonde),
1464
+ )
1465
+
1466
+ # Negative order.
1467
+ m = self.m.unsqueeze(-1)
1468
+ harmonics = torch.where(
1469
+ m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
1470
+ )
1471
+ harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
1472
+ return harmonics
1473
+
1474
+ def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
1475
+ """Generates mask for recurrence relation on the remaining entries.
1476
+
1477
+ The remaining entries are with respect to the diagonal and offdiagonal
1478
+ entries.
1479
+
1480
+ Args:
1481
+ l_max: see `gen_normalized_legendre`.
1482
+ Returns:
1483
+ torch.Tensors representing the mask used by the recurrence relations.
1484
+ """
1485
+
1486
+ # Computes all coefficients.
1487
+ m_mat, l_mat = torch.meshgrid(
1488
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
1489
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
1490
+ indexing="ij",
1491
+ )
1492
+ if self.is_normalized:
1493
+ c0 = l_mat * l_mat
1494
+ c1 = m_mat * m_mat
1495
+ c2 = 2.0 * l_mat
1496
+ c3 = (l_mat - 1.0) * (l_mat - 1.0)
1497
+ d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
1498
+ d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
1499
+ else:
1500
+ d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
1501
+ d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
1502
+
1503
+ d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
1504
+ d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
1505
+
1506
+ d_zeros = torch.zeros(
1507
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
1508
+ )
1509
+ d_zeros[d0_mask_indices] = d0[d0_mask_indices]
1510
+ d0_mask = d_zeros
1511
+
1512
+ d_zeros = torch.zeros(
1513
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
1514
+ )
1515
+ d_zeros[d1_mask_indices] = d1[d1_mask_indices]
1516
+ d1_mask = d_zeros
1517
+
1518
+ # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
1519
+ i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
1520
+ j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
1521
+ k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
1522
+ mask = (i + j - k == 0).to(self.dtype)
1523
+ d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
1524
+ d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
1525
+ return (d0_mask_3d, d1_mask_3d)
1526
+
1527
+ def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
1528
+ coeff_0 = self.d0_mask_3d[i]
1529
+ coeff_1 = self.d1_mask_3d[i]
1530
+ h = torch.einsum(
1531
+ "ij,ijk->ijk",
1532
+ coeff_0,
1533
+ torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
1534
+ ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
1535
+ p_val = p_val + h
1536
+ return p_val
1537
+
1538
+ def _init_legendre(self):
1539
+ a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
1540
+ b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
1541
+ if self.is_normalized:
1542
+ # The initial value p(0,0).
1543
+ initial_value: torch.Tensor = torch.tensor(
1544
+ 0.5 / (torch.pi**0.5), device=self.device
1545
+ )
1546
+ f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
1547
+ f_b = torch.sqrt(2.0 * b_idx + 3.0)
1548
+ else:
1549
+ # The initial value p(0,0).
1550
+ initial_value = torch.tensor(1.0, device=self.device)
1551
+ f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
1552
+ f_b = 2.0 * b_idx + 1.0
1553
+
1554
+ d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
1555
+ return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
1556
+
1557
+ def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
1558
+ r"""Computes associated Legendre functions (ALFs) of the first kind.
1559
+
1560
+ The ALFs of the first kind are used in spherical harmonics. The spherical
1561
+ harmonic of degree `l` and order `m` can be written as
1562
+ `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
1563
+ normalization factor and θ and φ are the colatitude and longitude,
1564
+ repectively. `N_l^m` is chosen in the way that the spherical harmonics form
1565
+ a set of orthonormal basis function of L^2(S^2). For the computational
1566
+ efficiency of spherical harmonics transform, the normalization factor is
1567
+ used in the computation of the ALFs. In addition, normalizing `P_l^m`
1568
+ avoids overflow/underflow and achieves better numerical stability. Three
1569
+ recurrence relations are used in the computation.
1570
+
1571
+ Args:
1572
+ l_max: The maximum degree of the associated Legendre function. Both the
1573
+ degrees and orders are `[0, 1, 2, ..., l_max]`.
1574
+ x: A vector of type `float32`, `float64` containing the sampled points in
1575
+ spherical coordinates, at which the ALFs are computed; `x` is essentially
1576
+ `cos(θ)`. For the numerical integration used by the spherical harmonics
1577
+ transforms, `x` contains the quadrature points in the interval of
1578
+ `[-1, 1]`. There are several approaches to provide the quadrature points:
1579
+ Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
1580
+ method (`scipy.special.roots_chebyu`), and Driscoll & Healy
1581
+ method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
1582
+ transforms and convolutions on the 2-sphere." Advances in applied
1583
+ mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
1584
+ points are nearly equal-spaced along θ and provide exact discrete
1585
+ orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
1586
+ operation, `W` is a diagonal matrix containing the quadrature weights,
1587
+ and `I` is the identity matrix. The Gauss-Chebyshev points are equally
1588
+ spaced, which only provide approximate discrete orthogonality. The
1589
+ Driscoll & Healy qudarture points are equally spaced and provide the
1590
+ exact discrete orthogonality. The number of sampling points is required to
1591
+ be twice as the number of frequency points (modes) in the Driscoll & Healy
1592
+ approach, which enables FFT and achieves a fast spherical harmonics
1593
+ transform.
1594
+ is_normalized: True if the associated Legendre functions are normalized.
1595
+ With normalization, `N_l^m` is applied such that the spherical harmonics
1596
+ form a set of orthonormal basis functions of L^2(S^2).
1597
+
1598
+ Returns:
1599
+ The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
1600
+ of the ALFs at `x`; the dimensions in the sequence of order, degree, and
1601
+ evalution points.
1602
+ """
1603
+ p = torch.zeros(
1604
+ (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
1605
+ )
1606
+ p[0, 0] = self.initial_value
1607
+
1608
+ # Compute the diagonal entries p(l,l) with recurrence.
1609
+ y = torch.cumprod(
1610
+ torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
1611
+ )
1612
+ p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
1613
+ # torch.diag_indices(l_max + 1)
1614
+ diag_indices = torch.stack(
1615
+ [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
1616
+ )
1617
+ p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
1618
+
1619
+ diag_indices = torch.stack(
1620
+ [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
1621
+ )
1622
+
1623
+ # Compute the off-diagonal entries with recurrence.
1624
+ p_offdiag = torch.einsum(
1625
+ "ij,ij->ij",
1626
+ torch.einsum("i,j->ij", self.f_b, x),
1627
+ p[(diag_indices[0], diag_indices[1])],
1628
+ ) # p[torch.diag_indices(l_max)])
1629
+ p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
1630
+ p_offdiag
1631
+ )
1632
+
1633
+ # Compute the remaining entries with recurrence.
1634
+ if self.l_max > 1:
1635
+ for i in range(2, self.l_max + 1):
1636
+ p = self._recursive(i, p, x)
1637
+ return p
src/misc/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from src.visualization.color_map import apply_color_map_to_image
4
+ import torch.distributed as dist
5
+
6
+ def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
7
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
8
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device).view(-1, 1, 1)
9
+ return tensor.mul(std).add(mean)
10
+
11
+
12
+ # Color-map the result.
13
+ def vis_depth_map(result, near=None, far=None):
14
+ if near is None and far is None:
15
+ far = result.view(-1)[:16_000_000].quantile(0.99).log()
16
+ try:
17
+ near = result[result > 0][:16_000_000].quantile(0.01).log()
18
+ except:
19
+ print("No valid depth values found.")
20
+ near = torch.zeros_like(far)
21
+ else:
22
+ near = near.log()
23
+ far = far.log()
24
+
25
+ result = result.log()
26
+ result = 1 - (result - near) / (far - near)
27
+ return apply_color_map_to_image(result, "turbo")
28
+
29
+
30
+ def confidence_map(result):
31
+ # far = result.view(-1)[:16_000_000].quantile(0.99).log()
32
+ # try:
33
+ # near = result[result > 0][:16_000_000].quantile(0.01).log()
34
+ # except:
35
+ # print("No valid depth values found.")
36
+ # near = torch.zeros_like(far)
37
+ # result = result.log()
38
+ # result = 1 - (result - near) / (far - near)
39
+ result = result / result.view(-1).max()
40
+ return apply_color_map_to_image(result, "magma")
41
+
42
+
43
+ def get_overlap_tag(overlap):
44
+ if 0.05 <= overlap <= 0.3:
45
+ overlap_tag = "small"
46
+ elif overlap <= 0.55:
47
+ overlap_tag = "medium"
48
+ elif overlap <= 0.8:
49
+ overlap_tag = "large"
50
+ else:
51
+ overlap_tag = "ignore"
52
+
53
+ return overlap_tag
54
+
55
+
56
+ def is_dist_avail_and_initialized():
57
+ if not dist.is_available():
58
+ return False
59
+ if not dist.is_initialized():
60
+ return False
61
+ return True
62
+
63
+
64
+ def get_world_size():
65
+ if not is_dist_avail_and_initialized():
66
+ return 1
67
+ return dist.get_world_size()
68
+
69
+
70
+ def get_rank():
71
+ if not is_dist_avail_and_initialized():
72
+ return 0
73
+ return dist.get_rank()
src/model/decoder/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .decoder import Decoder
2
+ from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg
3
+
4
+ DECODERS = {
5
+ "splatting_cuda": DecoderSplattingCUDA,
6
+ }
7
+
8
+ DecoderCfg = DecoderSplattingCUDACfg
9
+
10
+
11
+ def get_decoder(decoder_cfg: DecoderCfg) -> Decoder:
12
+ return DECODERS[decoder_cfg.name](decoder_cfg)
src/model/decoder/cuda_splatting.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import isqrt
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from diff_gaussian_rasterization import (
6
+ GaussianRasterizationSettings,
7
+ GaussianRasterizer,
8
+ )
9
+ from einops import einsum, rearrange, repeat
10
+ from jaxtyping import Float, Bool
11
+ from torch import Tensor
12
+
13
+ from ...geometry.projection import get_fov, homogenize_points
14
+
15
+
16
+ def get_projection_matrix(
17
+ near: Float[Tensor, " batch"],
18
+ far: Float[Tensor, " batch"],
19
+ fov_x: Float[Tensor, " batch"],
20
+ fov_y: Float[Tensor, " batch"],
21
+ ) -> Float[Tensor, "batch 4 4"]:
22
+ """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z
23
+ axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after
24
+ transformation and that Z is flipped.
25
+ """
26
+ tan_fov_x = (0.5 * fov_x).tan()
27
+ tan_fov_y = (0.5 * fov_y).tan()
28
+
29
+ top = tan_fov_y * near
30
+ bottom = -top
31
+ right = tan_fov_x * near
32
+ left = -right
33
+
34
+ (b,) = near.shape
35
+ result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device)
36
+ result[:, 0, 0] = 2 * near / (right - left)
37
+ result[:, 1, 1] = 2 * near / (top - bottom)
38
+ result[:, 0, 2] = (right + left) / (right - left)
39
+ result[:, 1, 2] = (top + bottom) / (top - bottom)
40
+ result[:, 3, 2] = 1
41
+ result[:, 2, 2] = far / (far - near)
42
+ result[:, 2, 3] = -(far * near) / (far - near)
43
+ return result
44
+
45
+
46
+ def render_cuda(
47
+ extrinsics: Float[Tensor, "batch 4 4"],
48
+ intrinsics: Float[Tensor, "batch 3 3"],
49
+ near: Float[Tensor, " batch"],
50
+ far: Float[Tensor, " batch"],
51
+ image_shape: tuple[int, int],
52
+ background_color: Float[Tensor, "batch 3"],
53
+ gaussian_means: Float[Tensor, "batch gaussian 3"],
54
+ gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
55
+ gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
56
+ gaussian_opacities: Float[Tensor, "batch gaussian"],
57
+ scale_invariant: bool = True,
58
+ use_sh: bool = True,
59
+ cam_rot_delta: Float[Tensor, "batch 3"] | None = None,
60
+ cam_trans_delta: Float[Tensor, "batch 3"] | None = None,
61
+ voxel_masks: Bool[Tensor, "batch gaussian"] | None = None,
62
+ ) -> tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch height width"]]:
63
+ assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
64
+
65
+ # Make sure everything is in a range where numerical issues don't appear.
66
+ if scale_invariant:
67
+ scale = 1 / near
68
+ extrinsics = extrinsics.clone()
69
+ extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None]
70
+ gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2)
71
+ gaussian_means = gaussian_means * scale[:, None, None]
72
+ near = near * scale
73
+ far = far * scale
74
+
75
+ _, _, _, n = gaussian_sh_coefficients.shape
76
+ degree = isqrt(n) - 1
77
+ shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
78
+
79
+ b, _, _ = extrinsics.shape
80
+ h, w = image_shape
81
+
82
+ fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
83
+ tan_fov_x = (0.5 * fov_x).tan()
84
+ tan_fov_y = (0.5 * fov_y).tan()
85
+
86
+ projection_matrix = get_projection_matrix(near, far, fov_x, fov_y)
87
+ projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
88
+ view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
89
+ full_projection = view_matrix @ projection_matrix
90
+
91
+ all_images = []
92
+ all_radii = []
93
+ all_depths = []
94
+ for i in range(b):
95
+ # Set up a tensor for the gradients of the screen-space means.
96
+ mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
97
+ try:
98
+ mean_gradients.retain_grad()
99
+ except Exception:
100
+ pass
101
+
102
+ settings = GaussianRasterizationSettings(
103
+ image_height=h,
104
+ image_width=w,
105
+ tanfovx=tan_fov_x[i].item(),
106
+ tanfovy=tan_fov_y[i].item(),
107
+ bg=background_color[i],
108
+ scale_modifier=1.0,
109
+ viewmatrix=view_matrix[i],
110
+ projmatrix=full_projection[i],
111
+ projmatrix_raw=projection_matrix[i],
112
+ sh_degree=degree,
113
+ campos=extrinsics[i, :3, 3],
114
+ prefiltered=False, # This matches the original usage.
115
+ debug=False,
116
+ )
117
+ rasterizer = GaussianRasterizer(settings)
118
+
119
+ row, col = torch.triu_indices(3, 3)
120
+
121
+ if voxel_masks is not None:
122
+ voxel_mask = voxel_masks[i]
123
+ image, radii, depth, opacity, n_touched = rasterizer(
124
+ means3D=gaussian_means[i][voxel_mask],
125
+ means2D=mean_gradients[voxel_mask],
126
+ shs=shs[i][voxel_mask] if use_sh else None,
127
+ colors_precomp=None if use_sh else shs[i, :, 0, :][voxel_mask],
128
+ opacities=gaussian_opacities[i][voxel_mask, ..., None],
129
+ cov3D_precomp=gaussian_covariances[i, :, row, col][voxel_mask],
130
+ theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
131
+ rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
132
+ )
133
+ else:
134
+ image, radii, depth, opacity, n_touched = rasterizer(
135
+ means3D=gaussian_means[i],
136
+ means2D=mean_gradients,
137
+ shs=shs[i] if use_sh else None,
138
+ colors_precomp=None if use_sh else shs[i, :, 0, :],
139
+ opacities=gaussian_opacities[i, ..., None],
140
+ cov3D_precomp=gaussian_covariances[i, :, row, col],
141
+ theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
142
+ rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
143
+ )
144
+ all_images.append(image)
145
+ all_radii.append(radii)
146
+ all_depths.append(depth.squeeze(0))
147
+ return torch.stack(all_images), torch.stack(all_depths)
148
+
149
+
150
+ def render_cuda_orthographic(
151
+ extrinsics: Float[Tensor, "batch 4 4"],
152
+ width: Float[Tensor, " batch"],
153
+ height: Float[Tensor, " batch"],
154
+ near: Float[Tensor, " batch"],
155
+ far: Float[Tensor, " batch"],
156
+ image_shape: tuple[int, int],
157
+ background_color: Float[Tensor, "batch 3"],
158
+ gaussian_means: Float[Tensor, "batch gaussian 3"],
159
+ gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
160
+ gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
161
+ gaussian_opacities: Float[Tensor, "batch gaussian"],
162
+ fov_degrees: float = 0.1,
163
+ use_sh: bool = True,
164
+ dump: dict | None = None,
165
+ ) -> Float[Tensor, "batch 3 height width"]:
166
+ b, _, _ = extrinsics.shape
167
+ h, w = image_shape
168
+ assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
169
+
170
+ _, _, _, n = gaussian_sh_coefficients.shape
171
+ degree = isqrt(n) - 1
172
+ shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
173
+
174
+ # Create fake "orthographic" projection by moving the camera back and picking a
175
+ # small field of view.
176
+ fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad()
177
+ tan_fov_x = (0.5 * fov_x).tan()
178
+ distance_to_near = (0.5 * width) / tan_fov_x
179
+ tan_fov_y = 0.5 * height / distance_to_near
180
+ fov_y = (2 * tan_fov_y).atan()
181
+ near = near + distance_to_near
182
+ far = far + distance_to_near
183
+ move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device)
184
+ move_back[2, 3] = -distance_to_near
185
+ extrinsics = extrinsics @ move_back
186
+
187
+ # Escape hatch for visualization/figures.
188
+ if dump is not None:
189
+ dump["extrinsics"] = extrinsics
190
+ dump["fov_x"] = fov_x
191
+ dump["fov_y"] = fov_y
192
+ dump["near"] = near
193
+ dump["far"] = far
194
+
195
+ projection_matrix = get_projection_matrix(
196
+ near, far, repeat(fov_x, "-> b", b=b), fov_y
197
+ )
198
+ projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
199
+ view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
200
+ full_projection = view_matrix @ projection_matrix
201
+
202
+ all_images = []
203
+ all_radii = []
204
+ for i in range(b):
205
+ # Set up a tensor for the gradients of the screen-space means.
206
+ mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
207
+ try:
208
+ mean_gradients.retain_grad()
209
+ except Exception:
210
+ pass
211
+
212
+ settings = GaussianRasterizationSettings(
213
+ image_height=h,
214
+ image_width=w,
215
+ tanfovx=tan_fov_x,
216
+ tanfovy=tan_fov_y,
217
+ bg=background_color[i],
218
+ scale_modifier=1.0,
219
+ viewmatrix=view_matrix[i],
220
+ projmatrix=full_projection[i],
221
+ projmatrix_raw=projection_matrix[i],
222
+ sh_degree=degree,
223
+ campos=extrinsics[i, :3, 3],
224
+ prefiltered=False, # This matches the original usage.
225
+ debug=False,
226
+ )
227
+ rasterizer = GaussianRasterizer(settings)
228
+
229
+ row, col = torch.triu_indices(3, 3)
230
+
231
+ image, radii, depth, opacity, n_touched = rasterizer(
232
+ means3D=gaussian_means[i],
233
+ means2D=mean_gradients,
234
+ shs=shs[i] if use_sh else None,
235
+ colors_precomp=None if use_sh else shs[i, :, 0, :],
236
+ opacities=gaussian_opacities[i, ..., None],
237
+ cov3D_precomp=gaussian_covariances[i, :, row, col],
238
+ )
239
+ all_images.append(image)
240
+ all_radii.append(radii)
241
+ return torch.stack(all_images)
242
+
243
+
244
+ DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]
src/model/decoder/decoder.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Generic, Literal, TypeVar, Optional
4
+
5
+ from jaxtyping import Float
6
+ from torch import Tensor, nn
7
+
8
+ from ..types import Gaussians
9
+
10
+ DepthRenderingMode = Literal[
11
+ "depth",
12
+ "log",
13
+ "disparity",
14
+ "relative_disparity",
15
+ ]
16
+
17
+ @dataclass
18
+ class DecoderOutput:
19
+ color: Float[Tensor, "batch view 3 height width"]
20
+ depth: Float[Tensor, "batch view height width"] | None
21
+ alpha: Float[Tensor, "batch view height width"] | None
22
+ lod_rendering: dict | None
23
+ pts_all: Optional[Float[Tensor, "batch view height width 3"]]=None
24
+ conf: Optional[Float[Tensor, "batch view height width"]]=None
25
+
26
+ T = TypeVar("T")
27
+
28
+
29
+ class Decoder(nn.Module, ABC, Generic[T]):
30
+ cfg: T
31
+
32
+ def __init__(self, cfg: T) -> None:
33
+ super().__init__()
34
+ self.cfg = cfg
35
+
36
+ @abstractmethod
37
+ def forward(
38
+ self,
39
+ gaussians: Gaussians,
40
+ extrinsics: Float[Tensor, "batch view 4 4"],
41
+ intrinsics: Float[Tensor, "batch view 3 3"],
42
+ near: Float[Tensor, "batch view"],
43
+ far: Float[Tensor, "batch view"],
44
+ image_shape: tuple[int, int],
45
+ depth_mode: DepthRenderingMode | None = None,
46
+ ) -> DecoderOutput:
47
+ pass