bennyguo commited on
Commit
c43e1bf
·
0 Parent(s):

initial commit

Browse files
Files changed (9) hide show
  1. .gitattributes +35 -0
  2. .gitignore +5 -0
  3. README.md +14 -0
  4. app.py +198 -0
  5. example_inputs_b64.py +0 -0
  6. model.py +1725 -0
  7. requirements.txt +6 -0
  8. static/viewer/viewer.html +167 -0
  9. triposplat.py +598 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ static/example_inputs/
2
+ gradio_outputs/
3
+ ckpts/
4
+ __pycache__/
5
+ *.pyc
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TripoSplat
3
+ emoji: 👁
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 6.15.2
8
+ python_version: '3.12'
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TripoSplat Gradio demo with Spark.js in-browser viewer.
2
+ Usage: python app.py
3
+ """
4
+ import base64
5
+ import subprocess
6
+ import tempfile
7
+ import time
8
+ from pathlib import Path
9
+ from uuid import uuid4
10
+
11
+ import gradio as gr
12
+ import spaces
13
+ import torch
14
+
15
+ from triposplat import TripoSplatPipeline
16
+ import example_inputs_b64 as _b64
17
+
18
+ # ----------------------------------------------------------------------------
19
+ # Download checkpoints from HuggingFace Hub (VAST-AI/TripoSplat)
20
+ # ----------------------------------------------------------------------------
21
+
22
+ subprocess.run(
23
+ [
24
+ "hf", "download",
25
+ "VAST-AI/TripoSplat",
26
+ "--local-dir", "ckpts"
27
+ ],
28
+ check=True,
29
+ )
30
+
31
+ # ----------------------------------------------------------------------------
32
+ # Pipeline (loaded once at startup)
33
+ # ----------------------------------------------------------------------------
34
+
35
+ PIPE = TripoSplatPipeline(
36
+ ckpt_path = "ckpts/diffusion_models/triposplat_fp16.safetensors",
37
+ decoder_path = "ckpts/vae/triposplat_vae_decoder_fp16.safetensors",
38
+ dinov3_path = "ckpts/clip_vision/dino_v3_vit_h.safetensors",
39
+ flux2_vae_encoder_path = "ckpts/vae/flux2-vae.safetensors",
40
+ rmbg_path = "ckpts/background_removal/birefnet.safetensors",
41
+ device = "cuda",
42
+ )
43
+
44
+ OUT_ROOT = Path("gradio_outputs").resolve()
45
+ OUT_ROOT.mkdir(parents=True, exist_ok=True)
46
+ VIEWER_HTML = Path("static/viewer/viewer.html").resolve()
47
+
48
+ # Decode example images from base64 into a persistent temp directory so that
49
+ # gr.Examples (which needs file paths) works without binary files in the repo.
50
+ _EXAMPLES_TMPDIR = tempfile.mkdtemp(prefix="triposplat_examples_")
51
+ def _write_example(varname: str, filename: str) -> str:
52
+ path = Path(_EXAMPLES_TMPDIR) / filename
53
+ path.write_bytes(base64.b64decode(getattr(_b64, varname)))
54
+ return str(path)
55
+
56
+ EXAMPLES = [
57
+ _write_example("CREATURE_BUTTERFLY", "creature_butterfly.webp"),
58
+ _write_example("BUILDING_STONE_HOUSE", "building_stone_house.webp"),
59
+ _write_example("VEHICLE_PIRATE_SHIP", "vehicle_pirate_ship.webp"),
60
+ _write_example("PLANT_WATER_LILY", "plant_water_lily.webp"),
61
+ ]
62
+
63
+ PLACEHOLDER_HTML = (
64
+ "<div style='display:flex;align-items:center;justify-content:center;height:520px;"
65
+ "color:#94a3b8;font:16px system-ui;background:#111318;border-radius:12px'>"
66
+ "3D viewer will appear here after generation</div>"
67
+ )
68
+
69
+
70
+ def _gr_file(path: Path) -> str:
71
+ """Gradio serves any file under `allowed_paths` at `/gradio_api/file=<abspath>`."""
72
+ return f"/gradio_api/file={path.as_posix()}"
73
+
74
+
75
+ def _viewer_iframe(ply_path: Path) -> str:
76
+ ts = time.time() # cache-bust so the iframe reloads each generation
77
+ src = f"{_gr_file(VIEWER_HTML)}?ply={_gr_file(ply_path)}&ts={ts}"
78
+ return (
79
+ f"<iframe src='{src}' "
80
+ "style='width:100%;height:520px;border:0;border-radius:12px;background:#0a0b0e'></iframe>"
81
+ )
82
+
83
+
84
+ # ----------------------------------------------------------------------------
85
+ # Event handlers
86
+ # ----------------------------------------------------------------------------
87
+
88
+ def on_image_change(image):
89
+ """Run preprocessing as soon as the input changes — gives the user instant
90
+ feedback on the matte/crop without waiting for the full generation."""
91
+ if image is None:
92
+ return None
93
+ return PIPE.preprocess_image(image)
94
+
95
+
96
+ @spaces.GPU
97
+ def generate(prepared, seed: int, steps: int, guidance_scale: float,
98
+ num_gaussians: int, output_format: str,
99
+ progress=gr.Progress(track_tqdm=True)):
100
+ if prepared is None:
101
+ raise gr.Error("Please upload an image and wait for preprocessing to finish.")
102
+
103
+ progress(0, desc="Generating...")
104
+ t0 = time.time()
105
+ gen = torch.Generator(device=PIPE._device).manual_seed(int(seed))
106
+ cond = PIPE.encode_image(prepared, generator=gen)
107
+ out = PIPE.sample_latent(cond, steps=int(steps),
108
+ guidance_scale=float(guidance_scale),
109
+ generator=gen, show_progress=True)
110
+ gaussian = PIPE.decode_latent(out["latent"], num_gaussians=int(num_gaussians))
111
+ gen_dt = time.time() - t0
112
+
113
+ out_dir = OUT_ROOT / uuid4().hex[:12]
114
+ out_dir.mkdir(parents=True, exist_ok=True)
115
+ ply_path = out_dir / "splat.ply"
116
+ gaussian.save_ply(str(ply_path))
117
+
118
+ fmt = output_format.lower()
119
+ if fmt == "ply":
120
+ download_path = ply_path
121
+ elif fmt == "splat":
122
+ download_path = out_dir / "splat.splat"
123
+ gaussian.save_splat(str(download_path))
124
+ else:
125
+ raise gr.Error(f"Unknown output format: {output_format}")
126
+
127
+ info = (f"{gaussian.get_xyz.shape[0]:,} gaussians · "
128
+ f"generation: {gen_dt:.1f}s · saved: {download_path.name}")
129
+ return _viewer_iframe(ply_path), gr.update(value=str(download_path), interactive=True), info
130
+
131
+
132
+ # ----------------------------------------------------------------------------
133
+ # Gradio UI
134
+ # ----------------------------------------------------------------------------
135
+
136
+ with gr.Blocks(title="TripoSplat") as demo:
137
+ gr.Markdown("# TripoSplat")
138
+ gr.Markdown(
139
+ "TripoSplat converts a single 2D image into high-quality and variable number of 3D Gaussians, developed by [TripoAI](https://www.tripo3d.ai/). "
140
+ "It can serve as a powerful pipeline tool for asset creation, AR/VR, game development, simulation environments, and beyond.\n\n"
141
+ "[Read Paper](https://arxiv.org/abs/2605.16355) | [Research Blog](https://www.tripo3d.ai/research/triposplat)"
142
+ )
143
+
144
+ image_in = gr.Image(label="Input image", type="pil", image_mode="RGBA",
145
+ height=320, render=False)
146
+
147
+ gr.Examples(
148
+ examples=[[p] for p in EXAMPLES],
149
+ inputs=[image_in],
150
+ label="Examples (click to load)",
151
+ examples_per_page=10,
152
+ cache_examples=False,
153
+ )
154
+
155
+ with gr.Row():
156
+ with gr.Column(scale=1):
157
+ image_in.render()
158
+
159
+ with gr.Accordion("Sampling settings", open=False):
160
+ seed_in = gr.Number(label="Seed", value=42, precision=0)
161
+ steps_in = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=20)
162
+ cfg_in = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.5, value=3.0)
163
+ num_g_in = gr.Dropdown(
164
+ label="Number of gaussians",
165
+ choices=["32768", "65536", "131072", "262144"],
166
+ value="262144",
167
+ )
168
+ fmt_in = gr.Dropdown(label="Download format", choices=["ply", "splat"], value="ply")
169
+
170
+ run_btn = gr.Button("Generate", variant="primary")
171
+ prepared_out = gr.Image(label="Preprocessed input", interactive=False, height=240)
172
+ info_out = gr.Markdown()
173
+
174
+ with gr.Column(scale=2):
175
+ viewer_out = gr.HTML(value=PLACEHOLDER_HTML, label="Spark.js viewer")
176
+ file_out = gr.DownloadButton(label="Download", value=None, interactive=False)
177
+
178
+ image_in.change(
179
+ fn=on_image_change,
180
+ inputs=[image_in],
181
+ outputs=[prepared_out],
182
+ )
183
+
184
+ run_btn.click(
185
+ fn=generate,
186
+ inputs=[prepared_out, seed_in, steps_in, cfg_in, num_g_in, fmt_in],
187
+ outputs=[viewer_out, file_out, info_out],
188
+ )
189
+
190
+
191
+ if __name__ == "__main__":
192
+ demo.launch(
193
+ allowed_paths=[
194
+ str(VIEWER_HTML.parent),
195
+ str(OUT_ROOT),
196
+ _EXAMPLES_TMPDIR,
197
+ ],
198
+ )
example_inputs_b64.py ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,1725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import math
3
+ import re
4
+
5
+ import numpy as np
6
+ import safetensors.torch
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision.ops import deform_conv2d
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # DINOv3 ViT-H/16+
15
+ # ---------------------------------------------------------------------------
16
+
17
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
18
+ x1, x2 = x.chunk(2, dim=-1)
19
+ return torch.cat((-x2, x1), dim=-1)
20
+
21
+
22
+ class DinoV3PatchEmbed(nn.Module):
23
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=1280):
24
+ super().__init__()
25
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
26
+
27
+ def forward(self, x):
28
+ return self.proj(x).flatten(2).transpose(1, 2)
29
+
30
+
31
+ class DinoV3RotaryEmbedding2D(nn.Module):
32
+ def __init__(self, dim: int, base: float = 100.0):
33
+ super().__init__()
34
+ inv_freq = 1.0 / (base ** torch.arange(0, 1, 4.0 / dim, dtype=torch.float32))
35
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
36
+
37
+ def forward(self, height: int, width: int, device: torch.device, dtype: torch.dtype):
38
+ coords_h = torch.arange(0.5, height, dtype=torch.float32, device=device) / height
39
+ coords_w = torch.arange(0.5, width, dtype=torch.float32, device=device) / width
40
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
41
+ coords = (2.0 * coords - 1.0).flatten(0, 1)
42
+ angles = (2 * math.pi * coords[:, :, None] * self.inv_freq[None, None, :]).flatten(1, 2).tile(2)
43
+ cos = angles.cos().unsqueeze(0).unsqueeze(0)
44
+ sin = angles.sin().unsqueeze(0).unsqueeze(0)
45
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
46
+
47
+
48
+ class DinoV3Attention(nn.Module):
49
+ def __init__(self, dim: int, num_heads: int, qkv_bias: tuple = (True, False, True)):
50
+ super().__init__()
51
+ self.num_heads = num_heads
52
+ self.head_dim = dim // num_heads
53
+ q_bias, k_bias, v_bias = qkv_bias
54
+ self.q_proj = nn.Linear(dim, dim, bias=q_bias)
55
+ self.k_proj = nn.Linear(dim, dim, bias=k_bias)
56
+ self.v_proj = nn.Linear(dim, dim, bias=v_bias)
57
+ self.o_proj = nn.Linear(dim, dim, bias=True)
58
+
59
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
60
+ num_prefix_tokens: int = 0) -> torch.Tensor:
61
+ B, N, C = x.shape
62
+ q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
63
+ k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
64
+ v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
65
+ if num_prefix_tokens > 0:
66
+ q_pre, q_pat = q.split((num_prefix_tokens, N - num_prefix_tokens), dim=-2)
67
+ k_pre, k_pat = k.split((num_prefix_tokens, N - num_prefix_tokens), dim=-2)
68
+ q = torch.cat((q_pre, q_pat * cos + _rotate_half(q_pat) * sin), dim=-2)
69
+ k = torch.cat((k_pre, k_pat * cos + _rotate_half(k_pat) * sin), dim=-2)
70
+ else:
71
+ q = q * cos + _rotate_half(q) * sin
72
+ k = k * cos + _rotate_half(k) * sin
73
+ out = F.scaled_dot_product_attention(q, k, v)
74
+ return self.o_proj(out.transpose(1, 2).reshape(B, N, C))
75
+
76
+
77
+ class DinoV3MLP(nn.Module):
78
+ def __init__(self, dim: int, hidden_dim: int, bias: bool = True):
79
+ super().__init__()
80
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias)
81
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
82
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
86
+
87
+
88
+ class DinoV3Block(nn.Module):
89
+ def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0,
90
+ qkv_bias: tuple = (True, False, True), layerscale_init: float = 1.0,
91
+ mlp_bias: bool = True, eps: float = 1e-5):
92
+ super().__init__()
93
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
94
+ self.attn = DinoV3Attention(dim, num_heads, qkv_bias=qkv_bias)
95
+ self.ls1 = nn.Parameter(torch.ones(dim) * layerscale_init)
96
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
97
+ self.mlp = DinoV3MLP(dim, int(dim * mlp_ratio), bias=mlp_bias)
98
+ self.ls2 = nn.Parameter(torch.ones(dim) * layerscale_init)
99
+
100
+ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
101
+ num_prefix_tokens: int = 0) -> torch.Tensor:
102
+ x = x + self.ls1 * self.attn(self.norm1(x), cos, sin, num_prefix_tokens=num_prefix_tokens)
103
+ x = x + self.ls2 * self.mlp(self.norm2(x))
104
+ return x
105
+
106
+
107
+ class DinoV3ViT(nn.Module):
108
+ def __init__(self, hidden_size: int = 1280, num_heads: int = 20, num_layers: int = 32,
109
+ patch_size: int = 16, num_register_tokens: int = 4,
110
+ intermediate_size: int = 5120, layerscale_init: float = 1.0,
111
+ query_bias: bool = True, key_bias: bool = False, value_bias: bool = True,
112
+ mlp_bias: bool = True, rope_theta: float = 100.0, layer_norm_eps: float = 1e-5):
113
+ super().__init__()
114
+ self.patch_size = patch_size
115
+ self.num_register_tokens = num_register_tokens
116
+ self.patch_embed = DinoV3PatchEmbed(patch_size=patch_size, embed_dim=hidden_size)
117
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
118
+ self.register_tokens = nn.Parameter(torch.zeros(1, num_register_tokens, hidden_size))
119
+ self.rope = DinoV3RotaryEmbedding2D(dim=hidden_size // num_heads, base=rope_theta)
120
+ qkv_bias = (query_bias, key_bias, value_bias)
121
+ self.blocks = nn.ModuleList([
122
+ DinoV3Block(hidden_size, num_heads, mlp_ratio=intermediate_size / hidden_size,
123
+ qkv_bias=qkv_bias, layerscale_init=layerscale_init,
124
+ mlp_bias=mlp_bias, eps=layer_norm_eps)
125
+ for _ in range(num_layers)
126
+ ])
127
+ self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
128
+
129
+ @property
130
+ def device(self) -> torch.device:
131
+ return self.cls_token.device
132
+
133
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
134
+ B, _, H, W = pixel_values.shape
135
+ x = self.patch_embed(pixel_values)
136
+ hp, wp = H // self.patch_size, W // self.patch_size
137
+ cos, sin = self.rope(hp, wp, x.device, x.dtype)
138
+ x = torch.cat([self.cls_token.expand(B, -1, -1),
139
+ self.register_tokens.expand(B, -1, -1), x], dim=1)
140
+ num_prefix = 1 + self.num_register_tokens
141
+ for block in self.blocks:
142
+ x = block(x, cos, sin, num_prefix_tokens=num_prefix)
143
+ return self.norm(x)
144
+
145
+ def load_safetensors(self, path: str) -> None:
146
+ state_dict = safetensors.torch.load_file(path)
147
+ our_sd = self.state_dict()
148
+ loaded = {}
149
+ for hf_key in state_dict:
150
+ k = (hf_key
151
+ .replace("embeddings.patch_embeddings.", "patch_embed.proj.")
152
+ .replace("embeddings.cls_token", "cls_token")
153
+ .replace("embeddings.mask_token", "mask_token")
154
+ .replace("embeddings.register_tokens", "register_tokens"))
155
+ m = re.match(r"layer\.(\d+)\.(.+)", k)
156
+ if m:
157
+ rest = m.group(2)
158
+ for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
159
+ rest = rest.replace(f"attention.{proj}", f"attn.{proj}")
160
+ rest = (rest.replace("layer_scale1.lambda1", "ls1")
161
+ .replace("layer_scale2.lambda1", "ls2"))
162
+ k = f"blocks.{m.group(1)}.{rest}"
163
+ if k in our_sd:
164
+ assert state_dict[hf_key].shape == our_sd[k].shape, \
165
+ f"Shape mismatch {k}: {state_dict[hf_key].shape} vs {our_sd[k].shape}"
166
+ loaded[k] = state_dict[hf_key]
167
+ check_sd = {k: v for k, v in our_sd.items() if k != "mask_token"}
168
+ missing = set(check_sd) - set(loaded)
169
+ unexpected = set(loaded) - set(check_sd)
170
+ if missing:
171
+ raise KeyError(f"[DINOv3] Missing keys: {missing}")
172
+ if unexpected:
173
+ raise KeyError(f"[DINOv3] Unexpected keys: {unexpected}")
174
+ self.load_state_dict(loaded, strict=True)
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Flux2 VAE Encoder
179
+ # ---------------------------------------------------------------------------
180
+
181
+ class Flux2ResnetBlock(nn.Module):
182
+ def __init__(self, in_channels, out_channels, use_shortcut=False):
183
+ super().__init__()
184
+ self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-6)
185
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
186
+ self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-6)
187
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
188
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, 1, 1, 0) if use_shortcut else None
189
+
190
+ def forward(self, x):
191
+ h = F.silu(self.norm1(x))
192
+ h = F.silu(self.norm2(self.conv1(h)))
193
+ h = self.conv2(h)
194
+ return h + (self.conv_shortcut(x) if self.conv_shortcut is not None else x)
195
+
196
+
197
+ class Flux2Downsampler(nn.Module):
198
+ def __init__(self, channels):
199
+ super().__init__()
200
+ self.conv = nn.Conv2d(channels, channels, 3, 2, 0)
201
+
202
+ def forward(self, x):
203
+ return self.conv(F.pad(x, (0, 1, 0, 1)))
204
+
205
+
206
+ class Flux2Attention(nn.Module):
207
+ def __init__(self, channels):
208
+ super().__init__()
209
+ self.group_norm = nn.GroupNorm(32, channels, eps=1e-6)
210
+ self.to_q = nn.Linear(channels, channels)
211
+ self.to_k = nn.Linear(channels, channels)
212
+ self.to_v = nn.Linear(channels, channels)
213
+ self.to_out = nn.ModuleList([nn.Linear(channels, channels), nn.Identity()])
214
+
215
+ def forward(self, x):
216
+ B, C, H, W = x.shape
217
+ h = self.group_norm(x).reshape(B, C, H * W).transpose(1, 2)
218
+ q = self.to_q(h).reshape(B, -1, 1, C).permute(0, 2, 1, 3)
219
+ k = self.to_k(h).reshape(B, -1, 1, C).permute(0, 2, 1, 3)
220
+ v = self.to_v(h).reshape(B, -1, 1, C).permute(0, 2, 1, 3)
221
+ out = F.scaled_dot_product_attention(q, k, v)
222
+ out = self.to_out[0](out.permute(0, 2, 1, 3).reshape(B, -1, C))
223
+ return x + out.transpose(1, 2).reshape(B, C, H, W)
224
+
225
+
226
+ class Flux2Encoder(nn.Module):
227
+ def __init__(self):
228
+ super().__init__()
229
+ self.conv_in = nn.Conv2d(3, 128, 3, 1, 1)
230
+ self.down_0_resnets = nn.ModuleList([Flux2ResnetBlock(128, 128), Flux2ResnetBlock(128, 128)])
231
+ self.down_0_sampler = Flux2Downsampler(128)
232
+ self.down_1_resnets = nn.ModuleList([Flux2ResnetBlock(128, 256, use_shortcut=True), Flux2ResnetBlock(256, 256)])
233
+ self.down_1_sampler = Flux2Downsampler(256)
234
+ self.down_2_resnets = nn.ModuleList([Flux2ResnetBlock(256, 512, use_shortcut=True), Flux2ResnetBlock(512, 512)])
235
+ self.down_2_sampler = Flux2Downsampler(512)
236
+ self.down_3_resnets = nn.ModuleList([Flux2ResnetBlock(512, 512), Flux2ResnetBlock(512, 512)])
237
+ self.mid_attn = Flux2Attention(512)
238
+ self.mid_resnets = nn.ModuleList([Flux2ResnetBlock(512, 512), Flux2ResnetBlock(512, 512)])
239
+ self.conv_norm_out = nn.GroupNorm(32, 512, eps=1e-6)
240
+ self.conv_out = nn.Conv2d(512, 64, 3, 1, 1)
241
+
242
+ def forward(self, x):
243
+ x = self.conv_in(x)
244
+ for r in self.down_0_resnets: x = r(x)
245
+ x = self.down_0_sampler(x)
246
+ for r in self.down_1_resnets: x = r(x)
247
+ x = self.down_1_sampler(x)
248
+ for r in self.down_2_resnets: x = r(x)
249
+ x = self.down_2_sampler(x)
250
+ for r in self.down_3_resnets: x = r(x)
251
+ x = self.mid_resnets[0](x)
252
+ x = self.mid_attn(x)
253
+ x = self.mid_resnets[1](x)
254
+ return self.conv_out(F.silu(self.conv_norm_out(x)))
255
+
256
+
257
+ class Flux2VAEEncoder(nn.Module):
258
+ def __init__(self):
259
+ super().__init__()
260
+ self.encoder = Flux2Encoder()
261
+ self.quant_conv = nn.Conv2d(64, 64, 1, 1, 0)
262
+ self.bn = nn.BatchNorm1d(128, eps=1e-5, momentum=0.1, affine=False, track_running_stats=True)
263
+
264
+ def load_safetensors(self, path: str):
265
+ sd = safetensors.torch.load_file(path)
266
+ remapped = {}
267
+ for k, v in sd.items():
268
+ # Skip the decoder half of a full Flux2-VAE ckpt — we only need the encoder.
269
+ if k.startswith(("decoder.", "post_quant_conv.")):
270
+ continue
271
+ # Comfy / diffusers-style naming → our flattened naming.
272
+ m = re.match(r"encoder\.down_blocks\.(\d+)\.resnets\.(\d+)\.(.+)", k)
273
+ if m:
274
+ remapped[f"encoder.down_{m.group(1)}_resnets.{m.group(2)}.{m.group(3)}"] = v
275
+ continue
276
+ m = re.match(r"encoder\.down_blocks\.(\d+)\.downsamplers\.0\.(.+)", k)
277
+ if m:
278
+ remapped[f"encoder.down_{m.group(1)}_sampler.{m.group(2)}"] = v
279
+ continue
280
+ m = re.match(r"encoder\.mid_block\.resnets\.(\d+)\.(.+)", k)
281
+ if m:
282
+ remapped[f"encoder.mid_resnets.{m.group(1)}.{m.group(2)}"] = v
283
+ continue
284
+ m = re.match(r"encoder\.mid_block\.attentions\.0\.(.+)", k)
285
+ if m:
286
+ remapped[f"encoder.mid_attn.{m.group(1)}"] = v
287
+ continue
288
+ remapped[k] = v
289
+ missing, unexpected = self.load_state_dict(remapped, strict=False)
290
+ if missing:
291
+ raise KeyError(f"[VAE] Missing keys: {missing}")
292
+ if unexpected:
293
+ raise KeyError(f"[VAE] Unexpected keys: {unexpected}")
294
+
295
+ def encode(self, images, deterministic: bool = True, generator: torch.Generator = None):
296
+ moments = self.quant_conv(self.encoder(images))
297
+ mean, logvar = moments.chunk(2, dim=1)
298
+ if deterministic:
299
+ latents = mean
300
+ else:
301
+ noise = torch.randn(mean.shape, dtype=mean.dtype, device=mean.device, generator=generator)
302
+ latents = mean + torch.exp(0.5 * logvar) * noise
303
+ B, C, H, W = latents.shape
304
+ latents = latents.view(B, C, H // 2, 2, W // 2, 2).permute(0, 1, 3, 5, 2, 4)
305
+ latents = latents.reshape(B, C * 4, H // 2, W // 2)
306
+ bn_mean = self.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
307
+ bn_std = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn.eps).to(latents.device, latents.dtype)
308
+ return ((latents - bn_mean) / bn_std).to(torch.float32).flatten(2).transpose(1, 2).contiguous()
309
+
310
+ return rgba
311
+
312
+
313
+ # ---------------------------------------------------------------------------
314
+ # BiRefNet background removal (Swin-L + ASPP-deformable decoder)
315
+ # ---------------------------------------------------------------------------
316
+
317
+ # -- timm-style helpers, inlined to avoid a timm dependency --------------------
318
+
319
+ def _trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0,
320
+ a: float = -2.0, b: float = 2.0) -> torch.Tensor:
321
+ # Initialization helper — only used at __init__ time, the released ckpt
322
+ # overwrites everything in load_safetensors so the exact distribution here
323
+ # is unimportant.
324
+ with torch.no_grad():
325
+ tensor.normal_(mean, std).clamp_(mean + a * std, mean + b * std)
326
+ return tensor
327
+
328
+
329
+ # -- Swin Transformer (Swin-Large preset) -------------------------------------
330
+
331
+ class _SwinMlp(nn.Module):
332
+ def __init__(self, in_features, hidden_features=None, out_features=None):
333
+ super().__init__()
334
+ hidden_features = hidden_features or in_features
335
+ out_features = out_features or in_features
336
+ self.fc1 = nn.Linear(in_features, hidden_features)
337
+ self.act = nn.GELU()
338
+ self.fc2 = nn.Linear(hidden_features, out_features)
339
+
340
+ def forward(self, x):
341
+ return self.fc2(self.act(self.fc1(x)))
342
+
343
+
344
+ def _window_partition(x, window_size):
345
+ B, H, W, C = x.shape
346
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
347
+ return x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
348
+
349
+
350
+ def _window_reverse(windows, window_size, H, W):
351
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
352
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
353
+ return x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
354
+
355
+
356
+ class _WindowAttention(nn.Module):
357
+ def __init__(self, dim, window_size, num_heads):
358
+ super().__init__()
359
+ self.dim = dim
360
+ self.window_size = window_size # (Wh, Ww)
361
+ self.num_heads = num_heads
362
+ head_dim = dim // num_heads
363
+ self.scale = head_dim ** -0.5
364
+
365
+ self.relative_position_bias_table = nn.Parameter(
366
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
367
+ coords_h = torch.arange(window_size[0])
368
+ coords_w = torch.arange(window_size[1])
369
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
370
+ coords_flatten = torch.flatten(coords, 1)
371
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
372
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
373
+ relative_coords[:, :, 0] += window_size[0] - 1
374
+ relative_coords[:, :, 1] += window_size[1] - 1
375
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
376
+ self.register_buffer("relative_position_index", relative_coords.sum(-1))
377
+
378
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
379
+ self.proj = nn.Linear(dim, dim)
380
+ _trunc_normal_(self.relative_position_bias_table, std=0.02)
381
+
382
+ def forward(self, x, mask=None):
383
+ B_, N, C = x.shape
384
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
385
+ q, k, v = qkv[0], qkv[1], qkv[2]
386
+ q = q * self.scale
387
+ attn = q @ k.transpose(-2, -1)
388
+ bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
389
+ self.window_size[0] * self.window_size[1],
390
+ self.window_size[0] * self.window_size[1], -1)
391
+ attn = attn + bias.permute(2, 0, 1).contiguous().unsqueeze(0)
392
+ if mask is not None:
393
+ nW = mask.shape[0]
394
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
395
+ attn = attn.view(-1, self.num_heads, N, N)
396
+ attn = attn.softmax(dim=-1)
397
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
398
+ return self.proj(x)
399
+
400
+
401
+ class _SwinBlock(nn.Module):
402
+ def __init__(self, dim, num_heads, window_size, shift_size, mlp_ratio=4.0):
403
+ super().__init__()
404
+ self.dim = dim
405
+ self.window_size = window_size
406
+ self.shift_size = shift_size
407
+ self.norm1 = nn.LayerNorm(dim)
408
+ self.attn = _WindowAttention(dim, (window_size, window_size), num_heads)
409
+ self.norm2 = nn.LayerNorm(dim)
410
+ self.mlp = _SwinMlp(dim, int(dim * mlp_ratio))
411
+ self.H = None
412
+ self.W = None
413
+
414
+ def forward(self, x, mask_matrix):
415
+ B, L, C = x.shape
416
+ H, W = self.H, self.W
417
+ shortcut = x
418
+ x = self.norm1(x).view(B, H, W, C)
419
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
420
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
421
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
422
+ _, Hp, Wp, _ = x.shape
423
+ if self.shift_size > 0:
424
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
425
+ attn_mask = mask_matrix
426
+ else:
427
+ shifted_x = x
428
+ attn_mask = None
429
+ x_windows = _window_partition(shifted_x, self.window_size).view(
430
+ -1, self.window_size * self.window_size, C)
431
+ attn_windows = self.attn(x_windows, mask=attn_mask).view(
432
+ -1, self.window_size, self.window_size, C)
433
+ shifted_x = _window_reverse(attn_windows, self.window_size, Hp, Wp)
434
+ if self.shift_size > 0:
435
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
436
+ else:
437
+ x = shifted_x
438
+ if pad_r > 0 or pad_b > 0:
439
+ x = x[:, :H, :W, :].contiguous()
440
+ x = x.view(B, H * W, C)
441
+ x = shortcut + x
442
+ x = x + self.mlp(self.norm2(x))
443
+ return x
444
+
445
+
446
+ class _PatchMerging(nn.Module):
447
+ def __init__(self, dim):
448
+ super().__init__()
449
+ self.dim = dim
450
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
451
+ self.norm = nn.LayerNorm(4 * dim)
452
+
453
+ def forward(self, x, H, W):
454
+ B, L, C = x.shape
455
+ x = x.view(B, H, W, C)
456
+ if H % 2 == 1 or W % 2 == 1:
457
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
458
+ x0 = x[:, 0::2, 0::2, :]
459
+ x1 = x[:, 1::2, 0::2, :]
460
+ x2 = x[:, 0::2, 1::2, :]
461
+ x3 = x[:, 1::2, 1::2, :]
462
+ x = torch.cat([x0, x1, x2, x3], -1).view(B, -1, 4 * C)
463
+ return self.reduction(self.norm(x))
464
+
465
+
466
+ class _SwinBasicLayer(nn.Module):
467
+ def __init__(self, dim, depth, num_heads, window_size, mlp_ratio=4.0, downsample=True):
468
+ super().__init__()
469
+ self.window_size = window_size
470
+ self.shift_size = window_size // 2
471
+ self.depth = depth
472
+ self.blocks = nn.ModuleList([
473
+ _SwinBlock(dim=dim, num_heads=num_heads, window_size=window_size,
474
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
475
+ mlp_ratio=mlp_ratio)
476
+ for i in range(depth)
477
+ ])
478
+ self.downsample = _PatchMerging(dim) if downsample else None
479
+
480
+ def forward(self, x, H, W):
481
+ Hp = int(math.ceil(H / self.window_size)) * self.window_size
482
+ Wp = int(math.ceil(W / self.window_size)) * self.window_size
483
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
484
+ h_slices = (slice(0, -self.window_size),
485
+ slice(-self.window_size, -self.shift_size),
486
+ slice(-self.shift_size, None))
487
+ w_slices = (slice(0, -self.window_size),
488
+ slice(-self.window_size, -self.shift_size),
489
+ slice(-self.shift_size, None))
490
+ cnt = 0
491
+ for h in h_slices:
492
+ for w in w_slices:
493
+ img_mask[:, h, w, :] = cnt
494
+ cnt += 1
495
+ mask_windows = _window_partition(img_mask, self.window_size).view(-1, self.window_size ** 2)
496
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
497
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)) \
498
+ .masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
499
+ for blk in self.blocks:
500
+ blk.H, blk.W = H, W
501
+ x = blk(x, attn_mask)
502
+ if self.downsample is not None:
503
+ x_down = self.downsample(x, H, W)
504
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
505
+ return x, H, W, x_down, Wh, Ww
506
+ return x, H, W, x, H, W
507
+
508
+
509
+ class _SwinPatchEmbed(nn.Module):
510
+ def __init__(self, patch_size=4, in_channels=3, embed_dim=192):
511
+ super().__init__()
512
+ self.patch_size = (patch_size, patch_size)
513
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
514
+ self.norm = nn.LayerNorm(embed_dim)
515
+ self.embed_dim = embed_dim
516
+
517
+ def forward(self, x):
518
+ _, _, H, W = x.shape
519
+ if W % self.patch_size[1] != 0:
520
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
521
+ if H % self.patch_size[0] != 0:
522
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
523
+ x = self.proj(x)
524
+ Wh, Ww = x.size(2), x.size(3)
525
+ x = x.flatten(2).transpose(1, 2)
526
+ x = self.norm(x)
527
+ return x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
528
+
529
+
530
+ class _SwinLarge(nn.Module):
531
+ """Swin-Large backbone matching the BiRefNet HF release.
532
+
533
+ embed_dim=192, depths=[2,2,18,2], num_heads=[6,12,24,48], window_size=12.
534
+ """
535
+ def __init__(self):
536
+ super().__init__()
537
+ embed_dim = 192
538
+ depths = [2, 2, 18, 2]
539
+ num_heads = [6, 12, 24, 48]
540
+ window_size = 12
541
+ self.num_layers = len(depths)
542
+ self.embed_dim = embed_dim
543
+ self.patch_embed = _SwinPatchEmbed(patch_size=4, in_channels=3, embed_dim=embed_dim)
544
+ self.layers = nn.ModuleList([
545
+ _SwinBasicLayer(
546
+ dim=int(embed_dim * 2 ** i),
547
+ depth=depths[i],
548
+ num_heads=num_heads[i],
549
+ window_size=window_size,
550
+ downsample=(i < self.num_layers - 1),
551
+ ) for i in range(self.num_layers)
552
+ ])
553
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
554
+ self.num_features = num_features
555
+ for i in range(self.num_layers):
556
+ self.add_module(f"norm{i}", nn.LayerNorm(num_features[i]))
557
+
558
+ def forward(self, x):
559
+ x = self.patch_embed(x)
560
+ Wh, Ww = x.size(2), x.size(3)
561
+ x = x.flatten(2).transpose(1, 2)
562
+ outs = []
563
+ for i in range(self.num_layers):
564
+ x_out, H, W, x, Wh, Ww = self.layers[i](x, Wh, Ww)
565
+ norm_layer = getattr(self, f"norm{i}")
566
+ x_out = norm_layer(x_out)
567
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
568
+ outs.append(out)
569
+ return tuple(outs)
570
+
571
+
572
+ # -- ASPP-Deformable -----------------------------------------------------------
573
+
574
+ class _DeformableConv2d(nn.Module):
575
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
576
+ super().__init__()
577
+ if isinstance(kernel_size, int):
578
+ kernel_size = (kernel_size, kernel_size)
579
+ self.stride = (stride, stride) if isinstance(stride, int) else stride
580
+ self.padding = padding
581
+ self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size[0] * kernel_size[1],
582
+ kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
583
+ self.modulator_conv = nn.Conv2d(in_channels, 1 * kernel_size[0] * kernel_size[1],
584
+ kernel_size=kernel_size, stride=stride, padding=padding, bias=True)
585
+ self.regular_conv = nn.Conv2d(in_channels, out_channels,
586
+ kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
587
+
588
+ def forward(self, x):
589
+ offset = self.offset_conv(x)
590
+ modulator = 2.0 * torch.sigmoid(self.modulator_conv(x))
591
+ return deform_conv2d(
592
+ input=x, offset=offset,
593
+ weight=self.regular_conv.weight, bias=self.regular_conv.bias,
594
+ padding=self.padding, mask=modulator, stride=self.stride,
595
+ )
596
+
597
+
598
+ class _ASPPModuleDeformable(nn.Module):
599
+ def __init__(self, in_channels, planes, kernel_size, padding):
600
+ super().__init__()
601
+ self.atrous_conv = _DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
602
+ stride=1, padding=padding, bias=False)
603
+ self.bn = nn.BatchNorm2d(planes)
604
+ self.relu = nn.ReLU(inplace=True)
605
+
606
+ def forward(self, x):
607
+ return self.relu(self.bn(self.atrous_conv(x)))
608
+
609
+
610
+ class _ASPPDeformable(nn.Module):
611
+ def __init__(self, in_channels, out_channels=None, parallel_block_sizes=(1, 3, 7)):
612
+ super().__init__()
613
+ if out_channels is None:
614
+ out_channels = in_channels
615
+ inter = 256
616
+ self.aspp1 = _ASPPModuleDeformable(in_channels, inter, 1, padding=0)
617
+ self.aspp_deforms = nn.ModuleList([
618
+ _ASPPModuleDeformable(in_channels, inter, k, padding=k // 2)
619
+ for k in parallel_block_sizes
620
+ ])
621
+ self.global_avg_pool = nn.Sequential(
622
+ nn.AdaptiveAvgPool2d((1, 1)),
623
+ nn.Conv2d(in_channels, inter, 1, stride=1, bias=False),
624
+ nn.BatchNorm2d(inter),
625
+ nn.ReLU(inplace=True),
626
+ )
627
+ self.conv1 = nn.Conv2d(inter * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
628
+ self.bn1 = nn.BatchNorm2d(out_channels)
629
+ self.relu = nn.ReLU(inplace=True)
630
+
631
+ def forward(self, x):
632
+ x1 = self.aspp1(x)
633
+ x_aspp_deforms = [m(x) for m in self.aspp_deforms]
634
+ x5 = self.global_avg_pool(x)
635
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
636
+ y = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
637
+ return self.relu(self.bn1(self.conv1(y)))
638
+
639
+
640
+ # -- Decoder blocks ------------------------------------------------------------
641
+
642
+ class _BasicDecBlk(nn.Module):
643
+ def __init__(self, in_channels, out_channels, inter_channels=64):
644
+ super().__init__()
645
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
646
+ self.bn_in = nn.BatchNorm2d(inter_channels)
647
+ self.relu_in = nn.ReLU(inplace=True)
648
+ self.dec_att = _ASPPDeformable(in_channels=inter_channels)
649
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
650
+ self.bn_out = nn.BatchNorm2d(out_channels)
651
+
652
+ def forward(self, x):
653
+ x = self.relu_in(self.bn_in(self.conv_in(x)))
654
+ x = self.dec_att(x)
655
+ x = self.bn_out(self.conv_out(x))
656
+ return x
657
+
658
+
659
+ class _BasicLatBlk(nn.Module):
660
+ def __init__(self, in_channels, out_channels):
661
+ super().__init__()
662
+ self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
663
+
664
+ def forward(self, x):
665
+ return self.conv(x)
666
+
667
+
668
+ class _SimpleConvs(nn.Module):
669
+ def __init__(self, in_channels, out_channels, inter_channels=64):
670
+ super().__init__()
671
+ self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
672
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
673
+
674
+ def forward(self, x):
675
+ return self.conv_out(self.conv1(x))
676
+
677
+
678
+ # -- Image → patch-stack helper -----------------------------------------------
679
+
680
+ def _image2patches(image, patch_ref):
681
+ """`einops` rearrange 'b c (hg h) (wg w) -> b (c hg wg) h w' replacement.
682
+
683
+ Splits `image` into hg×wg non-overlapping patches and stacks them along
684
+ the channel axis. `hg`/`wg` are inferred from image and patch_ref sizes.
685
+ """
686
+ b, c, h_full, w_full = image.shape
687
+ hg, wg = h_full // patch_ref.shape[-2], w_full // patch_ref.shape[-1]
688
+ h, w = h_full // hg, w_full // wg
689
+ # (b, c, hg*h, wg*w) -> (b, c, hg, h, wg, w) -> (b, c, hg, wg, h, w) -> (b, c*hg*wg, h, w)
690
+ return image.view(b, c, hg, h, wg, w).permute(0, 1, 2, 4, 3, 5).reshape(b, c * hg * wg, h, w)
691
+
692
+
693
+ # -- Decoder + top-level BiRefNet ---------------------------------------------
694
+
695
+ class _BiRefNetDecoder(nn.Module):
696
+ def __init__(self, channels=(3072, 1536, 768, 384)):
697
+ super().__init__()
698
+ c = channels # high-to-low resolution channel counts
699
+ # input-modulator blocks (one per resolution; channels are
700
+ # `3 * patch_grid**2`, see _image2patches docstring).
701
+ self.ipt_blk5 = _SimpleConvs(2 ** 10 * 3, c[0] // 8, inter_channels=64)
702
+ self.ipt_blk4 = _SimpleConvs(2 ** 8 * 3, c[0] // 8, inter_channels=64)
703
+ self.ipt_blk3 = _SimpleConvs(2 ** 6 * 3, c[1] // 8, inter_channels=64)
704
+ self.ipt_blk2 = _SimpleConvs(2 ** 4 * 3, c[2] // 8, inter_channels=64)
705
+ self.ipt_blk1 = _SimpleConvs(2 ** 0 * 3, c[3] // 8, inter_channels=64)
706
+
707
+ self.decoder_block4 = _BasicDecBlk(c[0] + c[0] // 8, c[1])
708
+ self.decoder_block3 = _BasicDecBlk(c[1] + c[0] // 8, c[2])
709
+ self.decoder_block2 = _BasicDecBlk(c[2] + c[1] // 8, c[3])
710
+ self.decoder_block1 = _BasicDecBlk(c[3] + c[2] // 8, c[3] // 2)
711
+ self.conv_out1 = nn.Sequential(nn.Conv2d(c[3] // 2 + c[3] // 8, 1, 1, 1, 0))
712
+
713
+ self.lateral_block4 = _BasicLatBlk(c[1], c[1])
714
+ self.lateral_block3 = _BasicLatBlk(c[2], c[2])
715
+ self.lateral_block2 = _BasicLatBlk(c[3], c[3])
716
+
717
+ # multi-scale supervision heads (training only — kept for state_dict
718
+ # parity with the released checkpoint; not consumed at inference).
719
+ self.conv_ms_spvn_4 = nn.Conv2d(c[1], 1, 1, 1, 0)
720
+ self.conv_ms_spvn_3 = nn.Conv2d(c[2], 1, 1, 1, 0)
721
+ self.conv_ms_spvn_2 = nn.Conv2d(c[3], 1, 1, 1, 0)
722
+
723
+ # gradient-decoder-triggering (gdt) attention: used at inference to
724
+ # gate p4/p3/p2.
725
+ _N = 16
726
+ def _gdt_branch(in_c):
727
+ return nn.Sequential(nn.Conv2d(in_c, _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True))
728
+ self.gdt_convs_4 = _gdt_branch(c[1])
729
+ self.gdt_convs_3 = _gdt_branch(c[2])
730
+ self.gdt_convs_2 = _gdt_branch(c[3])
731
+
732
+ def _head_1x1():
733
+ return nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
734
+ # multi-scale supervision heads on the gdt branch (training only)
735
+ self.gdt_convs_pred_4 = _head_1x1()
736
+ self.gdt_convs_pred_3 = _head_1x1()
737
+ self.gdt_convs_pred_2 = _head_1x1()
738
+ # attention heads
739
+ self.gdt_convs_attn_4 = _head_1x1()
740
+ self.gdt_convs_attn_3 = _head_1x1()
741
+ self.gdt_convs_attn_2 = _head_1x1()
742
+
743
+ def forward(self, x, x1, x2, x3, x4):
744
+ x4 = torch.cat((x4, self.ipt_blk5(_image2patches(x, x4))), 1)
745
+ p4 = self.decoder_block4(x4)
746
+ p4 = p4 * self.gdt_convs_attn_4(self.gdt_convs_4(p4)).sigmoid()
747
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
748
+ _p3 = _p4 + self.lateral_block4(x3)
749
+
750
+ _p3 = torch.cat((_p3, self.ipt_blk4(_image2patches(x, _p3))), 1)
751
+ p3 = self.decoder_block3(_p3)
752
+ p3 = p3 * self.gdt_convs_attn_3(self.gdt_convs_3(p3)).sigmoid()
753
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
754
+ _p2 = _p3 + self.lateral_block3(x2)
755
+
756
+ _p2 = torch.cat((_p2, self.ipt_blk3(_image2patches(x, _p2))), 1)
757
+ p2 = self.decoder_block2(_p2)
758
+ p2 = p2 * self.gdt_convs_attn_2(self.gdt_convs_2(p2)).sigmoid()
759
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
760
+ _p1 = _p2 + self.lateral_block2(x1)
761
+
762
+ _p1 = torch.cat((_p1, self.ipt_blk2(_image2patches(x, _p1))), 1)
763
+ _p1 = self.decoder_block1(_p1)
764
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
765
+
766
+ _p1 = torch.cat((_p1, self.ipt_blk1(_image2patches(x, _p1))), 1)
767
+ return self.conv_out1(_p1)
768
+
769
+
770
+ class BiRefNet(nn.Module):
771
+ """BiRefNet (ZhengPeng7/BiRefNet) with Swin-L backbone, multi-scale input
772
+ concatenation, ASPP-deformable squeeze block, and the 4-level
773
+ input-modulating decoder used in the v1 release.
774
+
775
+ `forward(x)` returns a single 1-channel alpha map in `[0, 1]` (post-sigmoid).
776
+ `remove_background(pil_img)` is the PIL helper used by the pipeline —
777
+ accepts a PIL RGB image and returns an RGBA copy with the predicted matte
778
+ in the alpha channel.
779
+ """
780
+
781
+ INPUT_SIZE = (1024, 1024)
782
+ # backbone channel counts post mul_scl_ipt='cat' (doubled from raw Swin-L)
783
+ _CHANNELS = (3072, 1536, 768, 384)
784
+ # ImageNet normalization used by the BiRefNet recipe
785
+ _NORM_MEAN = (0.485, 0.456, 0.406)
786
+ _NORM_STD = (0.229, 0.224, 0.225)
787
+
788
+ def __init__(self):
789
+ super().__init__()
790
+ self.bb = _SwinLarge()
791
+ cxt = list(self._CHANNELS[1:][::-1][-3:]) # = [384, 768, 1536]
792
+ self.squeeze_module = nn.Sequential(
793
+ _BasicDecBlk(self._CHANNELS[0] + sum(cxt), self._CHANNELS[0])
794
+ )
795
+ self.decoder = _BiRefNetDecoder(channels=self._CHANNELS)
796
+
797
+ @property
798
+ def device(self):
799
+ return next(self.parameters()).device
800
+
801
+ @property
802
+ def dtype(self):
803
+ return next(self.parameters()).dtype
804
+
805
+ def _forward_enc(self, x):
806
+ x1, x2, x3, x4 = self.bb(x)
807
+ # mul_scl_ipt='cat': re-run backbone at half resolution, concat features
808
+ B, C, H, W = x.shape
809
+ x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H // 2, W // 2),
810
+ mode='bilinear', align_corners=True))
811
+ x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], 1)
812
+ x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], 1)
813
+ x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], 1)
814
+ x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], 1)
815
+ # cxt: upsample x1/x2/x3 to x4 spatial and concat for the squeeze input
816
+ x4 = torch.cat([
817
+ F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
818
+ F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
819
+ F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
820
+ x4,
821
+ ], 1)
822
+ return x1, x2, x3, x4
823
+
824
+ def forward(self, x):
825
+ x1, x2, x3, x4 = self._forward_enc(x)
826
+ x4 = self.squeeze_module(x4)
827
+ logits = self.decoder(x, x1, x2, x3, x4)
828
+ return torch.sigmoid(logits)
829
+
830
+ def load_safetensors(self, path: str) -> None:
831
+ sd = safetensors.torch.load_file(path)
832
+ # The decoder's gdt_convs_pred_* / conv_ms_spvn_* heads are training-only
833
+ # but are kept as submodules for state_dict parity. strict=True works.
834
+ missing, unexpected = self.load_state_dict(sd, strict=False)
835
+ if unexpected:
836
+ raise KeyError(f"[birefnet] unexpected keys (e.g. {unexpected[:3]})")
837
+ if missing:
838
+ raise KeyError(f"[birefnet] missing keys (e.g. {missing[:3]})")
839
+
840
+ @torch.no_grad()
841
+ def remove_background(self, image) -> "Image.Image":
842
+ from PIL import Image
843
+ if image.mode != "RGB":
844
+ image = image.convert("RGB")
845
+ W, H = image.size
846
+ arr = np.array(image, dtype=np.float32) / 255.0
847
+ t = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
848
+ t = F.interpolate(t, size=self.INPUT_SIZE, mode='bilinear', align_corners=True)
849
+ mean = torch.tensor(self._NORM_MEAN).view(1, 3, 1, 1)
850
+ std = torch.tensor(self._NORM_STD).view(1, 3, 1, 1)
851
+ t = ((t - mean) / std).to(device=self.device, dtype=self.dtype)
852
+ alpha = self.forward(t)
853
+ alpha = F.interpolate(alpha.float(), size=(H, W), mode='bilinear', align_corners=True)[0, 0]
854
+ a = (alpha.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
855
+ rgba = image.copy()
856
+ rgba.putalpha(Image.fromarray(a, mode="L"))
857
+ return rgba
858
+
859
+
860
+ # ---------------------------------------------------------------------------
861
+ # Shared transformer helpers
862
+ # ---------------------------------------------------------------------------
863
+
864
+ class LayerNorm32(nn.LayerNorm):
865
+ def forward(self, x):
866
+ origin_dtype = x.dtype
867
+ return F.layer_norm(
868
+ x.float(),
869
+ self.normalized_shape,
870
+ self.weight.float() if self.weight is not None else None,
871
+ self.bias.float() if self.bias is not None else None,
872
+ self.eps,
873
+ ).to(origin_dtype)
874
+
875
+
876
+ class MultiHeadRMSNorm(nn.Module):
877
+ def __init__(self, dim, heads):
878
+ super().__init__()
879
+ self.scale = dim ** 0.5
880
+ self.gamma = nn.Parameter(torch.ones(heads, dim))
881
+
882
+ def forward(self, x):
883
+ origin_dtype = x.dtype
884
+ return (F.normalize(x.float(), dim=-1) * self.gamma.float() * self.scale).to(origin_dtype)
885
+
886
+
887
+ def apply_rotary_emb(hidden_states, freqs):
888
+ x_rotated = torch.view_as_complex(hidden_states.float().reshape(*hidden_states.shape[:-1], -1, 2))
889
+ x_rotated = x_rotated * freqs
890
+ x_out = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1)
891
+ return x_out.type_as(hidden_states)
892
+
893
+
894
+ def clamp_mul(x, f):
895
+ f_t = f.tanh()
896
+ return x * f_t + x.detach() * (f - f_t)
897
+
898
+
899
+ def scaled_dot_product_attention(qkv=None, q=None, k=None, v=None, kv=None):
900
+ if qkv is not None:
901
+ q, k, v = qkv.unbind(dim=2)
902
+ elif kv is not None:
903
+ k, v = kv.unbind(dim=2)
904
+ q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
905
+ return F.scaled_dot_product_attention(q, k, v).permute(0, 2, 1, 3)
906
+
907
+
908
+ # ---------------------------------------------------------------------------
909
+ # Positional embeddings
910
+ # ---------------------------------------------------------------------------
911
+
912
+ class RePo3DRotaryEmbedding(nn.Module):
913
+ def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0):
914
+ super().__init__()
915
+ self.num_heads = num_heads
916
+ self.head_dim = head_dim
917
+ repo_hidden_size = int(model_channels * repo_hidden_ratio)
918
+ self.norm = LayerNorm32(model_channels)
919
+ self.gate_map = nn.Linear(model_channels, repo_hidden_size, bias=False)
920
+ self.content_map = nn.Linear(model_channels, repo_hidden_size, bias=False)
921
+ self.act = nn.SiLU()
922
+ self.final_map = nn.Linear(repo_hidden_size, 3 * num_heads, bias=False)
923
+ self.dim_0 = 2 * (head_dim // 6)
924
+ self.dim_1 = 2 * (head_dim // 6)
925
+ self.dim_2 = head_dim - self.dim_0 - self.dim_1
926
+ dims = [self.dim_0, self.dim_1, self.dim_2]
927
+ freqs_list = []
928
+ for d in dims:
929
+ freq_dim = d // 2
930
+ freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
931
+ self.freqs_0 = nn.Parameter(freqs_list[0])
932
+ self.freqs_1 = nn.Parameter(freqs_list[1])
933
+ self.freqs_2 = nn.Parameter(freqs_list[2])
934
+
935
+ def forward(self, hidden_states):
936
+ h = self.norm(hidden_states)
937
+ feat = self.act(self.gate_map(h)) * self.content_map(h)
938
+ out = self.final_map(feat)
939
+ B, L, _ = out.shape
940
+ delta_pos = out.reshape(B, L, self.num_heads, 3)
941
+ ang_0 = clamp_mul(delta_pos[..., 0].unsqueeze(-1), self.freqs_0) * torch.pi
942
+ ang_1 = clamp_mul(delta_pos[..., 1].unsqueeze(-1), self.freqs_1) * torch.pi
943
+ ang_2 = clamp_mul(delta_pos[..., 2].unsqueeze(-1), self.freqs_2) * torch.pi
944
+ ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # fp32 needed for torch.polar → complex64
945
+ return torch.polar(torch.ones_like(ang), ang).type(torch.complex64)
946
+
947
+
948
+ class PcdAbsolutePositionEmbedder(nn.Module):
949
+ def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16):
950
+ super().__init__()
951
+ self.channels = channels
952
+ self.in_channels = in_channels
953
+ self.max_res = max_res
954
+ self.freq_dim = channels // in_channels // 2
955
+
956
+ def _freqs(self, device):
957
+ freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
958
+ res_dim = max(0, self.freq_dim - self.max_res)
959
+ freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
960
+ if res_dim > 0 else torch.empty(0, device=device))
961
+ freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
962
+ return torch.pow(2.0, freqs)
963
+
964
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
965
+ orig_dtype = x.dtype
966
+ x = x.float()
967
+ *dims, D = x.shape
968
+ out = torch.outer(x.reshape(-1), self._freqs(x.device)) * 2 * torch.pi
969
+ out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
970
+ if out.shape[-1] < self.channels:
971
+ out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
972
+ device=out.device, dtype=out.dtype)], dim=-1)
973
+ return out.to(orig_dtype)
974
+
975
+
976
+ class PcdAbsolutePositionEmbedderV2(nn.Module):
977
+ def __init__(self, channels: int, in_channels: int = 3, max_res: int = 10):
978
+ super().__init__()
979
+ self.channels = channels
980
+ self.in_channels = in_channels
981
+ self.max_res = max_res
982
+ self.freq_dim = channels // in_channels // 2
983
+
984
+ def _freqs(self, device):
985
+ logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
986
+ return torch.pow(2.0, logs)
987
+
988
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
989
+ orig_dtype = x.dtype
990
+ x = x.float()
991
+ N, D = x.shape
992
+ ang = x.unsqueeze(-1) * self._freqs(x.device) * torch.pi
993
+ embed = torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1).reshape(N, -1)
994
+ if embed.shape[1] < self.channels:
995
+ embed = torch.cat([embed, torch.zeros(N, self.channels - embed.shape[1],
996
+ device=embed.device, dtype=embed.dtype)], dim=-1)
997
+ return embed.to(orig_dtype)
998
+
999
+
1000
+ # ---------------------------------------------------------------------------
1001
+ # Transformer building blocks
1002
+ # ---------------------------------------------------------------------------
1003
+
1004
+ class FeedForwardNet(nn.Module):
1005
+ def __init__(self, channels, mlp_ratio=4.0, channels_out=None):
1006
+ super().__init__()
1007
+ self.mlp = nn.Sequential(
1008
+ nn.Linear(channels, int(channels * mlp_ratio)),
1009
+ nn.GELU(approximate="tanh"),
1010
+ nn.Linear(int(channels * mlp_ratio), channels if channels_out is None else channels_out),
1011
+ )
1012
+
1013
+ def forward(self, x):
1014
+ return self.mlp(x)
1015
+
1016
+
1017
+ class MLP(nn.Module):
1018
+ def __init__(self, channels: int, inner_channels: int, channels_out: Optional[int] = None,
1019
+ mlp_layer_num: int = 2):
1020
+ super().__init__()
1021
+ layers = []
1022
+ for i in range(mlp_layer_num - 1):
1023
+ layers.append(nn.Linear(channels if i == 0 else inner_channels, inner_channels))
1024
+ layers.append(nn.GELU(approximate="tanh"))
1025
+ layers.append(nn.Linear(inner_channels, channels if channels_out is None else channels_out))
1026
+ self.mlp = nn.Sequential(*layers)
1027
+
1028
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1029
+ return self.mlp(x)
1030
+
1031
+
1032
+ class RopeMultiHeadAttention(nn.Module):
1033
+ def __init__(self, channels, num_heads, ctx_channels=None, type="self",
1034
+ attn_mode="full", qkv_bias=True, qk_rms_norm=False, use_rope=False):
1035
+ super().__init__()
1036
+ self.channels = channels
1037
+ self.num_heads = num_heads
1038
+ self.head_dim = channels // num_heads
1039
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
1040
+ self._type = type
1041
+ self.qk_rms_norm = qk_rms_norm
1042
+ self.use_rope = use_rope
1043
+ if self._type == "self":
1044
+ self.qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
1045
+ else:
1046
+ self.q = nn.Linear(channels, channels, bias=qkv_bias)
1047
+ self.kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
1048
+ if self.qk_rms_norm:
1049
+ self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
1050
+ self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
1051
+ self.out = nn.Linear(channels, channels)
1052
+
1053
+ def forward(self, x, context=None, rope_emb=None):
1054
+ B, L, C = x.shape
1055
+ if self._type == "self":
1056
+ qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
1057
+ q, k, v = qkv.unbind(2)
1058
+ if self.use_rope:
1059
+ q = apply_rotary_emb(q, rope_emb)
1060
+ k = apply_rotary_emb(k, rope_emb)
1061
+ else:
1062
+ q = self.q(x).reshape(B, L, self.num_heads, self.head_dim)
1063
+ if context is None:
1064
+ raise ValueError("Context must be provided for cross attention")
1065
+ kv = self.kv(context).reshape(B, context.shape[1], 2, self.num_heads, self.head_dim)
1066
+ k, v = kv.unbind(2)
1067
+ if self.qk_rms_norm:
1068
+ q = self.q_norm(q)
1069
+ k = self.k_norm(k)
1070
+ h = scaled_dot_product_attention(q=q, k=k, v=v)
1071
+ return self.out(h.reshape(B, L, C))
1072
+
1073
+
1074
+ class MultiHeadAttention(nn.Module):
1075
+ def __init__(self, channels, num_heads, ctx_channels=None, type="self",
1076
+ attn_mode="full", qkv_bias=True, qk_rms_norm=False):
1077
+ super().__init__()
1078
+ assert channels % num_heads == 0
1079
+ assert type in ["self", "cross"]
1080
+ assert attn_mode == "full"
1081
+ self.channels = channels
1082
+ self.head_dim = channels // num_heads
1083
+ self.ctx_channels = ctx_channels if ctx_channels is not None else channels
1084
+ self.num_heads = num_heads
1085
+ self._type = type
1086
+ self.qk_rms_norm = qk_rms_norm
1087
+ if self._type == "self":
1088
+ self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
1089
+ else:
1090
+ self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
1091
+ self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
1092
+ if self.qk_rms_norm:
1093
+ self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
1094
+ self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads)
1095
+ self.to_out = nn.Linear(channels, channels)
1096
+
1097
+ def forward(self, x, context=None):
1098
+ B, L, C = x.shape
1099
+ if self._type == "self":
1100
+ qkv = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1)
1101
+ if self.qk_rms_norm:
1102
+ q, k, v = qkv.unbind(dim=2)
1103
+ q = self.q_rms_norm(q)
1104
+ k = self.k_rms_norm(k)
1105
+ qkv = torch.stack([q, k, v], dim=2)
1106
+ h = scaled_dot_product_attention(qkv=qkv)
1107
+ else:
1108
+ Lkv = context.shape[1]
1109
+ q = self.to_q(x).reshape(B, L, self.num_heads, -1)
1110
+ kv = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1)
1111
+ if self.qk_rms_norm:
1112
+ q = self.q_rms_norm(q)
1113
+ k, v = kv.unbind(dim=2)
1114
+ k = self.k_rms_norm(k)
1115
+ h = scaled_dot_product_attention(q=q, k=k, v=v)
1116
+ else:
1117
+ h = scaled_dot_product_attention(q=q, kv=kv)
1118
+ return self.to_out(h.reshape(B, L, -1))
1119
+
1120
+
1121
+ class UnifiedTransformerBlock(nn.Module):
1122
+ def __init__(self, channels, num_heads, mlp_ratio=4.0, attn_mode="full",
1123
+ use_checkpoint=False, use_rope=False, qk_rms_norm=False, qkv_bias=True,
1124
+ modulation=True, share_mod=False, use_shift_table=False):
1125
+ super().__init__()
1126
+ self.modulation = modulation
1127
+ self.share_mod = share_mod
1128
+ self.norm1 = LayerNorm32(channels, elementwise_affine=not modulation, eps=1e-6)
1129
+ self.norm2 = LayerNorm32(channels, elementwise_affine=not modulation, eps=1e-6)
1130
+ self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads, type="self",
1131
+ attn_mode=attn_mode, qkv_bias=qkv_bias,
1132
+ use_rope=use_rope, qk_rms_norm=qk_rms_norm)
1133
+ self.mlp = FeedForwardNet(channels, mlp_ratio=mlp_ratio)
1134
+ if modulation:
1135
+ if not share_mod:
1136
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True))
1137
+ self.shift_table = nn.Parameter(torch.randn(1, 6 * channels) / channels ** 0.5) if use_shift_table else None
1138
+
1139
+ def forward(self, x, mod=None, rotary_emb=None):
1140
+ if self.modulation:
1141
+ if not self.share_mod:
1142
+ mod = self.adaLN_modulation(mod)
1143
+ if hasattr(self, 'shift_table') and self.shift_table is not None:
1144
+ mod = mod + self.shift_table.type(mod.dtype)
1145
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
1146
+ h = self.norm1(x)
1147
+ h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
1148
+ h = self.attn(h, rope_emb=rotary_emb)
1149
+ x = x + h * gate_msa.unsqueeze(1)
1150
+ h = self.norm2(x)
1151
+ h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
1152
+ x = x + self.mlp(h) * gate_mlp.unsqueeze(1)
1153
+ else:
1154
+ x = x + self.attn(self.norm1(x), rope_emb=rotary_emb)
1155
+ x = x + self.mlp(self.norm2(x))
1156
+ return x
1157
+
1158
+
1159
+ # ---------------------------------------------------------------------------
1160
+ # Quasi-random sampling utilities
1161
+ # ---------------------------------------------------------------------------
1162
+
1163
+ PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
1164
+
1165
+
1166
+ def radical_inverse(base, n):
1167
+ val = 0
1168
+ inv_base = 1.0 / base
1169
+ inv_base_n = inv_base
1170
+ while n > 0:
1171
+ digit = n % base
1172
+ val += digit * inv_base_n
1173
+ n //= base
1174
+ inv_base_n *= inv_base
1175
+ return val
1176
+
1177
+
1178
+ def halton_sequence(dim, n):
1179
+ return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
1180
+
1181
+
1182
+ def hammersley_sequence(dim, n, num_samples):
1183
+ return [n / num_samples] + halton_sequence(dim - 1, n)
1184
+
1185
+
1186
+ @torch.no_grad()
1187
+ def sample_probs(probs, counts, algo="systematic"):
1188
+ batch_shape = counts.shape
1189
+ B = counts.numel()
1190
+ P = probs.size(-1)
1191
+ device = probs.device
1192
+ probs = probs.view(B, P)
1193
+ counts = counts.view(B)
1194
+
1195
+ probs = probs.to(torch.float32).clamp_min_(0)
1196
+ row_sums = probs.sum(1, keepdim=True)
1197
+ zero_mask = row_sums.eq(0)
1198
+ probs = probs / row_sums.clamp_min_(1)
1199
+ if zero_mask.any():
1200
+ probs = probs.clone()
1201
+ probs[zero_mask.expand_as(probs)] = 1.0 / P
1202
+
1203
+ counts = counts.to(device=device, dtype=torch.long)
1204
+ out = torch.zeros(B, P, dtype=torch.long, device=device)
1205
+ cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
1206
+ unique_n, inv = counts.unique(sorted=False, return_inverse=True)
1207
+ for i, n in enumerate(unique_n.tolist()):
1208
+ if n == 0:
1209
+ continue
1210
+ rows = (inv == i).nonzero(as_tuple=False).squeeze(1)
1211
+ r = rows.numel()
1212
+ U0 = torch.rand(r, 1, device=device) / float(n)
1213
+ grid = torch.arange(n, device=device, dtype=torch.float32)[None, :] / float(n)
1214
+ us = (U0 + grid).clamp(max=1.0 - 1e-12)
1215
+ cdf_rows = cdf.index_select(0, rows)
1216
+ idx = torch.searchsorted(cdf_rows, us).clamp_max(probs.size(1) - 1)
1217
+ buf = torch.zeros(r, P, dtype=torch.float32, device=device)
1218
+ buf.scatter_add_(1, idx, torch.ones_like(idx, dtype=buf.dtype))
1219
+ out.index_copy_(0, rows, buf.to(torch.long))
1220
+
1221
+ return out.view(*batch_shape, P)
1222
+
1223
+
1224
+ # ---------------------------------------------------------------------------
1225
+ # VAE decoders
1226
+ # ---------------------------------------------------------------------------
1227
+
1228
+ class LevelEmbedder(nn.Module):
1229
+ def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024):
1230
+ super().__init__()
1231
+ self.mlp = nn.Sequential(
1232
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
1233
+ nn.SiLU(),
1234
+ nn.Linear(hidden_size, hidden_size, bias=True),
1235
+ )
1236
+ self.frequency_embedding_size = frequency_embedding_size
1237
+ self.max_period = max_period
1238
+
1239
+ @staticmethod
1240
+ def level_embedding(t, dim, max_period=1024):
1241
+ half = dim // 2
1242
+ freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
1243
+ args = t[:, None].float() * freqs[None] * 2 * torch.pi
1244
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
1245
+ if dim % 2:
1246
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
1247
+ return embedding
1248
+
1249
+ def forward(self, t):
1250
+ emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
1251
+ return self.mlp(emb.to(self.mlp[0].weight.dtype))
1252
+
1253
+
1254
+ class ModulatedTransformerCrossOnlyBlock(nn.Module):
1255
+ def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
1256
+ qk_rms_norm_cross=True, qkv_bias=True):
1257
+ super().__init__()
1258
+ self.share_mod = share_mod
1259
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
1260
+ self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
1261
+ self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
1262
+ type="cross", attn_mode="full", qkv_bias=qkv_bias,
1263
+ qk_rms_norm=qk_rms_norm_cross)
1264
+ self.mlp = FeedForwardNet(channels, mlp_ratio=mlp_ratio)
1265
+ if not share_mod:
1266
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True))
1267
+
1268
+ def forward(self, x, mod, context):
1269
+ if self.share_mod:
1270
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
1271
+ else:
1272
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
1273
+ h = self.norm1(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
1274
+ x = x + self.cross_attn(h, context) * gate_msa.unsqueeze(1)
1275
+ h = self.norm2(x) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
1276
+ x = x + self.mlp(h) * gate_mlp.unsqueeze(1)
1277
+ return x
1278
+
1279
+
1280
+ class ModulatedCrossOnlyTransformerBase(nn.Module):
1281
+ def __init__(self, in_channels, model_channels, cond_channels, num_blocks, num_heads=None,
1282
+ num_head_channels=64, mlp_ratio=4.0, share_mod=False, additional_level_embed=False,
1283
+ qk_rms_norm_cross=True):
1284
+ super().__init__()
1285
+ self.model_channels = model_channels
1286
+ self.cond_channels = cond_channels
1287
+ self.num_blocks = num_blocks
1288
+ self.num_heads = num_heads or model_channels // num_head_channels
1289
+ self.mlp_ratio = mlp_ratio
1290
+ self.share_mod = share_mod
1291
+ self.qk_rms_norm_cross = qk_rms_norm_cross
1292
+
1293
+ self.input_layer = nn.Linear(in_channels, model_channels)
1294
+ self.l_embedder = LevelEmbedder(model_channels)
1295
+ self.l_embedder2 = LevelEmbedder(model_channels, max_period=100) if additional_level_embed else None
1296
+ if share_mod:
1297
+ self.adaLN_modulation = nn.Sequential(
1298
+ nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True))
1299
+ if cond_channels is not None:
1300
+ self.blocks = nn.ModuleList([
1301
+ ModulatedTransformerCrossOnlyBlock(
1302
+ model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
1303
+ mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
1304
+ share_mod=self.share_mod)
1305
+ for _ in range(num_blocks)
1306
+ ])
1307
+
1308
+ @property
1309
+ def dtype(self) -> torch.dtype:
1310
+ return next(self.parameters()).dtype
1311
+
1312
+ @property
1313
+ def device(self) -> torch.device:
1314
+ return next(self.parameters()).device
1315
+
1316
+ def forward(self, x, l, cond, l2=None):
1317
+ h = self.input_layer(x)
1318
+ l_emb = self.l_embedder(l)
1319
+ if self.l_embedder2 is not None and l2 is not None:
1320
+ l_emb = l_emb + self.l_embedder2(l2)
1321
+ if self.share_mod:
1322
+ l_emb = self.adaLN_modulation(l_emb)
1323
+ for block in self.blocks:
1324
+ h = block(h, l_emb, cond)
1325
+ return h
1326
+
1327
+
1328
+ class OctreeProbabilityFixedlenDecoder(ModulatedCrossOnlyTransformerBase):
1329
+ def __init__(self, model_channels, cond_channels, num_blocks, num_heads=None,
1330
+ num_head_channels=64, mlp_ratio=4.0, share_mod=False,
1331
+ additional_level_embed=False, qk_rms_norm_cross=True, *,
1332
+ no_norm=False):
1333
+ super().__init__(
1334
+ in_channels=model_channels, model_channels=model_channels,
1335
+ cond_channels=cond_channels, num_blocks=num_blocks,
1336
+ num_heads=num_heads, num_head_channels=num_head_channels,
1337
+ mlp_ratio=mlp_ratio, share_mod=share_mod,
1338
+ additional_level_embed=additional_level_embed,
1339
+ qk_rms_norm_cross=qk_rms_norm_cross,
1340
+ )
1341
+ self.out_proj = nn.Linear(self.model_channels, 8)
1342
+ self.no_norm = no_norm
1343
+ self.in_proj = nn.Linear(3, self.model_channels)
1344
+ self.pos_embedder = PcdAbsolutePositionEmbedderV2(channels=model_channels, in_channels=3)
1345
+
1346
+ def forward(self, x, l, cond, l2=None):
1347
+ d = self.dtype
1348
+ B, L, C = x.shape
1349
+ h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
1350
+ if l2 is not None:
1351
+ l2 = torch.log2(l2)
1352
+ h = super().forward(h, l, cond.to(d), l2)
1353
+ h = F.layer_norm(h.float(), h.shape[-1:]).to(d) if not self.no_norm else h / (1 + 2 * self.num_blocks) ** 0.5
1354
+ logits = self.out_proj(h)
1355
+ return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
1356
+
1357
+ @staticmethod
1358
+ def sample(model, cond, num_points, level, temperature=1.0, algo="systematic"):
1359
+ B = cond.shape[0]
1360
+ device = cond.device
1361
+ child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
1362
+ dtype=torch.long, device=device)
1363
+ prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
1364
+ prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
1365
+ prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
1366
+ batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
1367
+ num_tensor = torch.full((B,), num_points, dtype=torch.long, device=device)
1368
+
1369
+ for lv in range(1, level + 1):
1370
+ res_p = 1 << (lv - 1)
1371
+ res = 1 << lv
1372
+ parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
1373
+ res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
1374
+ pred_logits = model(parent_coords_norm, res_tensor, cond, num_tensor)["logits"] / temperature
1375
+ pred_probs = torch.softmax(pred_logits, dim=-1)
1376
+ pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
1377
+ sampled = sample_probs(pred_probs, prev_counts, algo=algo).flatten(1, 2)
1378
+ pred_log_probs = pred_log_probs.flatten(1, 2)
1379
+ prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
1380
+ child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
1381
+ mask = sampled > 0
1382
+ max_valid = mask.sum(dim=1).max().item()
1383
+ scatter_indices = mask.cumsum(dim=1) - 1
1384
+ valid_scatter_indices = scatter_indices[mask]
1385
+ valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
1386
+ next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
1387
+ next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
1388
+ next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
1389
+ next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
1390
+ next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
1391
+ next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
1392
+ prev_coords_int = next_prev_coords_int
1393
+ prev_counts = next_prev_counts
1394
+ prev_log_probs = next_prev_log_probs
1395
+
1396
+ res = 1 << level
1397
+ prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
1398
+ coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
1399
+ coords_norm = (coords_int.to(torch.float32) + torch.rand_like(coords_int, dtype=torch.float32)) / res
1400
+ return {"points": coords_norm, "log_probs": prev_log_probs}
1401
+
1402
+
1403
+ class TransformerCrossBlock(nn.Module):
1404
+ def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, attn_mode="full",
1405
+ qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True):
1406
+ super().__init__()
1407
+ self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
1408
+ self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
1409
+ self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
1410
+ self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self",
1411
+ attn_mode=attn_mode, qkv_bias=qkv_bias,
1412
+ qk_rms_norm=qk_rms_norm)
1413
+ self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
1414
+ type="cross", attn_mode="full", qkv_bias=qkv_bias,
1415
+ qk_rms_norm=qk_rms_norm_cross)
1416
+ self.mlp = FeedForwardNet(channels, mlp_ratio=mlp_ratio)
1417
+
1418
+ def forward(self, x, context):
1419
+ x = x + self.self_attn(self.norm1(x))
1420
+ x = x + self.cross_attn(self.norm2(x), context)
1421
+ x = x + self.mlp(self.norm3(x))
1422
+ return x
1423
+
1424
+
1425
+ class TransformerBase(nn.Module):
1426
+ def __init__(self, in_channels, model_channels, cond_channels, num_blocks, num_heads=None,
1427
+ num_head_channels=64, mlp_ratio=4.0, attn_mode="full", window_num=None,
1428
+ qk_rms_norm=True, qk_rms_norm_cross=True):
1429
+ super().__init__()
1430
+ self.model_channels = model_channels
1431
+ self.cond_channels = cond_channels
1432
+ self.num_blocks = num_blocks
1433
+ self.num_heads = num_heads or model_channels // num_head_channels
1434
+ self.mlp_ratio = mlp_ratio
1435
+ self.input_layer = nn.Linear(in_channels, model_channels)
1436
+ if cond_channels is not None:
1437
+ self.blocks = nn.ModuleList([
1438
+ TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
1439
+ num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
1440
+ attn_mode="full", qk_rms_norm=qk_rms_norm,
1441
+ qk_rms_norm_cross=qk_rms_norm_cross)
1442
+ for _ in range(num_blocks)
1443
+ ])
1444
+
1445
+ @property
1446
+ def dtype(self) -> torch.dtype:
1447
+ return next(self.parameters()).dtype
1448
+
1449
+ def forward(self, x, cond=None, l=None, cond2=None):
1450
+ h = self.input_layer(x)
1451
+ for block in self.blocks:
1452
+ h = block(h, cond)
1453
+ return h
1454
+
1455
+
1456
+ class FixedlenDecoder(TransformerBase):
1457
+ def __init__(self, in_channels, model_channels, cond_channels, num_blocks, num_heads=None,
1458
+ num_head_channels=64, mlp_ratio=4.0, attn_mode="full", window_num=None,
1459
+ qk_rms_norm=True, qk_rms_norm_cross=True):
1460
+ super().__init__(in_channels=model_channels, model_channels=model_channels,
1461
+ cond_channels=cond_channels, num_blocks=num_blocks,
1462
+ num_heads=num_heads, num_head_channels=num_head_channels,
1463
+ mlp_ratio=mlp_ratio, attn_mode=attn_mode, window_num=window_num,
1464
+ qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross)
1465
+ self.in_proj = nn.Linear(in_channels, model_channels)
1466
+ self.pos_embedder = PcdAbsolutePositionEmbedderV2(channels=model_channels, in_channels=3)
1467
+
1468
+ def forward(self, x=None, cond=None):
1469
+ pcd = x["points"]
1470
+ d = self.dtype
1471
+ B, L, C = pcd.shape
1472
+ h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
1473
+ return super().forward(h, cond.to(d))
1474
+
1475
+
1476
+ class ElasticGaussianFixedlenDecoder(FixedlenDecoder):
1477
+ def __init__(self, in_channels, model_channels, cond_channels, num_blocks, num_heads=None,
1478
+ num_head_channels=64, mlp_ratio=4.0, attn_mode="full", window_num=None,
1479
+ *, no_norm=False, representation_config=None,
1480
+ use_learned_offset_scale=True, use_per_offset=True,
1481
+ qk_rms_norm=True, qk_rms_norm_cross=True):
1482
+ self.rep_config = representation_config
1483
+ self.use_learned_offset_scale = use_learned_offset_scale
1484
+ self.use_per_offset = use_per_offset
1485
+ self.out_channels = self._calc_layout()
1486
+ super().__init__(in_channels=in_channels, model_channels=model_channels,
1487
+ cond_channels=cond_channels, num_blocks=num_blocks,
1488
+ num_heads=num_heads, num_head_channels=num_head_channels,
1489
+ mlp_ratio=mlp_ratio, attn_mode=attn_mode, window_num=window_num,
1490
+ qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross)
1491
+ self.out_proj = nn.Linear(model_channels, self.out_channels)
1492
+ self.no_norm = no_norm
1493
+ self._build_perturbation()
1494
+
1495
+ def _calc_layout(self):
1496
+ ng = self.rep_config['num_gaussians']
1497
+ self.layout = {
1498
+ '_xyz': {'shape': (ng, 3), 'size': ng * 3},
1499
+ '_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
1500
+ '_scaling': {'shape': (ng, 3), 'size': ng * 3},
1501
+ '_rotation': {'shape': (ng, 4), 'size': ng * 4},
1502
+ '_opacity': {'shape': (ng, 1), 'size': ng},
1503
+ }
1504
+ if self.use_learned_offset_scale and self.use_per_offset:
1505
+ self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
1506
+ start = 0
1507
+ for k, v in self.layout.items():
1508
+ v['range'] = (start, start + v['size'])
1509
+ start += v['size']
1510
+ return start
1511
+
1512
+ def _build_perturbation(self):
1513
+ ng = self.rep_config['num_gaussians']
1514
+ perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
1515
+ perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
1516
+ self.register_buffer('points_offset_perturbation', perturbation)
1517
+ if self.use_learned_offset_scale:
1518
+ base = torch.tensor(self.rep_config['offset_scale'])
1519
+ self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
1520
+
1521
+ def _get_offset(self, h):
1522
+ B = h.shape[0]
1523
+ if self.use_learned_offset_scale:
1524
+ r = self.layout['_offset_scale']['range']
1525
+ _offset_scale = F.softplus(
1526
+ h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
1527
+ + self.base_offset_scale)
1528
+
1529
+ r = self.layout['_xyz']['range']
1530
+ offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
1531
+ offset = offset * self.rep_config['lr']['_xyz']
1532
+ if self.rep_config['perturb_offset']:
1533
+ offset = offset + self.points_offset_perturbation
1534
+ offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
1535
+ offset = offset * (_offset_scale if self.use_learned_offset_scale else self.rep_config['offset_scale'])
1536
+ return offset
1537
+
1538
+ def forward(self, x=None, cond=None):
1539
+ h = super().forward(x, cond)
1540
+ h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype) if not self.no_norm else h / (1 + 3 * self.num_blocks) ** 0.5
1541
+ return {"features": self.out_proj(h)}
1542
+
1543
+
1544
+ # ---------------------------------------------------------------------------
1545
+ # Flow matching denoiser
1546
+ # ---------------------------------------------------------------------------
1547
+
1548
+ class TimestepEmbedder(nn.Module):
1549
+ def __init__(self, hidden_size, frequency_embedding_size=256):
1550
+ super().__init__()
1551
+ self.mlp = nn.Sequential(
1552
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
1553
+ nn.SiLU(),
1554
+ nn.Linear(hidden_size, hidden_size, bias=True),
1555
+ )
1556
+ self.frequency_embedding_size = frequency_embedding_size
1557
+
1558
+ @staticmethod
1559
+ def timestep_embedding(t, dim, max_period=10000):
1560
+ half = dim // 2
1561
+ freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
1562
+ args = t[:, None].float() * freqs[None]
1563
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
1564
+ if dim % 2:
1565
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
1566
+ return embedding
1567
+
1568
+ def forward(self, t):
1569
+ emb = self.timestep_embedding(t, self.frequency_embedding_size)
1570
+ return self.mlp(emb.to(self.mlp[0].weight.dtype))
1571
+
1572
+
1573
+ class LatentSeqMMFlowModel(nn.Module):
1574
+ def __init__(self, q_token_length, in_channels, model_channels, cond_channels,
1575
+ out_channels, num_blocks, num_refiner_blocks=2, num_heads=None,
1576
+ num_head_channels=64, cam_channels=None, cond2_channels=None,
1577
+ mlp_ratio=4, share_mod=True, qk_rms_norm=False, use_shift_table=False):
1578
+ super().__init__()
1579
+ self.q_token_length = q_token_length
1580
+ self.in_channels = in_channels
1581
+ self.cam_channels = cam_channels
1582
+ self.model_channels = model_channels
1583
+ self.cond_channels = cond_channels
1584
+ self.cond2_channels = cond2_channels
1585
+ self.out_channels = out_channels
1586
+ self.num_blocks = num_blocks
1587
+ self.num_refiner_blocks = num_refiner_blocks
1588
+ self.num_heads = num_heads or model_channels // num_head_channels
1589
+ self.mlp_ratio = mlp_ratio
1590
+ self.share_mod = share_mod
1591
+ self.qk_rms_norm = qk_rms_norm
1592
+ self.use_shift_table = use_shift_table
1593
+
1594
+ self.t_embedder = TimestepEmbedder(model_channels)
1595
+ if share_mod:
1596
+ self.adaLN_modulation = nn.Sequential(
1597
+ nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True))
1598
+
1599
+ self.input_layer = nn.Linear(in_channels, model_channels)
1600
+ self.cond_embedder = nn.Linear(cond_channels, model_channels)
1601
+ self.cond_embedder2 = nn.Linear(cond2_channels, model_channels) if cond2_channels is not None else None
1602
+
1603
+ sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
1604
+ self.pos_pe = sobol_seq.unsqueeze(0)
1605
+ self.pos_embedder = PcdAbsolutePositionEmbedder(model_channels)
1606
+ self.noise_repo_layers = nn.ModuleList([
1607
+ RePo3DRotaryEmbedding(model_channels, num_heads=self.num_heads, head_dim=num_head_channels)
1608
+ for _ in range(num_refiner_blocks)])
1609
+ self.context_repo_layers = nn.ModuleList([
1610
+ RePo3DRotaryEmbedding(model_channels, num_heads=self.num_heads, head_dim=num_head_channels)
1611
+ for _ in range(num_refiner_blocks)])
1612
+ self.repo_layers = nn.ModuleList([
1613
+ RePo3DRotaryEmbedding(model_channels, num_heads=self.num_heads, head_dim=num_head_channels)
1614
+ for _ in range(num_blocks)])
1615
+
1616
+ block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, attn_mode='full',
1617
+ use_rope=True, qk_rms_norm=self.qk_rms_norm,
1618
+ use_shift_table=self.use_shift_table)
1619
+ self.noise_refiner = nn.ModuleList([
1620
+ UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs)
1621
+ for _ in range(num_refiner_blocks)])
1622
+ self.context_refiner = nn.ModuleList([
1623
+ UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs)
1624
+ for _ in range(num_refiner_blocks)])
1625
+ if self.cam_channels is not None:
1626
+ self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels,
1627
+ mlp_layer_num=num_refiner_blocks)
1628
+ self.blocks = nn.ModuleList([
1629
+ UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs)
1630
+ for _ in range(num_blocks)])
1631
+ self.shift_table = nn.Parameter(torch.randn(1, 2, model_channels) / model_channels**0.5) if use_shift_table else None
1632
+ self.out_layer = nn.Linear(model_channels, out_channels)
1633
+ if cam_channels is not None:
1634
+ self.cam_out_layer = nn.Linear(model_channels, cam_channels)
1635
+
1636
+ @property
1637
+ def dtype(self) -> torch.dtype:
1638
+ return next(self.parameters()).dtype
1639
+
1640
+ @property
1641
+ def device(self) -> torch.device:
1642
+ return next(self.parameters()).device
1643
+
1644
+ def load_safetensors(self, path: str) -> None:
1645
+ self.load_state_dict(safetensors.torch.load_file(path), strict=True)
1646
+
1647
+ def forward(self, x_t, t, cond):
1648
+ d = self.dtype
1649
+ z = x_t['latent'].to(d)
1650
+ feat1 = cond['feature1'].to(d)
1651
+ feat2 = cond['feature2'].to(d) if self.cond_embedder2 is not None else None
1652
+ self.pos_pe = self.pos_pe.to(z.device)
1653
+
1654
+ h_x = self.input_layer(z)
1655
+ h_cond = self.cond_embedder(feat1)
1656
+ if feat2 is not None:
1657
+ h_cond = h_cond + self.cond_embedder2(feat2)
1658
+ t_emb = self.t_embedder(t)
1659
+ t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
1660
+
1661
+ h_x = h_x + self.pos_embedder(self.pos_pe).to(d)
1662
+
1663
+ for i, block in enumerate(self.noise_refiner):
1664
+ h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x))
1665
+
1666
+ for i, block in enumerate(self.context_refiner):
1667
+ h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond))
1668
+
1669
+ if self.cam_channels is not None:
1670
+ cam = x_t.get('camera').to(d)
1671
+ h_cam = self.cam_refiner(cam)
1672
+
1673
+ h = torch.cat([h_x, h_cond], dim=1)
1674
+ if self.cam_channels is not None:
1675
+ h = torch.cat([h, h_cam], dim=1)
1676
+
1677
+ for i, block in enumerate(self.blocks):
1678
+ h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h))
1679
+
1680
+ h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).type(d)
1681
+ if self.cam_channels is not None:
1682
+ h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).type(d)
1683
+
1684
+ if self.use_shift_table:
1685
+ shift, scale = (self.shift_table + t_emb.unsqueeze(1)).chunk(2, dim=1)
1686
+ h_x = h_x * (1 + scale) + shift
1687
+ if self.cam_channels is not None:
1688
+ h_cam = h_cam * (1 + scale) + shift
1689
+
1690
+ out = {'latent': self.out_layer(h_x)}
1691
+ if self.cam_channels is not None:
1692
+ out['camera'] = self.cam_out_layer(h_cam)
1693
+ return out
1694
+
1695
+
1696
+ # ---------------------------------------------------------------------------
1697
+ # OctreeGaussianDecoder
1698
+ # ---------------------------------------------------------------------------
1699
+
1700
+ class OctreeGaussianDecoder(nn.Module):
1701
+ _MAX_VOXEL_LEVEL = 8
1702
+
1703
+ def __init__(self, octree_args: dict, gs_args: dict):
1704
+ super().__init__()
1705
+ self.octree = OctreeProbabilityFixedlenDecoder(**octree_args)
1706
+ self.gs = ElasticGaussianFixedlenDecoder(**gs_args)
1707
+
1708
+ def load_safetensors(self, path: str) -> None:
1709
+ self.load_state_dict(safetensors.torch.load_file(path), strict=True)
1710
+
1711
+ @property
1712
+ def gaussians_per_point(self) -> int:
1713
+ return self.gs.rep_config['num_gaussians']
1714
+
1715
+ @torch.no_grad()
1716
+ def decode(self, latent: torch.Tensor, num_gaussians: int):
1717
+ from triposplat import _build_gaussians # local import: avoid model.py ↔ triposplat.py cycle
1718
+ num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
1719
+ points_pred = OctreeProbabilityFixedlenDecoder.sample(
1720
+ self.octree, latent,
1721
+ num_points=num_decoder_tokens, level=self._MAX_VOXEL_LEVEL,
1722
+ temperature=1.0, algo='systematic',
1723
+ )
1724
+ pred = self.gs(x=points_pred, cond=latent)
1725
+ return _build_gaussians(self.gs, points_pred, pred)[0]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ safetensors
5
+ pillow
6
+ tqdm
static/viewer/viewer.html ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <meta name="viewport" content="width=device-width,initial-scale=1">
6
+ <title>3DGS Viewer</title>
7
+ <style>
8
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
9
+ html, body { width: 100%; height: 100%; overflow: hidden;
10
+ background: radial-gradient(ellipse at 50% 60%, #1a2035 0%, #080b12 100%); }
11
+ canvas { display: block; }
12
+ #hud {
13
+ position: fixed; top: 12px; left: 16px;
14
+ color: #94a3b8; font: 12px/1.5 system-ui, sans-serif;
15
+ pointer-events: none; user-select: none;
16
+ }
17
+ #loading {
18
+ position: fixed; inset: 0;
19
+ display: flex; flex-direction: column;
20
+ align-items: center; justify-content: center;
21
+ color: #94a3b8; font: 14px system-ui, sans-serif;
22
+ gap: 12px;
23
+ }
24
+ #spinner {
25
+ width: 36px; height: 36px;
26
+ border: 3px solid #334155;
27
+ border-top-color: #60a5fa;
28
+ border-radius: 50%;
29
+ animation: spin 0.9s linear infinite;
30
+ }
31
+ @keyframes spin { to { transform: rotate(360deg); } }
32
+ #error {
33
+ display: none;
34
+ position: fixed; inset: 0;
35
+ align-items: center; justify-content: center;
36
+ color: #f87171; font: 13px system-ui, sans-serif;
37
+ padding: 32px; text-align: center; white-space: pre-wrap;
38
+ }
39
+ </style>
40
+ </head>
41
+ <body>
42
+ <div id="loading">
43
+ <div id="spinner"></div>
44
+ <div id="loading-label">Loading splat…</div>
45
+ </div>
46
+ <div id="error"></div>
47
+ <div id="hud">drag to orbit &nbsp;·&nbsp; scroll to zoom &nbsp;·&nbsp; right-drag to pan</div>
48
+
49
+ <script type="importmap">
50
+ {
51
+ "imports": {
52
+ "three": "https://cdnjs.cloudflare.com/ajax/libs/three.js/0.180.0/three.module.js",
53
+ "three/addons/": "https://unpkg.com/three@0.180.0/examples/jsm/",
54
+ "@sparkjsdev/spark": "https://unpkg.com/@sparkjsdev/spark@2.0.0/dist/spark.module.js"
55
+ }
56
+ }
57
+ </script>
58
+
59
+ <script type="module">
60
+ import * as THREE from "three";
61
+ import { OrbitControls } from "three/addons/controls/OrbitControls.js";
62
+ import { SparkRenderer, SplatMesh } from "@sparkjsdev/spark";
63
+
64
+ const params = new URLSearchParams(location.search);
65
+ const plyURL = params.get("ply");
66
+
67
+ const loadingEl = document.getElementById("loading");
68
+ const loadLabel = document.getElementById("loading-label");
69
+ const errorEl = document.getElementById("error");
70
+
71
+ function showError(msg) {
72
+ loadingEl.style.display = "none";
73
+ errorEl.style.display = "flex";
74
+ errorEl.textContent = msg;
75
+ }
76
+
77
+ if (!plyURL) {
78
+ showError("No ?ply= parameter provided.");
79
+ } else {
80
+ init(plyURL);
81
+ }
82
+
83
+ function init(url) {
84
+ const scene = new THREE.Scene();
85
+
86
+ const camera = new THREE.PerspectiveCamera(45, window.innerWidth / window.innerHeight, 0.01, 1000);
87
+ camera.position.set(0, 0.3, 1.8);
88
+
89
+ const renderer = new THREE.WebGLRenderer({ antialias: false });
90
+ renderer.setPixelRatio(window.devicePixelRatio);
91
+ renderer.setSize(window.innerWidth, window.innerHeight);
92
+ document.body.appendChild(renderer.domElement);
93
+
94
+ const spark = new SparkRenderer({ renderer });
95
+ scene.add(spark);
96
+
97
+ const controls = new OrbitControls(camera, renderer.domElement);
98
+ controls.enableDamping = true;
99
+ controls.dampingFactor = 0.07;
100
+ controls.minDistance = 0.2;
101
+ controls.maxDistance = 50;
102
+ controls.target.set(0, 0, 0);
103
+
104
+ const splat = new SplatMesh({ url });
105
+ // DEG saves PLYs with -Y up and the front of the object facing roughly
106
+ // along the horizontal axis. Re-orient for Three.js (+Y up, camera on
107
+ // +Z looking toward -Z): a Group lets us apply a yaw first (to bring the
108
+ // front-facing side to +Z) and then flip 180° around X to put up on +Y.
109
+ const splatRoot = new THREE.Group();
110
+ splatRoot.add(splat);
111
+ splat.rotation.y = Math.PI / 2; // yaw the model so its front faces the camera
112
+ splatRoot.rotation.x = Math.PI; // flip Y/Z so it stands upright
113
+ scene.add(splatRoot);
114
+
115
+ let framed = false;
116
+ function tryFrame() {
117
+ if (framed) return;
118
+ const box = new THREE.Box3().setFromObject(splatRoot);
119
+ if (box.isEmpty() || !isFinite(box.min.x)) return;
120
+ framed = true;
121
+ // const center = new THREE.Vector3();
122
+ // const size = new THREE.Vector3();
123
+ // box.getCenter(center);
124
+ // box.getSize(size);
125
+ // console.log(size);
126
+ // const maxDim = Math.max(size.x, size.y, size.z);
127
+ // // With a 45° FOV, half-fov tan ≈ 0.414, so the minimum distance to fit
128
+ // // a sphere of diameter `maxDim` in view is maxDim / (2*0.414) ≈ 1.21*maxDim.
129
+ // // Use a small margin (1.15) so the object nearly fills the viewport.
130
+ // const dist = maxDim * 1.15;
131
+ // camera.position.copy(center).add(new THREE.Vector3(0, maxDim * 0.15, dist));
132
+ // controls.target.copy(center);
133
+ // controls.update();
134
+ loadingEl.style.display = "none";
135
+ }
136
+
137
+ fetch(url, { method: "HEAD" }).then(r => {
138
+ if (!r.ok) throw new Error(`HTTP ${r.status} ${r.statusText}`);
139
+ const len = r.headers.get("content-length");
140
+ if (len) loadLabel.textContent = `Loading splat (${(len / 1024 / 1024).toFixed(1)} MB)…`;
141
+ }).catch(e => showError("Fetch failed: " + e.message));
142
+
143
+ let checkFrames = 0;
144
+ function checkLoaded() {
145
+ checkFrames++;
146
+ if (checkFrames < 5) return;
147
+ tryFrame();
148
+ }
149
+
150
+ window.addEventListener("resize", () => {
151
+ camera.aspect = window.innerWidth / window.innerHeight;
152
+ camera.updateProjectionMatrix();
153
+ renderer.setSize(window.innerWidth, window.innerHeight);
154
+ });
155
+
156
+ renderer.setAnimationLoop(() => {
157
+ controls.update();
158
+ renderer.render(scene, camera);
159
+ checkLoaded();
160
+ });
161
+
162
+ // Fallback: hide loading after 4s regardless of bbox-readiness.
163
+ setTimeout(() => { loadingEl.style.display = "none"; }, 4000);
164
+ }
165
+ </script>
166
+ </body>
167
+ </html>
triposplat.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import safetensors.torch
5
+ from PIL import Image, ImageFilter
6
+ from torchvision import transforms
7
+ from tqdm.auto import tqdm
8
+
9
+ from model import (
10
+ DinoV3ViT, Flux2VAEEncoder, BiRefNet,
11
+ OctreeProbabilityFixedlenDecoder, ElasticGaussianFixedlenDecoder,
12
+ LatentSeqMMFlowModel, OctreeGaussianDecoder,
13
+ )
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Gaussian
18
+ # ---------------------------------------------------------------------------
19
+
20
+ class Gaussian:
21
+ def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
22
+ scaling_bias: float = 0.01, opacity_bias: float = 0.1,
23
+ scaling_activation: str = "exp", device='cuda'):
24
+ self.sh_degree = sh_degree
25
+ self.mininum_kernel_size = mininum_kernel_size
26
+ self.scaling_bias = scaling_bias
27
+ self.opacity_bias = opacity_bias
28
+ self.device = device
29
+ self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
30
+
31
+ if scaling_activation == "exp":
32
+ self._scaling_activation = torch.exp
33
+ self._inverse_scaling_activation = torch.log
34
+ elif scaling_activation == "softplus":
35
+ self._scaling_activation = F.softplus
36
+ self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
37
+
38
+ self._opacity_activation = torch.sigmoid
39
+ self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
40
+
41
+ self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
42
+ self.rots_bias = torch.zeros(4, device=self.device)
43
+ self.rots_bias[0] = 1
44
+ self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
45
+
46
+ self._storage = {}
47
+
48
+ def _get_store(self, name):
49
+ return self._storage.get(name)
50
+
51
+ def _set_store(self, name, value):
52
+ self._storage[name] = value
53
+
54
+ @property
55
+ def _xyz(self):
56
+ return self._get_store("_xyz")
57
+ @_xyz.setter
58
+ def _xyz(self, value):
59
+ if value is None:
60
+ self._set_store("_xyz", None); self._set_store("xyz", None); return
61
+ self._set_store("_xyz", value)
62
+ self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
63
+
64
+ @property
65
+ def get_xyz(self):
66
+ return self._get_store("xyz")
67
+
68
+ @property
69
+ def _features_dc(self):
70
+ return self._get_store("_features_dc")
71
+ @_features_dc.setter
72
+ def _features_dc(self, value):
73
+ self._set_store("_features_dc", value)
74
+
75
+ @property
76
+ def _opacity(self):
77
+ return self._get_store("_opacity")
78
+ @_opacity.setter
79
+ def _opacity(self, value):
80
+ if value is None:
81
+ self._set_store("_opacity", None); self._set_store("opacity", None); return
82
+ self._set_store("_opacity", value)
83
+ self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
84
+
85
+ @property
86
+ def get_opacity(self):
87
+ return self._get_store("opacity")
88
+
89
+ @property
90
+ def _scaling(self):
91
+ return self._get_store("_scaling")
92
+ @_scaling.setter
93
+ def _scaling(self, value):
94
+ if value is None:
95
+ self._set_store("_scaling", None); self._set_store("scaling", None); return
96
+ self._set_store("_scaling", value)
97
+ s = self._scaling_activation(value + self.scale_bias)
98
+ s = torch.square(s) + self.mininum_kernel_size ** 2
99
+ self._set_store("scaling", torch.sqrt(s))
100
+
101
+ @property
102
+ def get_scaling(self):
103
+ return self._get_store("scaling")
104
+
105
+ @property
106
+ def _rotation(self):
107
+ return self._get_store("_rotation")
108
+ @_rotation.setter
109
+ def _rotation(self, value):
110
+ self._set_store("_rotation", value)
111
+
112
+ def construct_list_of_attributes(self):
113
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
114
+ dc = self._features_dc
115
+ for i in range(dc.shape[1] * dc.shape[2]):
116
+ l.append(f'f_dc_{i}')
117
+ l.append('opacity')
118
+ for i in range(self._scaling.shape[1]):
119
+ l.append(f'scale_{i}')
120
+ for i in range(self._rotation.shape[1]):
121
+ l.append(f'rot_{i}')
122
+ return l
123
+
124
+ _DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
125
+
126
+ def _get_ply_data(self, transform=None):
127
+ xyz = self.get_xyz.detach().cpu().numpy()
128
+ normals = np.zeros_like(xyz)
129
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
130
+ opacities = self._inverse_opacity_activation(self.get_opacity).detach().cpu().numpy()
131
+ scale = torch.log(self.get_scaling).detach().cpu().numpy()
132
+ rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
133
+ if transform is not None:
134
+ transform = np.array(transform)
135
+ xyz = np.matmul(xyz, transform.T)
136
+ R_mat = _quat_to_matrix(rotation)
137
+ R_mat = np.matmul(transform, R_mat)
138
+ rotation = _matrix_to_quat(R_mat)
139
+ return xyz, normals, f_dc, opacities, scale, rotation
140
+
141
+ def _transformed_xyz_rot(self, transform=None):
142
+ if transform is None:
143
+ transform = self._DEFAULT_TRANSFORM
144
+ transform = np.array(transform, dtype=np.float32)
145
+ xyz = self.get_xyz.detach().cpu().numpy().astype(np.float32)
146
+ rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()
147
+ xyz = np.matmul(xyz, transform.T)
148
+ R_mat = _quat_to_matrix(rotation)
149
+ R_mat = np.matmul(transform, R_mat)
150
+ rotation = _matrix_to_quat(R_mat)
151
+ return xyz, rotation
152
+
153
+ def to_ply_bytes(self, transform=None) -> bytes:
154
+ if transform is None:
155
+ transform = self._DEFAULT_TRANSFORM
156
+ xyz, normals, f_dc, opacities, scale, rotation = self._get_ply_data(transform=transform)
157
+ dtype_full = [(attr, 'f4') for attr in self.construct_list_of_attributes()]
158
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
159
+ elements[:] = list(map(tuple, np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)))
160
+ return _binary_ply_bytes(elements, dtype_full)
161
+
162
+ def to_splat_bytes(self, transform=None) -> bytes:
163
+ if transform is None:
164
+ transform = self._DEFAULT_TRANSFORM
165
+ xyz, rotation = self._transformed_xyz_rot(transform=transform)
166
+ scale = self.get_scaling.detach().cpu().numpy().astype(np.float32)
167
+ opacity = self.get_opacity.detach().cpu().numpy()
168
+ f_dc = self._features_dc.detach().cpu().numpy()
169
+ C0 = 0.28209479177387814
170
+ # .splat packs color as 4 bytes RGBA: RGB from the SH DC term, A from opacity.
171
+ rgb = np.clip((f_dc[:, 0, :] * C0 + 0.5) * 255, 0, 255).astype(np.uint8)
172
+ alpha = np.clip(opacity[:, 0:1] * 255, 0, 255).astype(np.uint8)
173
+ rgba = np.concatenate([rgb, alpha], axis=1)
174
+ rot = rotation / np.linalg.norm(rotation, axis=-1, keepdims=True)
175
+ rot_u8 = np.clip(rot * 128 + 128, 0, 255).astype(np.uint8)
176
+ order = np.argsort(-opacity[:, 0] * np.prod(scale, axis=-1))
177
+ xyz, scale, rgba, rot_u8 = xyz[order], scale[order], rgba[order], rot_u8[order]
178
+ # Per-splat record is exactly 32 bytes: xyz(12) + scale(12) + rgba(4) + rot(4).
179
+ data = np.concatenate([
180
+ xyz.astype(np.float32).view(np.uint8).reshape(-1, 12),
181
+ scale.astype(np.float32).view(np.uint8).reshape(-1, 12),
182
+ rgba.reshape(-1, 4),
183
+ rot_u8.reshape(-1, 4),
184
+ ], axis=1).reshape(-1)
185
+ return data.tobytes()
186
+
187
+ def save_ply(self, path, transform=None):
188
+ with open(path, 'wb') as f:
189
+ f.write(self.to_ply_bytes(transform=transform))
190
+
191
+ def save_splat(self, path, transform=None):
192
+ with open(path, 'wb') as f:
193
+ f.write(self.to_splat_bytes(transform=transform))
194
+
195
+
196
+ def _binary_ply_bytes(elements, dtype_full) -> bytes:
197
+ num_vertices = len(elements)
198
+ header = "ply\nformat binary_little_endian 1.0\n"
199
+ header += f"element vertex {num_vertices}\n"
200
+ type_map = {'f4': 'float', 'u1': 'uchar', 'i4': 'int'}
201
+ for name, t in dtype_full:
202
+ header += f"property {type_map.get(t, t)} {name}\n"
203
+ header += "end_header\n"
204
+ return header.encode('ascii') + elements.tobytes()
205
+
206
+
207
+ def _quat_to_matrix(q):
208
+ q = q / np.linalg.norm(q, axis=-1, keepdims=True)
209
+ w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
210
+ R = np.stack([
211
+ 1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
212
+ 2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
213
+ 2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
214
+ ], axis=-1).reshape(-1, 3, 3)
215
+ return R
216
+
217
+
218
+ def _matrix_to_quat(R):
219
+ trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
220
+ q = np.zeros((R.shape[0], 4), dtype=R.dtype)
221
+ s = np.sqrt(np.maximum(trace + 1, 0)) * 2
222
+ q[:, 0] = 0.25 * s
223
+ q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / np.where(s != 0, s, 1)
224
+ q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / np.where(s != 0, s, 1)
225
+ q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / np.where(s != 0, s, 1)
226
+ m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
227
+ s1 = np.sqrt(np.maximum(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], 0)) * 2
228
+ q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
229
+ q[m01, 1] = 0.25 * s1[m01]
230
+ q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
231
+ q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
232
+ m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
233
+ s2 = np.sqrt(np.maximum(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], 0)) * 2
234
+ q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
235
+ q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
236
+ q[m11, 2] = 0.25 * s2[m11]
237
+ q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
238
+ m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
239
+ s3 = np.sqrt(np.maximum(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], 0)) * 2
240
+ q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
241
+ q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
242
+ q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
243
+ q[m21, 3] = 0.25 * s3[m21]
244
+ return q / np.linalg.norm(q, axis=-1, keepdims=True)
245
+
246
+
247
+ def _build_gaussians(decoder: ElasticGaussianFixedlenDecoder, points_pred: dict, pred: dict):
248
+ x = points_pred
249
+ offset = decoder._get_offset(pred['features'])
250
+ h = pred["features"]
251
+ ret = []
252
+ for i in range(h.shape[0]):
253
+ g = Gaussian(
254
+ sh_degree=0,
255
+ aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
256
+ mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
257
+ scaling_bias=decoder.rep_config['scaling_bias'],
258
+ opacity_bias=decoder.rep_config['opacity_bias'],
259
+ scaling_activation=decoder.rep_config['scaling_activation'],
260
+ )
261
+ _x = x["points"][i, :, None, :]
262
+ for k, v in decoder.layout.items():
263
+ if k == '_xyz':
264
+ setattr(g, k, (offset[i] + _x).flatten(0, 1))
265
+ elif k in ('_xyz_center', '_offset_scale'):
266
+ continue
267
+ else:
268
+ feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
269
+ setattr(g, k, feats * decoder.rep_config['lr'][k])
270
+ ret.append(g)
271
+ return ret
272
+
273
+
274
+ # ---------------------------------------------------------------------------
275
+ # Euler flow sampler
276
+ # ---------------------------------------------------------------------------
277
+
278
+ class FlowEulerCfgSampler:
279
+ def __init__(self, sigma_min: float = 1e-5):
280
+ self.sigma_min = sigma_min
281
+
282
+ def _get_batch_size(self, x_t):
283
+ return next(iter(x_t.values())).shape[0] if isinstance(x_t, dict) else x_t.shape[0]
284
+
285
+ def _get_device(self, x_t):
286
+ return next(iter(x_t.values())).device if isinstance(x_t, dict) else x_t.device
287
+
288
+ def _inference_model(self, model, x_t, t, cond=None):
289
+ batch = self._get_batch_size(x_t)
290
+ device = self._get_device(x_t)
291
+ t_scaled = torch.tensor([1000 * t] * batch, device=device, dtype=torch.float32)
292
+ if isinstance(cond, dict):
293
+ for k, v in cond.items():
294
+ if isinstance(v, torch.Tensor) and v.shape[0] == 1 and batch > 1:
295
+ cond[k] = v.repeat(batch, *([1] * (len(v.shape) - 1)))
296
+ elif cond is not None and cond.shape[0] == 1 and batch > 1:
297
+ cond = cond.repeat(batch, *([1] * (len(cond.shape) - 1)))
298
+ return model(x_t, t_scaled, cond)
299
+
300
+ def _cfg_prediction(self, model, x_t, t, cond, neg_cond, guidance_scale):
301
+ # Diffusers-style convention: guidance_scale == 1 (or <= 1, or None) means no CFG —
302
+ # only the conditional pass runs, halving the per-step cost. > 1 enables CFG and
303
+ # blends as `pred = s * cond + (1 - s) * uncond = s * cond - (s - 1) * uncond`.
304
+ pred_v = self._inference_model(model, x_t, t, cond)
305
+ if isinstance(guidance_scale, dict):
306
+ if not any(s > 1 for s in guidance_scale.values()):
307
+ return pred_v
308
+ neg_pred_v = self._inference_model(model, x_t, t, neg_cond)
309
+ for key in pred_v:
310
+ s = guidance_scale.get(key, 1.0)
311
+ if s > 1:
312
+ pred_v[key] = s * pred_v[key] - (s - 1) * neg_pred_v[key]
313
+ return pred_v
314
+ if guidance_scale is None or guidance_scale <= 1:
315
+ return pred_v
316
+ neg_pred_v = self._inference_model(model, x_t, t, neg_cond)
317
+ for key in pred_v:
318
+ pred_v[key] = guidance_scale * pred_v[key] - (guidance_scale - 1) * neg_pred_v[key]
319
+ return pred_v
320
+
321
+ @torch.no_grad()
322
+ def sample(self, model, noise, cond, neg_cond, steps=50, shift=1.0,
323
+ guidance_scale=None, show_progress=False, callback=None):
324
+ sample = noise
325
+ t_seq = shift * np.linspace(1, 0, steps + 1) / (1 + (shift - 1) * np.linspace(1, 0, steps + 1))
326
+ t_pairs = list(zip(t_seq[:-1], t_seq[1:]))
327
+ iterator = tqdm(t_pairs, desc="Sampling", total=steps) if show_progress else t_pairs
328
+ for i, (t, t_prev) in enumerate(iterator):
329
+ x_t = {k: v.clone() for k, v in sample.items()} if isinstance(sample, dict) else sample.clone()
330
+ pred_v = self._cfg_prediction(model, x_t, t, cond, neg_cond, guidance_scale)
331
+ dt = t - t_prev
332
+ if isinstance(sample, dict):
333
+ for key in sample:
334
+ sample[key] = sample[key] - pred_v[key] * dt
335
+ else:
336
+ sample = sample - pred_v * dt
337
+ if callback is not None:
338
+ callback(i + 1, steps)
339
+ return sample
340
+
341
+
342
+ # ---------------------------------------------------------------------------
343
+ # Component loaders
344
+ # ---------------------------------------------------------------------------
345
+
346
+ def _place(m, device, dtype):
347
+ if device is not None or dtype is not None:
348
+ m = m.to(device=device, dtype=dtype)
349
+ return m.eval()
350
+
351
+
352
+ def load_dinov3(path: str, device=None, dtype=None) -> DinoV3ViT:
353
+ m = DinoV3ViT()
354
+ m.load_safetensors(path)
355
+ return _place(m, device, dtype)
356
+
357
+
358
+ def load_vae_encoder(path: str, device=None, dtype=None) -> Flux2VAEEncoder:
359
+ m = Flux2VAEEncoder()
360
+ m.load_safetensors(path)
361
+ return _place(m, device, dtype)
362
+
363
+
364
+ def load_rmbg(path: str, device=None, dtype=None) -> BiRefNet:
365
+ m = BiRefNet()
366
+ m.load_safetensors(path)
367
+ return _place(m, device, dtype)
368
+
369
+
370
+ FLOW_MODEL_ARGS = dict(
371
+ q_token_length=8192, in_channels=16, cam_channels=5, out_channels=16,
372
+ model_channels=1024, cond_channels=1280, cond2_channels=128,
373
+ num_refiner_blocks=2, num_blocks=24, num_heads=16, mlp_ratio=4,
374
+ qk_rms_norm=True, share_mod=True, use_shift_table=True,
375
+ )
376
+
377
+
378
+ def load_flow_model(path: str, device=None, dtype=None) -> LatentSeqMMFlowModel:
379
+ m = LatentSeqMMFlowModel(**FLOW_MODEL_ARGS)
380
+ m.load_safetensors(path)
381
+ return _place(m, device, dtype)
382
+
383
+
384
+ OCTREE_DECODER_ARGS = dict(
385
+ model_channels=1024, cond_channels=16,
386
+ num_blocks=4, num_heads=16, mlp_ratio=4, share_mod=True,
387
+ )
388
+
389
+ GS_DECODER_ARGS = dict(
390
+ in_channels=3, model_channels=1024, cond_channels=16,
391
+ attn_mode="full", num_blocks=16, num_heads=16, mlp_ratio=4,
392
+ use_learned_offset_scale=True, use_per_offset=True,
393
+ representation_config=dict(
394
+ lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
395
+ perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
396
+ filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
397
+ scaling_activation="softplus",
398
+ ),
399
+ )
400
+
401
+
402
+ def load_decoder(path: str, device=None, dtype=None) -> OctreeGaussianDecoder:
403
+ m = OctreeGaussianDecoder(OCTREE_DECODER_ARGS, GS_DECODER_ARGS)
404
+ m.load_safetensors(path)
405
+ return _place(m, device, dtype)
406
+
407
+
408
+ # ---------------------------------------------------------------------------
409
+ # Pipeline stages
410
+ # ---------------------------------------------------------------------------
411
+
412
+ _CANVAS_SIZE = 1024
413
+
414
+
415
+ def _image_to_pil(image) -> Image.Image:
416
+ if isinstance(image, Image.Image):
417
+ return image
418
+ if isinstance(image, (str, bytes)) or hasattr(image, "__fspath__"):
419
+ return Image.open(image)
420
+ if isinstance(image, torch.Tensor):
421
+ t = image.detach().cpu()
422
+ if t.ndim == 4:
423
+ assert t.shape[0] == 1, (
424
+ f"batched image input is not supported (got B={t.shape[0]}); "
425
+ "pass one image at a time"
426
+ )
427
+ t = t[0]
428
+ arr = (t.clamp(0, 1) * 255).to(torch.uint8).numpy()
429
+ mode = "RGBA" if arr.shape[-1] == 4 else "RGB"
430
+ return Image.fromarray(arr, mode=mode)
431
+ raise TypeError(f"unsupported image type: {type(image)}")
432
+
433
+
434
+ def preprocess_image(image, rmbg: BiRefNet, erode_radius: int = 1) -> Image.Image:
435
+ image = _image_to_pil(image)
436
+ size = _CANVAS_SIZE
437
+ w, h = image.size
438
+ s = size / min(w, h)
439
+ image = image.resize((max(1, int(round(w * s))), max(1, int(round(h * s)))), Image.LANCZOS)
440
+ has_real_alpha = (image.mode == "RGBA"
441
+ and np.array(image.getchannel(3), dtype=np.int32).min() < 255)
442
+ if not has_real_alpha:
443
+ image = rmbg.remove_background(image.convert("RGB"))
444
+ if erode_radius > 0:
445
+ image.putalpha(image.getchannel(3).filter(ImageFilter.MinFilter(2 * erode_radius + 1)))
446
+ alpha = np.array(image.getchannel(3))
447
+ ys, xs = np.nonzero(alpha)
448
+ bbox = [xs.min(), ys.min(), xs.max(), ys.max()]
449
+ cx, cy = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
450
+ half = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 * 1.2
451
+ image = image.crop([int(cx - half), int(cy - half), int(cx + half), int(cy + half)])
452
+ image = image.resize((size, size), Image.LANCZOS)
453
+ bg = Image.new("RGB", (size, size), (0, 0, 0))
454
+ bg.paste(image, mask=image.split()[3])
455
+ return bg
456
+
457
+
458
+ _DINOV3_NORMALIZE = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
459
+
460
+
461
+ @torch.no_grad()
462
+ def encode_image(image: Image.Image, dinov3: DinoV3ViT, vae_encoder: Flux2VAEEncoder,
463
+ generator: torch.Generator = None) -> dict:
464
+ device = next(dinov3.parameters()).device
465
+ img_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device=device, dtype=torch.float32)
466
+ img_normed = _DINOV3_NORMALIZE(img_tensor)
467
+ dinov3_dtype = next(dinov3.parameters()).dtype
468
+ vae_dtype = next(vae_encoder.parameters()).dtype
469
+ dinov3_feat = dinov3(pixel_values=img_normed.to(dinov3_dtype))
470
+ dinov3_feat = F.layer_norm(dinov3_feat.float(), dinov3_feat.shape[-1:])
471
+ vae_feat = vae_encoder.encode(img_tensor.to(vae_dtype) * 2 - 1,
472
+ deterministic=False, generator=generator)
473
+ # pad 5 zero tokens so feature2's token length matches feature1's (cls + 4 registers + patches)
474
+ zero_reg = torch.zeros(vae_feat.shape[0], 5, vae_feat.shape[2],
475
+ dtype=vae_feat.dtype, device=vae_feat.device)
476
+ vae_feat = torch.cat([zero_reg, vae_feat], dim=1)
477
+ return {'feature1': dinov3_feat, 'feature2': vae_feat}
478
+
479
+
480
+ @torch.no_grad()
481
+ def sample_latent(flow_model: LatentSeqMMFlowModel, cond: dict,
482
+ steps: int = 50, guidance_scale: float = 7.0, shift: float = 3.0,
483
+ generator: torch.Generator = None,
484
+ show_progress: bool = False, callback=None) -> dict:
485
+ device = flow_model.device
486
+ neg_cond = {k: torch.zeros_like(v) for k, v in cond.items()}
487
+ noise = {'latent': torch.randn(1, flow_model.q_token_length, flow_model.in_channels,
488
+ device=device, generator=generator)}
489
+ if flow_model.cam_channels is not None:
490
+ noise['camera'] = torch.randn(1, 1, flow_model.cam_channels,
491
+ device=device, generator=generator)
492
+ sampler = FlowEulerCfgSampler()
493
+ return sampler.sample(flow_model, noise, cond=cond, neg_cond=neg_cond,
494
+ steps=steps, guidance_scale=guidance_scale, shift=shift,
495
+ show_progress=show_progress, callback=callback)
496
+
497
+
498
+ # ---------------------------------------------------------------------------
499
+ # Pipeline
500
+ # ---------------------------------------------------------------------------
501
+
502
+ class TripoSplatPipeline:
503
+ def __init__(self, ckpt_path: str, decoder_path: str, dinov3_path: str,
504
+ flux2_vae_encoder_path: str, rmbg_path: str, device: str = "cuda"):
505
+ self._device = torch.device(device)
506
+ self.dinov3 = load_dinov3 (dinov3_path, device=self._device, dtype=torch.bfloat16)
507
+ self.vae_encoder = load_vae_encoder (flux2_vae_encoder_path, device=self._device, dtype=torch.bfloat16)
508
+ self.rmbg = load_rmbg (rmbg_path, device=self._device, dtype=torch.float16)
509
+ self.flow_model = load_flow_model (ckpt_path, device=self._device, dtype=torch.float16)
510
+ self.decoder = load_decoder (decoder_path, device=self._device, dtype=torch.float16)
511
+
512
+ def preprocess_image(self, image, erode_radius: int = 1) -> Image.Image:
513
+ return preprocess_image(image, self.rmbg, erode_radius=erode_radius)
514
+
515
+ def encode_image(self, image: Image.Image, generator: torch.Generator = None) -> dict:
516
+ return encode_image(image, self.dinov3, self.vae_encoder, generator=generator)
517
+
518
+ def sample_latent(self, cond: dict, steps: int = 50, guidance_scale: float = 7.0,
519
+ shift: float = 3.0, generator: torch.Generator = None,
520
+ show_progress: bool = False, callback=None) -> dict:
521
+ return sample_latent(self.flow_model, cond, steps=steps, guidance_scale=guidance_scale,
522
+ shift=shift, generator=generator,
523
+ show_progress=show_progress, callback=callback)
524
+
525
+ def decode_latent(self, latent: torch.Tensor, num_gaussians: int = 262144):
526
+ return self.decoder.decode(latent, num_gaussians=num_gaussians)
527
+
528
+ _NUM_GAUSSIANS_MIN = 32768
529
+ _NUM_GAUSSIANS_MAX = 262144
530
+
531
+ def _validate_num_gaussians(self, n: int) -> int:
532
+ assert self._NUM_GAUSSIANS_MIN <= n <= self._NUM_GAUSSIANS_MAX, (
533
+ f"num_gaussians must be in [{self._NUM_GAUSSIANS_MIN}, {self._NUM_GAUSSIANS_MAX}], got {n}"
534
+ )
535
+ gpp = self.decoder.gaussians_per_point
536
+ if n % gpp == 0:
537
+ return n
538
+ rounded = round(n / gpp) * gpp
539
+ print(f"[TripoSplatPipeline] num_gaussians={n} is not a multiple of {gpp}; rounding to {rounded}")
540
+ return rounded
541
+
542
+ @torch.no_grad()
543
+ def run(self, image, seed: int = 42, steps: int = 20, guidance_scale: float = 3.0,
544
+ shift: float = 3.0, num_gaussians=262144, erode_radius: int = 1,
545
+ show_progress: bool = False, callback=None):
546
+ """
547
+ Args:
548
+ image: Input image. Accepts a file path / PIL.Image / torch.Tensor
549
+ (`[1,H,W,C]` or `[H,W,C]`, float in `[0, 1]`, optional alpha
550
+ channel as the 4th channel).
551
+ seed: RNG seed for the VAE encoder's stochastic latent sampling and
552
+ the initial flow-matching noise. Same seed → same output.
553
+ steps: Number of Euler integrator steps in the flow-matching sampler.
554
+ More steps → better fidelity, linear runtime cost.
555
+ Recommend: 10~20.
556
+ guidance_scale: Classifier-free-guidance strength (diffusers
557
+ convention). `≤ 1.0` disables CFG. Higher → more detail,
558
+ stronger adherence to the input image; too high can cause color
559
+ oversaturation.
560
+ Recommend: 3.0.
561
+ shift: Flow-matching timestep schedule shift. `1.0` gives a uniform
562
+ schedule; `>1.0` allocates more steps to the early/high-noise end.
563
+ Recommend: 3.0.
564
+ num_gaussians: Target Gaussian-splat count. An `int` returns a
565
+ single `Gaussian`. A `list` / `tuple` of ints returns a
566
+ `list[Gaussian]`. Each count is rounded to the nearest multiple
567
+ of 32. More gaussians → more detail but higher rendering and
568
+ storage cost.
569
+ Recommend: 32768~262144.
570
+ erode_radius: Pixel radius used to erode the alpha matte after
571
+ background removal, to avoid segmentation-border bleed before
572
+ compositing on black. `0` disables; `1` is a 3×3 minimum filter.
573
+ Recommend: 1.
574
+ show_progress: Print a `tqdm` progress bar over sampler steps.
575
+ callback: Optional `fn(step, total)` invoked after each sampler step.
576
+ Useful for external progress UIs (e.g. ComfyUI's
577
+ `ProgressBar.update`).
578
+
579
+ Returns:
580
+ `(gaussian, prepared_image)` for an `int` `num_gaussians`, or
581
+ `(list_of_gaussians, prepared_image)` for a `list` / `tuple`. The
582
+ second element is the RGB composite the encoders actually saw —
583
+ useful for display / debugging.
584
+ """
585
+ if isinstance(num_gaussians, (list, tuple)):
586
+ counts = [self._validate_num_gaussians(n) for n in num_gaussians]
587
+ else:
588
+ counts = [self._validate_num_gaussians(num_gaussians)]
589
+
590
+ gen = torch.Generator(device=self._device).manual_seed(seed)
591
+ prepared = self.preprocess_image(image, erode_radius=erode_radius)
592
+ cond = self.encode_image(prepared, generator=gen)
593
+ out = self.sample_latent(cond, steps=steps, guidance_scale=guidance_scale, shift=shift,
594
+ generator=gen, show_progress=show_progress, callback=callback)
595
+ gaussians = [self.decode_latent(out['latent'], num_gaussians=n) for n in counts]
596
+ if isinstance(num_gaussians, (list, tuple)):
597
+ return gaussians, prepared
598
+ return gaussians[0], prepared