Commit
·
4a8a6dd
1
Parent(s):
30e86c0
Initial commit
Browse files
app.py
CHANGED
|
@@ -120,7 +120,10 @@ class FaceLiftPipeline:
|
|
| 120 |
workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
|
| 121 |
map_location="cpu"
|
| 122 |
)
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
# Keep on CPU initially - will move to GPU in decorated function
|
| 125 |
|
| 126 |
self.color_prompt_embedding = torch.load(
|
|
|
|
| 120 |
workspace_dir / "checkpoints/gslrm/ckpt_0000000000021125.pt",
|
| 121 |
map_location="cpu"
|
| 122 |
)
|
| 123 |
+
# Filter out loss_calculator weights (training-only, not needed for inference)
|
| 124 |
+
state_dict = {k: v for k, v in checkpoint["model"].items()
|
| 125 |
+
if not k.startswith("loss_calculator.")}
|
| 126 |
+
self.gs_lrm_model.load_state_dict(state_dict)
|
| 127 |
# Keep on CPU initially - will move to GPU in decorated function
|
| 128 |
|
| 129 |
self.color_prompt_embedding = torch.load(
|