YashNagraj75 commited on
Commit
2bd73d5
·
1 Parent(s): 9774d79

Add loading ckpt logic from state_dict and not directly from model class

Browse files
Files changed (4) hide show
  1. model_blocks/controlnet.py +16 -7
  2. pyproject.toml +1 -0
  3. requirements.txt +1 -0
  4. uv.lock +14 -0
model_blocks/controlnet.py CHANGED
@@ -22,7 +22,7 @@ class ControlNet(nn.Module):
22
  """
23
 
24
  def __init__(
25
- self, device, model_config, trained_ckpt_path=None, model_locked=True
26
  ) -> None:
27
  super().__init__()
28
 
@@ -30,14 +30,23 @@ class ControlNet(nn.Module):
30
  self.model = UNet(model_config)
31
  self.model_locked = model_locked
32
 
33
- if trained_ckpt_path is not None:
34
- print("Loading Checkpoint")
35
- self.model = torch.load(trained_ckpt_path).to(device)
 
 
 
36
 
37
- # False the upblocks (Decoder blocks) from the DDPM and uses only the encoder
 
38
  self.control_copy = UNet(model_config, use_up=False)
39
- if trained_ckpt_path is not None:
40
- self.control_copy.load_state_dict(self.model.state_dict(), strict=False)
 
 
 
 
 
41
 
42
  # Hint Block for ControlNet
43
  # Stack of Conv Activation and Zero Convolution at the end
 
22
  """
23
 
24
  def __init__(
25
+ self, device, model_config, model_ckpt=None, model_locked=True
26
  ) -> None:
27
  super().__init__()
28
 
 
30
  self.model = UNet(model_config)
31
  self.model_locked = model_locked
32
 
33
+ # Load weights for the trained model
34
+ if model_ckpt is not None and device is not None:
35
+ print("Loading Trained Diffusion Model")
36
+ self.model.load_state_dict(
37
+ torch.load(model_ckpt, map_location=device), strict=True
38
+ )
39
 
40
+ # ControlNet Copy of Trained DDPM
41
+ # use_up = False removes the upblocks(decoder layers) from DDPM Unet
42
  self.control_copy = UNet(model_config, use_up=False)
43
+ # Load same weights as the trained model
44
+
45
+ if model_ckpt is not None and device is not None:
46
+ print("Loading Control Diffusion Model")
47
+ self.control_copy.load_state_dict(
48
+ torch.load(model_ckpt, map_location=device), strict=False
49
+ )
50
 
51
  # Hint Block for ControlNet
52
  # Stack of Conv Activation and Zero Convolution at the end
pyproject.toml CHANGED
@@ -9,5 +9,6 @@ dependencies = [
9
  "pandas>=2.2.3",
10
  "torch>=2.6.0",
11
  "torchvision>=0.21.0",
 
12
  "wandb>=0.19.9",
13
  ]
 
9
  "pandas>=2.2.3",
10
  "torch>=2.6.0",
11
  "torchvision>=0.21.0",
12
+ "tqdm>=4.67.1",
13
  "wandb>=0.19.9",
14
  ]
requirements.txt CHANGED
@@ -45,6 +45,7 @@ smmap==5.0.2
45
  sympy==1.13.1
46
  torch==2.6.0
47
  torchvision==0.21.0
 
48
  triton==3.2.0
49
  typing-extensions==4.13.0
50
  typing-inspection==0.4.0
 
45
  sympy==1.13.1
46
  torch==2.6.0
47
  torchvision==0.21.0
48
+ tqdm==4.67.1
49
  triton==3.2.0
50
  typing-extensions==4.13.0
51
  typing-inspection==0.4.0
uv.lock CHANGED
@@ -76,6 +76,7 @@ dependencies = [
76
  { name = "pandas" },
77
  { name = "torch" },
78
  { name = "torchvision" },
 
79
  { name = "wandb" },
80
  ]
81
 
@@ -85,6 +86,7 @@ requires-dist = [
85
  { name = "pandas", specifier = ">=2.2.3" },
86
  { name = "torch", specifier = ">=2.6.0" },
87
  { name = "torchvision", specifier = ">=0.21.0" },
 
88
  { name = "wandb", specifier = ">=0.19.9" },
89
  ]
90
 
@@ -666,6 +668,18 @@ wheels = [
666
  { url = "https://files.pythonhosted.org/packages/ed/b4/fc60e3bc003879d3de842baea258fffc3586f4b49cd435a5ba1e09c33315/torchvision-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:9147f5e096a9270684e3befdee350f3cacafd48e0c54ab195f45790a9c146d67", size = 1560519 },
667
  ]
668
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  [[package]]
670
  name = "triton"
671
  version = "3.2.0"
 
76
  { name = "pandas" },
77
  { name = "torch" },
78
  { name = "torchvision" },
79
+ { name = "tqdm" },
80
  { name = "wandb" },
81
  ]
82
 
 
86
  { name = "pandas", specifier = ">=2.2.3" },
87
  { name = "torch", specifier = ">=2.6.0" },
88
  { name = "torchvision", specifier = ">=0.21.0" },
89
+ { name = "tqdm", specifier = ">=4.67.1" },
90
  { name = "wandb", specifier = ">=0.19.9" },
91
  ]
92
 
 
668
  { url = "https://files.pythonhosted.org/packages/ed/b4/fc60e3bc003879d3de842baea258fffc3586f4b49cd435a5ba1e09c33315/torchvision-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:9147f5e096a9270684e3befdee350f3cacafd48e0c54ab195f45790a9c146d67", size = 1560519 },
669
  ]
670
 
671
+ [[package]]
672
+ name = "tqdm"
673
+ version = "4.67.1"
674
+ source = { registry = "https://pypi.org/simple" }
675
+ dependencies = [
676
+ { name = "colorama", marker = "sys_platform == 'win32'" },
677
+ ]
678
+ sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
679
+ wheels = [
680
+ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 },
681
+ ]
682
+
683
  [[package]]
684
  name = "triton"
685
  version = "3.2.0"