AbstractPhil commited on
Commit
8a6546f
·
verified ·
1 Parent(s): f434818

Update bridge/latent_bridge.py

Browse files
Files changed (1) hide show
  1. bridge/latent_bridge.py +23 -14
bridge/latent_bridge.py CHANGED
@@ -9,9 +9,9 @@ from diffusers import ModelMixin
9
  class LatentBridge(ModelMixin):
10
  def __init__(self, in_ch: int = 4, out_ch: int = 16):
11
  super().__init__()
12
- # dec: 4→16, enc: 16→4
13
- self.dec = nn.Conv2d(in_ch, out_ch, kernel_size=1)
14
  self.enc = nn.Conv2d(out_ch, in_ch, kernel_size=1)
 
15
 
16
  def forward(self, x):
17
  return x
@@ -20,31 +20,40 @@ class LatentBridge(ModelMixin):
20
  def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
21
  sub = kwargs.get("subfolder", "")
22
  base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
23
- with open(os.path.join(base, "config.json"), "r") as f:
 
24
  cfg = json.load(f)
25
  cfg.pop("_class_name", None)
26
  return cls(**cfg)
27
 
28
  @classmethod
29
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
30
- # 1) Build instance from config
31
  bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
32
 
33
- # 2) Load raw weights
34
  sub = kwargs.get("subfolder", "")
35
  base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
36
  raw_sd = torch.load(os.path.join(base, "pytorch_model.bin"), map_location="cpu")
37
 
38
- # 3) Remap keys
39
- remapped = {}
 
 
 
 
 
 
 
 
40
  for k, v in raw_sd.items():
41
- if k.startswith("to4."):
42
- remapped[k.replace("to4", "enc", 1)] = v
43
- elif k.startswith("to16."):
44
- remapped[k.replace("to16", "dec", 1)] = v
45
  else:
46
- remapped[k] = v
47
 
48
- # 4) **Load** the remapped dict into your layers
49
- bridge.load_state_dict(remapped)
50
  return bridge
 
9
  class LatentBridge(ModelMixin):
10
  def __init__(self, in_ch: int = 4, out_ch: int = 16):
11
  super().__init__()
12
+ # enc: 16→4 channels, dec: 4→16 channels
 
13
  self.enc = nn.Conv2d(out_ch, in_ch, kernel_size=1)
14
+ self.dec = nn.Conv2d(in_ch, out_ch, kernel_size=1)
15
 
16
  def forward(self, x):
17
  return x
 
20
  def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
21
  sub = kwargs.get("subfolder", "")
22
  base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
23
+ cfg_file = os.path.join(base, "config.json")
24
+ with open(cfg_file, "r") as f:
25
  cfg = json.load(f)
26
  cfg.pop("_class_name", None)
27
  return cls(**cfg)
28
 
29
  @classmethod
30
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
31
+ # 1) Build the instance from config
32
  bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
33
 
34
+ # 2) Load raw state dict
35
  sub = kwargs.get("subfolder", "")
36
  base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
37
  raw_sd = torch.load(os.path.join(base, "pytorch_model.bin"), map_location="cpu")
38
 
39
+ # 3) Remap keys explicitly
40
+ remapped_sd = {
41
+ # map to16.* → dec.*
42
+ k.replace("to16", "dec", 1): v if k.startswith("to16") else
43
+ # map to4.* → enc.*
44
+ v
45
+ for k, v in raw_sd.items()
46
+ }
47
+ # But dictionary comprehension above only covers to16; let's do full loop:
48
+ final_sd = {}
49
  for k, v in raw_sd.items():
50
+ if k.startswith("to16."):
51
+ final_sd[k.replace("to16", "dec", 1)] = v
52
+ elif k.startswith("to4."):
53
+ final_sd[k.replace("to4", "enc", 1)] = v
54
  else:
55
+ final_sd[k] = v
56
 
57
+ # 4) Load into your layers
58
+ bridge.load_state_dict(final_sd)
59
  return bridge