diff --git a/app.py b/app.py index 111f79149e1cebc36149a6c125742b33f4cbf5e6..65ab03622fa635e384b5755d5219840fa66df144 100644 --- a/app.py +++ b/app.py @@ -193,7 +193,7 @@ def generate_and_extract_glb( image_files = [image[0] for image in multiimages] # Generate 3D model - outputs = pipeline.run( + outputs, _, _ = pipeline.run( image=image_files, seed=seed, formats=["gaussian", "mesh"], @@ -210,6 +210,12 @@ def generate_and_extract_glb( ) # Render video + # import uuid + # output_id = str(uuid.uuid4()) + # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True) + # video_path = f"{TMP_DIR}/{output_id}/preview.mp4" + # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb" + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] @@ -331,10 +337,8 @@ with demo: image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300) multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3) gr.Markdown(""" - Input different views of the object in separate images. - - *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.* - """) + Input different views of the object in separate images. + """) with gr.Accordion(label="Generation Settings", open=False): seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) @@ -434,7 +438,7 @@ with demo: # Launch the Gradio app if __name__ == "__main__": - pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-1") + pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2") pipeline.cuda() pipeline.VGGT_model.cuda() pipeline.birefnet_model.cuda() diff --git a/trellis/__pycache__/__init__.cpython-310.pyc b/trellis/__pycache__/__init__.cpython-310.pyc index 23e2fd5458d5ea6efe46fe6ef3d794c9f51d5559..1c01ca03138795cb6ffcb7cfec0ed03e010a24ec 100644 Binary files a/trellis/__pycache__/__init__.cpython-310.pyc and b/trellis/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/__init__.py b/trellis/models/__init__.py index 8c302e12b2c530c3cc0e844be83551bcc5814e4a..4d4e3383e7aa54d3e246c5ad1fccb55fc59eb134 100644 --- a/trellis/models/__init__.py +++ b/trellis/models/__init__.py @@ -9,6 +9,7 @@ __attributes = { 'SLatMeshDecoder': 'structured_latent_vae', 'SLatFlowModel': 'structured_latent_flow', 'ModulatedMultiViewCond': 'sparse_structure_flow', + 'ModulatedSLATMultiViewCond': 'structured_latent_flow', } __submodules = [] @@ -85,4 +86,4 @@ if __name__ == '__main__': from .sparse_structure_vae import SparseStructureEncoder, SparseStructureDecoder from .sparse_structure_flow import SparseStructureFlowModel, ModulatedMultiViewCond from .structured_latent_vae import SLatEncoder, SLatGaussianDecoder, SLatMeshDecoder - from .structured_latent_flow import SLatFlowModel + from .structured_latent_flow import SLatFlowModel, ModulatedSLATMultiViewCond diff --git a/trellis/models/__pycache__/__init__.cpython-310.pyc b/trellis/models/__pycache__/__init__.cpython-310.pyc index 49d5d11874115c1ee31886d315263c1495d44f5b..8c26051e6f16a4d9d1e60106117a895449cc8f05 100644 Binary files a/trellis/models/__pycache__/__init__.cpython-310.pyc and b/trellis/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc index 202e2449ca45640ed8caf16011e835fe16555f6f..b69906d0316b44c234fb41255ce67484848ddbda 100644 Binary files a/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc and b/trellis/models/__pycache__/sparse_structure_flow.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc b/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc index 5bfef55b5aa9a1c02f3e2f354d5f4e2149a49d76..7d93e2396e0a132c5374275ffb64ae8fa4ff2ade 100644 Binary files a/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc and b/trellis/models/__pycache__/sparse_structure_vae.cpython-310.pyc differ diff --git a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc index 8d0e095cd6f99f10575fe0c83645a74c6aa58e3c..dc4180ecc2ef7d947719d4b2adf858a4d501b203 100644 Binary files a/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc and b/trellis/models/__pycache__/structured_latent_flow.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_flow.py b/trellis/models/structured_latent_flow.py index e6468aa4f65415b301eb5d44b8babb7c5e78d3ce..13d25e01699094d500177d42bcd81f67bde4b96a 100644 --- a/trellis/models/structured_latent_flow.py +++ b/trellis/models/structured_latent_flow.py @@ -311,6 +311,11 @@ class SLatFlowModel(nn.Module): t_emb = self.adaLN_modulation(t_emb) t_emb = t_emb.type(self.dtype) + if isinstance(cond, list): + cond = [c.type(self.dtype) for c in cond] + else: + cond = cond.type(self.dtype) + skips = [] # pack with input blocks for block in self.input_blocks: @@ -320,15 +325,8 @@ class SLatFlowModel(nn.Module): if self.pe_mode == "ape": h = h + self.pos_embedder(h.coords[:, 1:]).type(self.dtype) - if isinstance(cond, list): - for i in range(len(cond)): - cond_tmp = cond[i].type(self.dtype) - for block in self.blocks: - h = block(h, t_emb, cond_tmp) - else: - cond = cond.type(self.dtype) - for block in self.blocks: - h = block(h, t_emb, cond) + for block in self.blocks: + h = block(h, t_emb, cond) # unpack with output blocks for block, skip in zip(self.out_blocks, reversed(skips)): @@ -340,3 +338,61 @@ class SLatFlowModel(nn.Module): h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.out_layer(h.type(x.dtype)) return h + +class ModulatedSLATMultiViewCond(nn.Module): + """ + Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + dtype: Optional[torch.dtype] = torch.float32, + use_fp16: bool = True, + ): + super().__init__() + self.linear_blocks = nn.ModuleList([ + nn.Sequential( + nn.Linear(ctx_channels, channels, bias=True), + nn.ReLU(), + ) + for _ in range(4) + ]) + self.fuse_blocks = nn.ModuleList([ + nn.Sequential( + nn.Linear(ctx_channels, channels, bias=True), + nn.ReLU(), + ) + for _ in range(4) + ]) + self.use_fp16 = use_fp16 + if use_fp16: + self.dtype = torch.float16 + else: + self.dtype = dtype + + self.intermediate_layer_idx = [4, 11, 17, 23] + if use_fp16: + self.convert_to_fp16() + + def convert_to_fp16(self) -> None: + """ + Convert the torso of the model to float16. + """ + self.use_fp16 = True + self.dtype = torch.float16 + self.linear_blocks.apply(convert_module_to_f16) + + def forward(self, aggregated_tokens_list: List, image_cond: torch.Tensor): + + b, n, _, _ = aggregated_tokens_list[0].shape + idx = 0 + cond = image_cond.reshape(b*n, -1, 1024).to(self.dtype) + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx] + # x = x.reshape(b, -1, 2048) + torch.cat([image_cond.reshape(b, -1, 1024), image_cond.reshape(b, -1, 1024)],dim=-1) + x = torch.cat([x.reshape(b*n, -1, 2048), cond.reshape(b*n, -1, 1024)],dim=-1).to(self.dtype) + x = self.linear_blocks[idx](x) + cond = x + image_cond.reshape(b*n, -1, 1024).to(self.dtype) + idx = idx + 1 + return cond \ No newline at end of file diff --git a/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc index 522a8c1d39a85fe580c4e8cff169c44fe71a7a6f..1fe66aab0911fef620a4a4d69ca2482d5f799a81 100644 Binary files a/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc and b/trellis/models/structured_latent_vae/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc index 94c7ce79d4149bec801650c5acda79572c96ba50..6629eb084ce3395c040040586dd05f045d957a39 100644 Binary files a/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc and b/trellis/models/structured_latent_vae/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc index ff121ce35266506a24172b4d8eb2925484732f1c..8be1e1d299136d2b6d793eda4d46c5f57339d4d7 100644 Binary files a/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc and b/trellis/models/structured_latent_vae/__pycache__/decoder_gs.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc index 3015c09fa91b24e0dbc90b056834349d4a84aa42..017e3ca18f2da190157b3495b06e2e18f36272ce 100644 Binary files a/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc and b/trellis/models/structured_latent_vae/__pycache__/decoder_mesh.cpython-310.pyc differ diff --git a/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc b/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc index 1cca13828f447ca049d502fad46f48b314ec5501..98d691c09d0030afa78c542a02263bdd5f1b0987 100644 Binary files a/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc and b/trellis/models/structured_latent_vae/__pycache__/encoder.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/norm.cpython-310.pyc b/trellis/modules/__pycache__/norm.cpython-310.pyc index 0f7287dfe3cf23e0e81b792837b4313725c31b72..a0c6289222896eae292d1a3eee1cfd503ea966ca 100644 Binary files a/trellis/modules/__pycache__/norm.cpython-310.pyc and b/trellis/modules/__pycache__/norm.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/spatial.cpython-310.pyc b/trellis/modules/__pycache__/spatial.cpython-310.pyc index b154fb6e5f4de8c8d060947a4b9b0d2d801f20ca..a95e42d41c5a0eeb84de6675e2eed98f3584a19d 100644 Binary files a/trellis/modules/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/__pycache__/spatial.cpython-310.pyc differ diff --git a/trellis/modules/__pycache__/utils.cpython-310.pyc b/trellis/modules/__pycache__/utils.cpython-310.pyc index dd9a6303f0af67ed8395d0da57694dc282a4d8a6..819873be3693778bc5c0d0478122813c00bc05b9 100644 Binary files a/trellis/modules/__pycache__/utils.cpython-310.pyc and b/trellis/modules/__pycache__/utils.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc index 22a07011334ff04709b1a0f6760439376bd74535..f36b87782f5a5a25cb3dd0c183a805003c9c44d1 100644 Binary files a/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc index 278ec68f3b43303ee95413eb56ab5886795a30f7..43fb1bbf289908acb164ca7f715d2954d865ddde 100644 Binary files a/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc index 423bcbe7586ebee9be62c3234e757d746d007bf2..c22ee7d03df9681e39aeb6e84972d2dd5ed66892 100644 Binary files a/trellis/modules/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc index ed807f72966edfb0f5326cb2471f319509edc26e..7fd77aa78f67910267addf0099fd4cc78ac495d6 100644 Binary files a/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc index 43d040474e123ecb5dd8fddd2f7197eb0356e235..24cbf3a0e0502ecdde17a89b458db7be571e5b38 100644 Binary files a/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/basic.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc index 3a96a59de3c903bc055104d73b2d38b01cef55ca..be1fe4597b48dc768f3f1a386702a0f11cc9a781 100644 Binary files a/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/linear.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc index 731a3b564306a0cac9022a1c9ad7e94a9cbba817..d01fe927d4e634e06871353f31d9ffc86d7f8f59 100644 Binary files a/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/nonlinearity.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc b/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc index a4573decac983f6f5b0eea585e17c9a61980958d..c31b94b3c578c3da727319fd6a3a871f1791e09a 100644 Binary files a/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/norm.cpython-310.pyc differ diff --git a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc index e23c93b4939b088b91bd8c505addebc20b1f7409..09095d47cdaf9196c9c2b00903419cb3bb7d316d 100644 Binary files a/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc and b/trellis/modules/sparse/__pycache__/spatial.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc index f593d4c1e17d1ad6c00e90b05dd019a90ed590da..12c22e6ab638dc458dc78e9978513c18ebede66f 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc index 68e865e9c460bfabd40f858962360334dbf1bd25..43a4a03b4ab81f33c7c414e6a30deef031f56082 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/full_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc index db6e497005c2cb7d7912fd7ba1325929e826c7ed..e3f74fc70bd89a2bc6fdb112865393be7402d687 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/modules.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc index 47fdd114317f1e73d6ec0be5bccd8d6bc9961a26..fb5380951324fd88aa04611c2b6566f058d66274 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/serialized_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc index 1bf32dc2364b42eec3090693994a0115b1f29d15..f0c559af356abc76ced8ce6c9cb97c9821707ced 100644 Binary files a/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc and b/trellis/modules/sparse/attention/__pycache__/windowed_attn.cpython-310.pyc differ diff --git a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc index 4c28690836369c624e802ccd4261d1003088521c..fb28db4e180c0a44114833e4db6620755baac9f0 100644 Binary files a/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc index de10562e71f2d8868f0f6d6f46b5418700c25a89..42473e51a71c280b1893ecf00eaf4f944eda9f78 100644 Binary files a/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc and b/trellis/modules/sparse/conv/__pycache__/conv_spconv.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc index b6160f08743d3a9c228bab36d007892740cafabb..498c60e03773e327a2638775d0ad712f5d3072c2 100644 Binary files a/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc index 62021de1246d666327d9d3f5dfc3bb4f68ddd168..f722d65b54846dd5c2bd3ce7c848ab7c02ade486 100644 Binary files a/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc index 107fc8f1ed75174d48e8f9641a8691e980d621c7..87af1164d2875540043e62820cd79e3d4f1522c8 100644 Binary files a/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/sparse/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/trellis/modules/sparse/transformer/modulated.py b/trellis/modules/sparse/transformer/modulated.py index 4a8416559f39acbed9e5996e9891c97f95c80c8f..4c3151aadf01ecc8ef61ac3fe70b75bdcabd6840 100644 --- a/trellis/modules/sparse/transformer/modulated.py +++ b/trellis/modules/sparse/transformer/modulated.py @@ -150,8 +150,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): h = h * gate_msa x = x + h h = x.replace(self.norm2(x.feats)) - h = self.cross_attn(h, context) - x = x + h + # h = self.cross_attn(h, context) + # x = x + h + if isinstance(context, list): + for ctx in context: + x = x + self.cross_attn(h, ctx) / len(context) + else: + h = self.cross_attn(h, context) + x = x + h h = x.replace(self.norm3(x.feats)) h = h * (1 + scale_mlp) + shift_mlp h = self.mlp(h) diff --git a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc index 6614cf67a229fc59aae86938bdbd6000f8aa0def..c4fc1d1822eedcfafcc1ef194f87e83af38a8dd6 100644 Binary files a/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc index bf7d94a799eaf030be44fcf8136207b6eda905db..bcc1d19075ea25cd05b92397fdf615ec9517cc88 100644 Binary files a/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/blocks.cpython-310.pyc differ diff --git a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc index 8f842476dad14630b99b4a11ee852805887ebba9..619cec0fdecc7678de5532f894918abc29611207 100644 Binary files a/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc and b/trellis/modules/transformer/__pycache__/modulated.cpython-310.pyc differ diff --git a/trellis/modules/transformer/modulated.py b/trellis/modules/transformer/modulated.py index c85d8d551b2bf9e45fb86bd98539967a4cec6665..201465633ce3abcaf740f5a5eb8bdafdb11a48fe 100644 --- a/trellis/modules/transformer/modulated.py +++ b/trellis/modules/transformer/modulated.py @@ -138,13 +138,18 @@ class ModulatedTransformerCrossBlock(nn.Module): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) h = self.norm1(x) h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) - # h = torch.utils.checkpoint.checkpoint(self.self_attn, h) h = self.self_attn(h) h = h * gate_msa.unsqueeze(1) x = x + h h = self.norm2(x) - h = self.cross_attn(h, context) - x = x + h + # h = self.cross_attn(h, context) + # x = x + h + if isinstance(context, list): + for ctx in context: + x = x + self.cross_attn(h, ctx) / len(context) + else: + h = self.cross_attn(h, context) + x = x + h h = self.norm3(x) h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) h = self.mlp(h) diff --git a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc index b505f96842e09d26f38d9441e724e5f774165762..99783060f81696d9394dcfabbb19883772f79ab1 100644 Binary files a/trellis/pipelines/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/base.cpython-310.pyc b/trellis/pipelines/__pycache__/base.cpython-310.pyc index b36097d7794c20939ba974eae2d34b65a1a37893..d6ea712627e59cbb4919dd09757dde6b3ed63ce8 100644 Binary files a/trellis/pipelines/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc index d6993693d05d395fb94b38a1b1cee135661f8c54..ebd44d9e664ccd1eeaad85bf5e55e39cd4d3bdbe 100644 Binary files a/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc and b/trellis/pipelines/__pycache__/trellis_image_to_3d.cpython-310.pyc differ diff --git a/trellis/pipelines/base.py b/trellis/pipelines/base.py index 30a2dabf298606c20830917c811f9da9dd7f4f51..3cab8e54ad1a062c20a9f880261bdf203da06f6c 100644 --- a/trellis/pipelines/base.py +++ b/trellis/pipelines/base.py @@ -24,6 +24,9 @@ class Pipeline: self.sparse_structure_flow_model = self.models['sparse_structure_flow_model'] if 'sparse_structure_vggt_cond' in self.models: self.sparse_structure_vggt_cond = self.models['sparse_structure_vggt_cond'] + if 'slat_vggt_cond' in self.models: + self.slat_vggt_cond = self.models['slat_vggt_cond'] + @staticmethod def from_pretrained(path: str) -> "Pipeline": """ diff --git a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc index 050e5e93c51f8c0220ae144762d2ced1b9a9af9a..c6d8f44cd9134cbd333a5d8ac9524a236ff7f78b 100644 Binary files a/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc index 1250f59798af8e9277b342bae5f15132d7f795af..9b979a3e3ff09057eef511b8bd7acc96da9654a9 100644 Binary files a/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/base.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc index 9ddc31e608eee86ba6b390eec45f586bc3718397..f98e1a61b2360b880227d6a240519720e6a0de01 100644 Binary files a/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/classifier_free_guidance_mixin.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc index 6cdea0a51d12ef248d67b240a74a9a8c6d699c59..46edcbfb319ce4777f0bdf664e96cc00613c7275 100644 Binary files a/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/flow_euler.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc index d1ef50294691e3fab01f0399bb6cfbafa73a3832..e34fe2bff807a00e731012ad7b7c899e268f1fd5 100644 Binary files a/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/flow_euler_old.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc index e4bffb043c3992d2fbeeebb815738e8de8807e8e..5b03810fbe94f0fcd52f7d165757e36c1fb3bc3f 100644 Binary files a/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc and b/trellis/pipelines/samplers/__pycache__/guidance_interval_mixin.cpython-310.pyc differ diff --git a/trellis/pipelines/samplers/flow_euler.py b/trellis/pipelines/samplers/flow_euler.py index 56ce2794291fab862003f5b5eba9f146e7377719..e62c633f0055dfea4a48bf886c0ae3ee4c0b87e8 100644 --- a/trellis/pipelines/samplers/flow_euler.py +++ b/trellis/pipelines/samplers/flow_euler.py @@ -129,7 +129,56 @@ class FlowEulerSampler(Sampler): pred_x_prev = x_t - (t - t_prev) * pred_v return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) - def sample_once_opt_delta_v( + def sample_ss_once_opt_delta_v( + self, + model, + ss_decoder, + learning_rate, + ss, + x_t, + t: float, + t_prev: float, + cond: Optional[Any] = None, + **kwargs + ): + """ + Sample x_{t-1} from the model using Euler method. + + Args: + model: The model to sample from. + x_t: The [N x C x ...] tensor of noisy inputs at time t. + t: The current timestep. + t_prev: The previous timestep. + cond: conditional information. + **kwargs: Additional arguments for model inference. + + Returns: + a dict containing the following + - 'pred_x_prev': x_{t-1}. + - 'pred_x_0': a prediction of x_0. + """ + torch.cuda.empty_cache() + with torch.no_grad(): + pred_x_0, pred_eps, pred_v = self._get_model_prediction(model, x_t, t, cond, **kwargs) + pred_v_opt = torch.nn.Parameter(pred_v.detach().clone()) + optimizer = torch.optim.Adam([pred_v_opt], betas=(0.5, 0.9), lr=learning_rate) + total_steps = 5 + with tqdm(total=total_steps, disable=True, desc='Sparse Structure (opt): optimizing') as pbar: + for step in range(total_steps): + optimizer.zero_grad() + pred_x_0, _ = self._v_to_xstart_eps(x_t=x_t, t=t, v=pred_v_opt) + logits = F.sigmoid(ss_decoder(pred_x_0)) + loss = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1) + loss.backward() + optimizer.step() + pbar.set_postfix({'loss': loss.item()}) + pbar.update() + + pred_x_prev = x_t - (t - t_prev) * pred_v_opt.detach() + torch.cuda.empty_cache() + return edict({"pred_x_prev": pred_x_prev, "pred_x_0": pred_x_0, "pred_eps": pred_eps}) + + def sample_slat_once_opt_delta_v( self, model, slat_decoder_gs, @@ -178,7 +227,11 @@ class FlowEulerSampler(Sampler): rend_gs = render_utils.render_frames(pred_gs[0], extrinsics, intrinsics, {'resolution': 259, 'bg_color': (0, 0, 0)}, need_depth=True, opt=True)['color'] # rend_mesh = render_utils.render_frames_opt(pred_mesh[0], extrinsics, intrinsics, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)['color'] rend_gs = torch.stack(rend_gs, dim=0) - loss_gs = loss_utils.l1_loss(rend_gs, input_images) + (1 - loss_utils.ssim(rend_gs, input_images)) + loss_utils.lpips(rend_gs, input_images) + dreamsim_model(rend_gs, input_images).mean() + loss_gs = loss_utils.l1_loss(rend_gs, input_images, size_average=False).mean(dim=(1,2,3)) + \ + (1 - loss_utils.ssim(rend_gs, input_images, size_average=False)) + \ + loss_utils.lpips(rend_gs, input_images, size_average=False).mean(dim=(1,2,3)) + \ + dreamsim_model(rend_gs, input_images) + loss_gs = loss_gs[loss_gs <= 0.8].mean() # loss_gs = (1 - loss_utils.ssim(rend_gs, input_images)) + loss_utils.lpips(rend_gs, input_images) + dreamsim_model(rend_gs, input_images).mean() # loss_mesh = loss_utils.l1_loss(rend_mesh, input_images) + 0.2 * (1 - loss_utils.ssim(rend_mesh, input_images)) + 0.2 * loss_utils.lpips(rend_mesh, input_images) loss = loss_gs + 0.2 * loss_utils.l1_loss(pred_v_opt_feat, pred_v.feats) @@ -232,7 +285,62 @@ class FlowEulerSampler(Sampler): ret.samples = sample return ret - def sample_opt_delta_v( + def sample_ss_opt_delta_v( + self, + model, + ss_decoder, + ss_learning_rate, + ss_start_t, + ss, + noise, + cond: Optional[Any] = None, + steps: int = 50, + rescale_t: float = 1.0, + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + sample = noise + t_seq = np.linspace(1, 0, steps + 1) + t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq) + t_pairs = list((t_seq[i], t_seq[i + 1]) for i in range(steps)) + ret = edict({"samples": None, "pred_x_t": [], "pred_x_0": []}) + # def cosine_anealing(step, total_steps, start_lr, end_lr): + # return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps)) + for i, (t, t_prev) in enumerate(tqdm(t_pairs, desc="Sampling", disable=not verbose)): + if t > ss_start_t: + out = self.sample_once(model, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + else: + # learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5) + learning_rate = ss_learning_rate + out = self.sample_ss_once_opt_delta_v(model, ss_decoder, ss_learning_rate, ss, sample, t, t_prev, cond, **kwargs) + sample = out.pred_x_prev + ret.pred_x_t.append(out.pred_x_prev) + ret.pred_x_0.append(out.pred_x_0) + ret.samples = sample + return ret + + def sample_slat_opt_delta_v( self, model, slat_decoder_gs, @@ -284,7 +392,7 @@ class FlowEulerSampler(Sampler): else: # learning_rate = cosine_anealing(i - int(np.where(t_seq <= start_t)[0].min()), int(steps - np.where(t_seq <= start_t)[0].min()), apperance_learning_rate, 1e-5) learning_rate = apperance_learning_rate - out = self.sample_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs) + out = self.sample_slat_once_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, learning_rate, input_images, extrinsics, intrinsics, sample, t, t_prev, cond, **kwargs) sample = out.pred_x_prev ret.pred_x_t.append(out.pred_x_prev) ret.pred_x_0.append(out.pred_x_0) @@ -711,7 +819,48 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa """ return super().sample_opt(model, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) - def sample_opt_delta_v( + def sample_ss_opt_delta_v( + self, + model, + ss_decoder, + ss_learning_rate, + ss_start_t, + ss, + noise, + cond, + neg_cond, + steps: int = 50, + rescale_t: float = 1.0, + cfg_strength: float = 3.0, + cfg_interval: Tuple[float, float] = (0.0, 1.0), + verbose: bool = True, + **kwargs + ): + """ + Generate samples from the model using Euler method. + + Args: + model: The model to sample from. + noise: The initial noise tensor. + cond: conditional information. + neg_cond: negative conditional information. + steps: The number of steps to sample. + rescale_t: The rescale factor for t. + cfg_strength: The strength of classifier-free guidance. + cfg_interval: The interval for classifier-free guidance. + verbose: If True, show a progress bar. + **kwargs: Additional arguments for model_inference. + + Returns: + a dict containing the following + - 'samples': the model samples. + - 'pred_x_t': a list of prediction of x_t. + - 'pred_x_0': a list of prediction of x_0. + """ + return super().sample_ss_opt_delta_v(model, ss_decoder, ss_learning_rate, ss_start_t, ss, noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + + + def sample_slat_opt_delta_v( self, model, slat_decoder_gs, @@ -753,7 +902,7 @@ class FlowEulerGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, FlowEulerSa - 'pred_x_t': a list of prediction of x_t. - 'pred_x_0': a list of prediction of x_0. """ - return super().sample_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) + return super().sample_slat_opt_delta_v(model, slat_decoder_gs, slat_decoder_mesh, dreamsim_model, apperance_learning_rate, start_t, input_images, extrinsics, intrinsics,noise, cond, steps, rescale_t, verbose, neg_cond=neg_cond, cfg_strength=cfg_strength, cfg_interval=cfg_interval, **kwargs) class LatentMatchGuidanceIntervalSampler(GuidanceIntervalSamplerMixin, LatentMatchSampler): diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py index ce7ea715964c0af6dc82d481472b8c81a0de3ac4..2bcd437e638eba268321cc61effbdbab1c32c1cc 100644 --- a/trellis/pipelines/trellis_image_to_3d.py +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -19,6 +19,7 @@ from typing import * from scipy.spatial.transform import Rotation from transformers import AutoModelForImageSegmentation import rembg +from dreamsim import dreamsim def export_point_cloud(xyz, color): # Convert tensors to numpy arrays if needed @@ -429,6 +430,51 @@ class TrellisImageTo3DPipeline(Pipeline): coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() return coords + + def sample_sparse_structure_opt( + self, + cond: dict, + ss: torch.Tensor, + ss_learning_rate: float=1e-1, + ss_start_t: float=0.6, + num_samples: int = 1, + sampler_params: dict = {}, + noise: torch.Tensor = None, + ) -> torch.Tensor: + """ + Sample sparse structures with the given conditioning. + + Args: + cond (dict): The conditioning information. + num_samples (int): The number of samples to generate. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample occupancy latent + flow_model = self.models['sparse_structure_flow_model'] + ss_decoder = self.models['sparse_structure_decoder'] + ss = ss.float() + reso = flow_model.resolution + if noise is None: + noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) + sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} + z_s = self.sparse_structure_sampler.sample_ss_opt_delta_v( + flow_model, + ss_decoder, + ss_learning_rate, + ss_start_t, + ss, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + # Decode occupancy latent + decoder = self.models['sparse_structure_decoder'] + coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() + + return coords + def encode_slat( self, slat: sp.SparseTensor, @@ -498,6 +544,66 @@ class TrellisImageTo3DPipeline(Pipeline): slat = slat * std + mean return slat + def sample_slat_opt( + self, + apperance_learning_rate, + start_t, + input_images: torch.Tensor, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + cond: dict, + coords: torch.Tensor, + sampler_params: dict = {}, + ) -> sp.SparseTensor: + """ + Sample structured latent with the given conditioning. + + Args: + cond (dict): The conditioning information. + coords (torch.Tensor): The coordinates of the sparse structure. + sampler_params (dict): Additional parameters for the sampler. + """ + # Sample structured latent + flow_model = self.models['slat_flow_model'] + slat_decoder_gs = self.models['slat_decoder_gs'] + slat_decoder_mesh = self.models['slat_decoder_mesh'] + noise = sp.SparseTensor( + feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), + coords=coords, + ) + sampler_params = {**self.slat_sampler_params, **sampler_params} + slat = self.slat_sampler.sample_slat_opt_delta_v( + flow_model, + slat_decoder_gs, + slat_decoder_mesh, + self.dreamsim_model, + apperance_learning_rate, + start_t, + input_images, + extrinsics, + intrinsics, + noise, + **cond, + **sampler_params, + verbose=True + ).samples + + std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) + mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) + slat = slat * std + mean + # from trellis.utils import render_utils, postprocessing_utils + # import imageio + # std = torch.tensor(self.slat_normalization['std'])[None].to(noise.device) + # mean = torch.tensor(self.slat_normalization['mean'])[None].to(noise.device) + # for i in range(sampler_params['steps']): + # latent = slat.pred_x_0[i] * std + mean + # outputs = self.decode_slat(latent, ["mesh", "gaussian"]) + # video_geo = render_utils.render_video(outputs['mesh'][0], resolution=512, pitch=0, inverse_direction=True, num_frames=120)['normal'] + # video_color = render_utils.render_video(outputs['gaussian'][0], resolution=512, pitch=0, inverse_direction=True, num_frames=120)['color'] + # video = [np.concatenate([video_color[i], video_geo[i]], axis=1) for i in range(len(video_color))] + # imageio.mimsave('outputs/slat_iter_{i:02d}.mp4'.format(i=i), video, fps=15) + return slat + def get_input(self, batch_data): std = torch.tensor(self.slat_normalization['std'])[None].to(self.device) mean = torch.tensor(self.slat_normalization['mean'])[None].to(self.device) @@ -693,6 +799,25 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): 'cond': cond, 'neg_cond': neg_cond, } + + def get_slat_cond(self, image_cond: torch.Tensor, aggregated_tokens_list: List, num_samples: int) -> dict: + """ + Get the conditioning information for the model. + + Args: + image (Union[torch.Tensor, list[Image.Image]]): The image prompts. + + Returns: + dict: The conditioning information + """ + b, n, _, _ = aggregated_tokens_list[0].shape + cond = self.slat_vggt_cond(aggregated_tokens_list, image_cond).reshape(b, n, -1, 1024) + cond = [c.squeeze(1) for c in cond.split(1, dim=1)] + neg_cond = [torch.zeros_like(c) for c in cond] + return { + 'cond': cond, + 'neg_cond': neg_cond, + } @torch.no_grad() def vggt_feat(self, image: Union[torch.Tensor, list[Image.Image]]) -> List: """ @@ -748,7 +873,7 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): ss_sampler_params = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params} reso = ss_flow_model.resolution ss_noise = torch.randn(num_samples, ss_flow_model.in_channels, reso, reso, reso).to(self.device) - ss_slat = self.sparse_structure_sampler.sample( + ss_latent = self.sparse_structure_sampler.sample( ss_flow_model, ss_noise, **ss_cond, @@ -757,17 +882,74 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): ).samples decoder = self.models['sparse_structure_decoder'] - coords = torch.argwhere(decoder(ss_slat)>0)[:, [0, 2, 3, 4]].int() + coords = torch.argwhere(decoder(ss_latent)>0)[:, [0, 2, 3, 4]].int() - cond = { - 'cond': image_cond.reshape(n, -1, 1024), - 'neg_cond': torch.zeros_like(image_cond.reshape(n, -1, 1024))[:1], - } + # cond = { + # 'cond': image_cond.reshape(n, -1, 1024), + # 'neg_cond': torch.zeros_like(image_cond.reshape(n, -1, 1024))[:1], + # } + # slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') + # with self.inject_sampler_multi_image('slat_sampler', len(image), slat_steps, mode=mode): + # slat = self.sample_slat(cond, coords, slat_sampler_params) + + slat_cond = self.get_slat_cond(image_cond, aggregated_tokens_list, num_samples) + slat = self.sample_slat(slat_cond, coords, slat_sampler_params) + return self.decode_slat(slat, formats), coords, ss_noise + + def run_refine( + self, + image: Union[torch.Tensor, list[Image.Image]], + ss_learning_rate: float, + ss_start_t: float, + apperance_learning_rate: float, + apperance_start_t: float, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + ss_noise: torch.Tensor, + input_points: torch.Tensor, + coords: torch.Tensor = None, + num_samples: int = 1, + seed: int = 42, + sparse_structure_sampler_params: dict = {}, + slat_sampler_params: dict = {}, + formats: List[str] = ['mesh'], + mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', + ): + + torch.manual_seed(seed) + aggregated_tokens_list, input_images = self.vggt_feat(image) + b, n, _, _ = aggregated_tokens_list[0].shape + image_cond = self.encode_image(image).reshape(b, n, -1, 1024) - slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') - with self.inject_sampler_multi_image('slat_sampler', len(image), slat_steps, mode=mode): - slat = self.sample_slat(cond, coords, slat_sampler_params) + if coords is None: + ss_cond = self.get_ss_cond(image_cond[:, :, 5:], aggregated_tokens_list, num_samples) + ss = torch.zeros(64, 64, 64, dtype=torch.long, device=image_cond.device) + ss = ss.index_put_((input_points[:,0], input_points[:,1], input_points[:,2]), torch.tensor(1, dtype=ss.dtype, device=ss.device)) + ss = ss[None, None] + torch.cuda.empty_cache() + # Sample structured latent + coords = self.sample_sparse_structure_opt(ss_cond, ss, ss_learning_rate, ss_start_t, num_samples, sparse_structure_sampler_params) + torch.cuda.empty_cache() + + # pcd = o3d.geometry.PointCloud() + # pcd.points = o3d.utility.Vector3dVector(coords[:,1:].cpu().numpy() / 64 - 0.5) + # o3d.io.write_point_cloud('outputs/after_coords.ply', pcd) + + # cond = { + # 'cond': image_cond.reshape(n, -1, 1024), + # 'neg_cond': torch.zeros_like(image_cond.reshape(n, -1, 1024))[:1], + # } + + # slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') + + # with self.inject_sampler_multi_image('slat_sampler', len(image), slat_steps, mode=mode): + # # slat = self.sample_slat(cond, coords, slat_sampler_params) + # slat = self.sample_slat_opt(apperance_learning_rate, apperance_start_t, input_images, extrinsics, intrinsics, cond, coords, slat_sampler_params) + + slat_cond = self.get_slat_cond(image_cond, aggregated_tokens_list, num_samples) + slat = self.sample_slat_opt(apperance_learning_rate, apperance_start_t, input_images, extrinsics, intrinsics, slat_cond, coords, slat_sampler_params) return self.decode_slat(slat, formats) + @staticmethod def from_pretrained(path: str) -> "TrellisVGGTTo3DPipeline": """ @@ -785,8 +967,8 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): new_pipeline.VGGT_model = VGGT_model.to(new_pipeline.device) del new_pipeline.VGGT_model.depth_head del new_pipeline.VGGT_model.track_head - del new_pipeline.VGGT_model.camera_head - del new_pipeline.VGGT_model.point_head + # del new_pipeline.VGGT_model.camera_head + # del new_pipeline.VGGT_model.point_head new_pipeline.VGGT_model.eval() new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained( @@ -805,4 +987,8 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): new_pipeline._init_image_cond_model(args['image_cond_model']) + model, _ = dreamsim(pretrained=True, device=new_pipeline.device, dreamsim_type="dino_vitb16", cache_dir="weights/dreamsim") + new_pipeline.dreamsim_model = model + new_pipeline.dreamsim_model.eval() + return new_pipeline \ No newline at end of file diff --git a/trellis/renderers/__pycache__/__init__.cpython-310.pyc b/trellis/renderers/__pycache__/__init__.cpython-310.pyc index b2ad76fd451195e5881a48a6ffb8ef97effca1f6..46ad0022f5d383bdb083adfc558fd9ccc813e668 100644 Binary files a/trellis/renderers/__pycache__/__init__.cpython-310.pyc and b/trellis/renderers/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc b/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc index 1c77d3c9f306cbb2883bebbf48428313b892eecd..1f0ee9a37eeedd3c98ee423d4f7c5a7a112ae0ad 100644 Binary files a/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc and b/trellis/renderers/__pycache__/gaussian_render.cpython-310.pyc differ diff --git a/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc b/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc index 56c182a03e1cb8448861f30b66e6cfa82968b02d..ceefb1d7a24b971e8cec10f3b24cd4584e428884 100644 Binary files a/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc and b/trellis/renderers/__pycache__/mesh_renderer.cpython-310.pyc differ diff --git a/trellis/representations/__pycache__/__init__.cpython-310.pyc b/trellis/representations/__pycache__/__init__.cpython-310.pyc index 375729efd5bcdff2e52bb7e672e787e7a9f34e6c..acf68b5d76aa20ed8e7ed90e8c11bc0ae2d1adcc 100644 Binary files a/trellis/representations/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc index b7b49bf7d9175d445790ba1370a0296b8a0043c0..2414368d692e47e084bbdd76f7fb80b6d1c9b007 100644 Binary files a/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc index 342ee620968a18e8df2da9066d78935b0f6a97c2..d99a2d8dbff8c82bda44c566367c95bd5aea69ba 100644 Binary files a/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/gaussian_model.cpython-310.pyc differ diff --git a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc index 2bc3af5fdc8f1ea1fbfecf0bb46ed229c56e9849..faab35b28f7414058f0dfcbd89e6446f1af4a0f8 100644 Binary files a/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc and b/trellis/representations/gaussian/__pycache__/general_utils.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc b/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc index 15addeb8a3d731f62e6cf2b68da2a8b5f3f952b9..31415534f22dac7d3026914e490b7c74312de812 100644 Binary files a/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc b/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc index fb331fd6fccc0ce8510ba2261134b29722dbe9af..8e4a1d721bd81dc15a1ef8f0fe9f416a01e9b8af 100644 Binary files a/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/cube2mesh.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc b/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc index 460812611dfbe0bef6900741bba44122cfd5d78a..73992f1151d4d10994d9eb37548c4c9d42a62671 100644 Binary files a/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/flexicube.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc b/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc index df1471dc56f41acb01e4ea455b920e80c23fbedd..43acf53d06c096b8dd210903964b0686d8d3dd6d 100644 Binary files a/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/mc2mesh.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc b/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc index 1117f4450a5ea87b2a2a6240edb6c9563b4e9a30..39271b5aef8b1a095cf85a8ac12bd4283a233f62 100644 Binary files a/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/tables.cpython-310.pyc differ diff --git a/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc b/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc index 89876dfe8a48829eb8f9d3b9fc00b9db2521ba24..474ffc7b6e5304c695535dceec97696d0fb2cadc 100644 Binary files a/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc and b/trellis/representations/mesh/__pycache__/utils_cube.cpython-310.pyc differ diff --git a/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc b/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc index 6aae217a4c1ccdf3ffb8ce5cae03d4818363aff7..1aef1eec30179df6dc019b195a84e766e9319941 100644 Binary files a/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc and b/trellis/representations/octree/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc b/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc index c9d902ff45e85e9f26a4362e4f44bf7f0a6ed18b..63b1788308649bc8137e22cf0f2f9914ae98bd23 100644 Binary files a/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc and b/trellis/representations/octree/__pycache__/octree_dfs.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/__init__.cpython-310.pyc b/trellis/utils/__pycache__/__init__.cpython-310.pyc index 6e0e70298e509d245b1136414c365f875e61e067..eeed7e60a797484dffe87636cfb489810bc42c06 100644 Binary files a/trellis/utils/__pycache__/__init__.cpython-310.pyc and b/trellis/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/general_utils.cpython-310.pyc b/trellis/utils/__pycache__/general_utils.cpython-310.pyc index 6bebee9d6fc3393c74bb6fa8c1edc2c85bb6f75f..163baa174a678aa974a4fecf16ec2e63374aed8b 100644 Binary files a/trellis/utils/__pycache__/general_utils.cpython-310.pyc and b/trellis/utils/__pycache__/general_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/loss_utils.cpython-310.pyc b/trellis/utils/__pycache__/loss_utils.cpython-310.pyc index 715fb0164872d83570fd0181b37c7e0ad098e1ee..a250c4af7a0c90db50161617596ce3c1ad98fb8c 100644 Binary files a/trellis/utils/__pycache__/loss_utils.cpython-310.pyc and b/trellis/utils/__pycache__/loss_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc b/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc index 6491f77d28ed35075e19b70078649f30ed68ba25..f96811da400b0bbab850b8496b2482e34f887089 100644 Binary files a/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc and b/trellis/utils/__pycache__/postprocessing_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/random_utils.cpython-310.pyc b/trellis/utils/__pycache__/random_utils.cpython-310.pyc index 1c073c3a39622357e677185757e72e1e8b95f0f6..033d6c88ef31db8db20a344c6a62107b49adcbad 100644 Binary files a/trellis/utils/__pycache__/random_utils.cpython-310.pyc and b/trellis/utils/__pycache__/random_utils.cpython-310.pyc differ diff --git a/trellis/utils/__pycache__/render_utils.cpython-310.pyc b/trellis/utils/__pycache__/render_utils.cpython-310.pyc index a6eeeaca4eabb83d4cf565cdbcbec649cf3268b1..fa0ab75c1093dc804b5f56fedf9ab2477692ee0a 100644 Binary files a/trellis/utils/__pycache__/render_utils.cpython-310.pyc and b/trellis/utils/__pycache__/render_utils.cpython-310.pyc differ diff --git a/trellis/utils/loss_utils.py b/trellis/utils/loss_utils.py index 52049f69543f2700bc5525b09cbf2fb25c08aa9e..6629756cca81ab2ce8a396acb294089abe9e8763 100644 --- a/trellis/utils/loss_utils.py +++ b/trellis/utils/loss_utils.py @@ -11,8 +11,11 @@ def smooth_l1_loss(pred, target, beta=1.0): return loss.mean() -def l1_loss(network_output, gt): - return torch.abs((network_output - gt)).mean() +def l1_loss(network_output, gt, size_average=True): + if size_average: + return torch.abs((network_output - gt)).mean() + else: + return torch.abs((network_output - gt)) def l2_loss(network_output, gt): @@ -70,14 +73,17 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): loss_fn_vgg = None -def lpips(img1, img2, value_range=(0, 1)): +def lpips(img1, img2, value_range=(0, 1), size_average=True): global loss_fn_vgg if loss_fn_vgg is None: loss_fn_vgg = LPIPS(net='vgg').cuda().eval() # normalize to [-1, 1] img1 = (img1 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 img2 = (img2 - value_range[0]) / (value_range[1] - value_range[0]) * 2 - 1 - return loss_fn_vgg(img1, img2).mean() + if size_average: + return loss_fn_vgg(img1, img2).mean() + else: + return loss_fn_vgg(img1, img2) def normal_angle(pred, gt):