Update bridge/latent_bridge.py
Browse files- bridge/latent_bridge.py +12 -22
bridge/latent_bridge.py
CHANGED
|
@@ -18,43 +18,33 @@ class LatentBridge(ModelMixin):
|
|
| 18 |
|
| 19 |
@classmethod
|
| 20 |
def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 21 |
-
# Determine the folder (honor subfolder kw if provided)
|
| 22 |
sub = kwargs.get("subfolder", "")
|
| 23 |
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 24 |
-
|
| 25 |
-
# Load and parse config.json
|
| 26 |
-
config_file = os.path.join(base, "config.json")
|
| 27 |
-
with open(config_file, "r") as f:
|
| 28 |
cfg = json.load(f)
|
| 29 |
cfg.pop("_class_name", None)
|
| 30 |
return cls(**cfg)
|
| 31 |
|
| 32 |
@classmethod
|
| 33 |
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 34 |
-
#
|
| 35 |
-
sub = kwargs.get("subfolder", "")
|
| 36 |
-
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 37 |
-
|
| 38 |
-
# 1) Instantiate from config
|
| 39 |
bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
| 40 |
|
| 41 |
-
# 2) Load
|
| 42 |
-
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |
-
# 3) Remap keys
|
| 46 |
-
# to4.* → enc.* (16→4 conv)
|
| 47 |
-
# to16.* → dec.* (4→16 conv)
|
| 48 |
remapped = {}
|
| 49 |
-
for k, v in
|
| 50 |
if k.startswith("to4."):
|
| 51 |
-
|
| 52 |
elif k.startswith("to16."):
|
| 53 |
-
|
| 54 |
else:
|
| 55 |
-
|
| 56 |
-
remapped[new_key] = v
|
| 57 |
|
| 58 |
-
# 4) Load into your
|
| 59 |
bridge.load_state_dict(remapped)
|
| 60 |
return bridge
|
|
|
|
| 18 |
|
| 19 |
@classmethod
|
| 20 |
def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
|
|
|
|
| 21 |
sub = kwargs.get("subfolder", "")
|
| 22 |
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 23 |
+
with open(os.path.join(base, "config.json"), "r") as f:
|
|
|
|
|
|
|
|
|
|
| 24 |
cfg = json.load(f)
|
| 25 |
cfg.pop("_class_name", None)
|
| 26 |
return cls(**cfg)
|
| 27 |
|
| 28 |
@classmethod
|
| 29 |
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 30 |
+
# 1) Build instance from config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
| 32 |
|
| 33 |
+
# 2) Load raw weights
|
| 34 |
+
sub = kwargs.get("subfolder", "")
|
| 35 |
+
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 36 |
+
raw_sd = torch.load(os.path.join(base, "pytorch_model.bin"), map_location="cpu")
|
| 37 |
|
| 38 |
+
# 3) Remap keys
|
|
|
|
|
|
|
| 39 |
remapped = {}
|
| 40 |
+
for k, v in raw_sd.items():
|
| 41 |
if k.startswith("to4."):
|
| 42 |
+
remapped[k.replace("to4", "enc", 1)] = v
|
| 43 |
elif k.startswith("to16."):
|
| 44 |
+
remapped[k.replace("to16", "dec", 1)] = v
|
| 45 |
else:
|
| 46 |
+
remapped[k] = v
|
|
|
|
| 47 |
|
| 48 |
+
# 4) **Load** the remapped dict into your layers
|
| 49 |
bridge.load_state_dict(remapped)
|
| 50 |
return bridge
|