AbstractPhil commited on
Commit
f434818
·
verified ·
1 Parent(s): 3a369d2

Update bridge/latent_bridge.py

Browse files
Files changed (1) hide show
  1. bridge/latent_bridge.py +12 -22
bridge/latent_bridge.py CHANGED
@@ -18,43 +18,33 @@ class LatentBridge(ModelMixin):
18
 
19
  @classmethod
20
  def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
21
- # Determine the folder (honor subfolder kw if provided)
22
  sub = kwargs.get("subfolder", "")
23
  base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
24
-
25
- # Load and parse config.json
26
- config_file = os.path.join(base, "config.json")
27
- with open(config_file, "r") as f:
28
  cfg = json.load(f)
29
  cfg.pop("_class_name", None)
30
  return cls(**cfg)
31
 
32
  @classmethod
33
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
34
- # Resolve base folder for both config & weights
35
- sub = kwargs.get("subfolder", "")
36
- base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
37
-
38
- # 1) Instantiate from config
39
  bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
40
 
41
- # 2) Load original state dict
42
- bin_path = os.path.join(base, "pytorch_model.bin")
43
- state_dict = torch.load(bin_path, map_location="cpu")
 
44
 
45
- # 3) Remap keys:
46
- # to4.* → enc.* (16→4 conv)
47
- # to16.* → dec.* (4→16 conv)
48
  remapped = {}
49
- for k, v in state_dict.items():
50
  if k.startswith("to4."):
51
- new_key = k.replace("to4", "enc", 1)
52
  elif k.startswith("to16."):
53
- new_key = k.replace("to16", "dec", 1)
54
  else:
55
- new_key = k
56
- remapped[new_key] = v
57
 
58
- # 4) Load into your bridge
59
  bridge.load_state_dict(remapped)
60
  return bridge
 
18
 
19
  @classmethod
20
  def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
 
21
  sub = kwargs.get("subfolder", "")
22
  base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
23
+ with open(os.path.join(base, "config.json"), "r") as f:
 
 
 
24
  cfg = json.load(f)
25
  cfg.pop("_class_name", None)
26
  return cls(**cfg)
27
 
28
  @classmethod
29
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
30
+ # 1) Build instance from config
 
 
 
 
31
  bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
32
 
33
+ # 2) Load raw weights
34
+ sub = kwargs.get("subfolder", "")
35
+ base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
36
+ raw_sd = torch.load(os.path.join(base, "pytorch_model.bin"), map_location="cpu")
37
 
38
+ # 3) Remap keys
 
 
39
  remapped = {}
40
+ for k, v in raw_sd.items():
41
  if k.startswith("to4."):
42
+ remapped[k.replace("to4", "enc", 1)] = v
43
  elif k.startswith("to16."):
44
+ remapped[k.replace("to16", "dec", 1)] = v
45
  else:
46
+ remapped[k] = v
 
47
 
48
+ # 4) **Load** the remapped dict into your layers
49
  bridge.load_state_dict(remapped)
50
  return bridge