Commit ·
2bd73d5
1
Parent(s): 9774d79
Add loading ckpt logic from state_dict and not directly from model class
Browse files- model_blocks/controlnet.py +16 -7
- pyproject.toml +1 -0
- requirements.txt +1 -0
- uv.lock +14 -0
model_blocks/controlnet.py
CHANGED
|
@@ -22,7 +22,7 @@ class ControlNet(nn.Module):
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
def __init__(
|
| 25 |
-
self, device, model_config,
|
| 26 |
) -> None:
|
| 27 |
super().__init__()
|
| 28 |
|
|
@@ -30,14 +30,23 @@ class ControlNet(nn.Module):
|
|
| 30 |
self.model = UNet(model_config)
|
| 31 |
self.model_locked = model_locked
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
|
|
|
| 38 |
self.control_copy = UNet(model_config, use_up=False)
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# Hint Block for ControlNet
|
| 43 |
# Stack of Conv Activation and Zero Convolution at the end
|
|
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
def __init__(
|
| 25 |
+
self, device, model_config, model_ckpt=None, model_locked=True
|
| 26 |
) -> None:
|
| 27 |
super().__init__()
|
| 28 |
|
|
|
|
| 30 |
self.model = UNet(model_config)
|
| 31 |
self.model_locked = model_locked
|
| 32 |
|
| 33 |
+
# Load weights for the trained model
|
| 34 |
+
if model_ckpt is not None and device is not None:
|
| 35 |
+
print("Loading Trained Diffusion Model")
|
| 36 |
+
self.model.load_state_dict(
|
| 37 |
+
torch.load(model_ckpt, map_location=device), strict=True
|
| 38 |
+
)
|
| 39 |
|
| 40 |
+
# ControlNet Copy of Trained DDPM
|
| 41 |
+
# use_up = False removes the upblocks(decoder layers) from DDPM Unet
|
| 42 |
self.control_copy = UNet(model_config, use_up=False)
|
| 43 |
+
# Load same weights as the trained model
|
| 44 |
+
|
| 45 |
+
if model_ckpt is not None and device is not None:
|
| 46 |
+
print("Loading Control Diffusion Model")
|
| 47 |
+
self.control_copy.load_state_dict(
|
| 48 |
+
torch.load(model_ckpt, map_location=device), strict=False
|
| 49 |
+
)
|
| 50 |
|
| 51 |
# Hint Block for ControlNet
|
| 52 |
# Stack of Conv Activation and Zero Convolution at the end
|
pyproject.toml
CHANGED
|
@@ -9,5 +9,6 @@ dependencies = [
|
|
| 9 |
"pandas>=2.2.3",
|
| 10 |
"torch>=2.6.0",
|
| 11 |
"torchvision>=0.21.0",
|
|
|
|
| 12 |
"wandb>=0.19.9",
|
| 13 |
]
|
|
|
|
| 9 |
"pandas>=2.2.3",
|
| 10 |
"torch>=2.6.0",
|
| 11 |
"torchvision>=0.21.0",
|
| 12 |
+
"tqdm>=4.67.1",
|
| 13 |
"wandb>=0.19.9",
|
| 14 |
]
|
requirements.txt
CHANGED
|
@@ -45,6 +45,7 @@ smmap==5.0.2
|
|
| 45 |
sympy==1.13.1
|
| 46 |
torch==2.6.0
|
| 47 |
torchvision==0.21.0
|
|
|
|
| 48 |
triton==3.2.0
|
| 49 |
typing-extensions==4.13.0
|
| 50 |
typing-inspection==0.4.0
|
|
|
|
| 45 |
sympy==1.13.1
|
| 46 |
torch==2.6.0
|
| 47 |
torchvision==0.21.0
|
| 48 |
+
tqdm==4.67.1
|
| 49 |
triton==3.2.0
|
| 50 |
typing-extensions==4.13.0
|
| 51 |
typing-inspection==0.4.0
|
uv.lock
CHANGED
|
@@ -76,6 +76,7 @@ dependencies = [
|
|
| 76 |
{ name = "pandas" },
|
| 77 |
{ name = "torch" },
|
| 78 |
{ name = "torchvision" },
|
|
|
|
| 79 |
{ name = "wandb" },
|
| 80 |
]
|
| 81 |
|
|
@@ -85,6 +86,7 @@ requires-dist = [
|
|
| 85 |
{ name = "pandas", specifier = ">=2.2.3" },
|
| 86 |
{ name = "torch", specifier = ">=2.6.0" },
|
| 87 |
{ name = "torchvision", specifier = ">=0.21.0" },
|
|
|
|
| 88 |
{ name = "wandb", specifier = ">=0.19.9" },
|
| 89 |
]
|
| 90 |
|
|
@@ -666,6 +668,18 @@ wheels = [
|
|
| 666 |
{ url = "https://files.pythonhosted.org/packages/ed/b4/fc60e3bc003879d3de842baea258fffc3586f4b49cd435a5ba1e09c33315/torchvision-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:9147f5e096a9270684e3befdee350f3cacafd48e0c54ab195f45790a9c146d67", size = 1560519 },
|
| 667 |
]
|
| 668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
[[package]]
|
| 670 |
name = "triton"
|
| 671 |
version = "3.2.0"
|
|
|
|
| 76 |
{ name = "pandas" },
|
| 77 |
{ name = "torch" },
|
| 78 |
{ name = "torchvision" },
|
| 79 |
+
{ name = "tqdm" },
|
| 80 |
{ name = "wandb" },
|
| 81 |
]
|
| 82 |
|
|
|
|
| 86 |
{ name = "pandas", specifier = ">=2.2.3" },
|
| 87 |
{ name = "torch", specifier = ">=2.6.0" },
|
| 88 |
{ name = "torchvision", specifier = ">=0.21.0" },
|
| 89 |
+
{ name = "tqdm", specifier = ">=4.67.1" },
|
| 90 |
{ name = "wandb", specifier = ">=0.19.9" },
|
| 91 |
]
|
| 92 |
|
|
|
|
| 668 |
{ url = "https://files.pythonhosted.org/packages/ed/b4/fc60e3bc003879d3de842baea258fffc3586f4b49cd435a5ba1e09c33315/torchvision-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:9147f5e096a9270684e3befdee350f3cacafd48e0c54ab195f45790a9c146d67", size = 1560519 },
|
| 669 |
]
|
| 670 |
|
| 671 |
+
[[package]]
|
| 672 |
+
name = "tqdm"
|
| 673 |
+
version = "4.67.1"
|
| 674 |
+
source = { registry = "https://pypi.org/simple" }
|
| 675 |
+
dependencies = [
|
| 676 |
+
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
| 677 |
+
]
|
| 678 |
+
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
|
| 679 |
+
wheels = [
|
| 680 |
+
{ url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
|
| 681 |
+
]
|
| 682 |
+
|
| 683 |
[[package]]
|
| 684 |
name = "triton"
|
| 685 |
version = "3.2.0"
|