Update bridge/latent_bridge.py
Browse files- bridge/latent_bridge.py +23 -14
bridge/latent_bridge.py
CHANGED
|
@@ -9,9 +9,9 @@ from diffusers import ModelMixin
|
|
| 9 |
class LatentBridge(ModelMixin):
|
| 10 |
def __init__(self, in_ch: int = 4, out_ch: int = 16):
|
| 11 |
super().__init__()
|
| 12 |
-
#
|
| 13 |
-
self.dec = nn.Conv2d(in_ch, out_ch, kernel_size=1)
|
| 14 |
self.enc = nn.Conv2d(out_ch, in_ch, kernel_size=1)
|
|
|
|
| 15 |
|
| 16 |
def forward(self, x):
|
| 17 |
return x
|
|
@@ -20,31 +20,40 @@ class LatentBridge(ModelMixin):
|
|
| 20 |
def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 21 |
sub = kwargs.get("subfolder", "")
|
| 22 |
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 23 |
-
|
|
|
|
| 24 |
cfg = json.load(f)
|
| 25 |
cfg.pop("_class_name", None)
|
| 26 |
return cls(**cfg)
|
| 27 |
|
| 28 |
@classmethod
|
| 29 |
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 30 |
-
# 1) Build instance from config
|
| 31 |
bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
| 32 |
|
| 33 |
-
# 2) Load raw
|
| 34 |
sub = kwargs.get("subfolder", "")
|
| 35 |
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 36 |
raw_sd = torch.load(os.path.join(base, "pytorch_model.bin"), map_location="cpu")
|
| 37 |
|
| 38 |
-
# 3) Remap keys
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
for k, v in raw_sd.items():
|
| 41 |
-
if k.startswith("
|
| 42 |
-
|
| 43 |
-
elif k.startswith("
|
| 44 |
-
|
| 45 |
else:
|
| 46 |
-
|
| 47 |
|
| 48 |
-
# 4)
|
| 49 |
-
bridge.load_state_dict(
|
| 50 |
return bridge
|
|
|
|
| 9 |
class LatentBridge(ModelMixin):
|
| 10 |
def __init__(self, in_ch: int = 4, out_ch: int = 16):
|
| 11 |
super().__init__()
|
| 12 |
+
# enc: 16→4 channels, dec: 4→16 channels
|
|
|
|
| 13 |
self.enc = nn.Conv2d(out_ch, in_ch, kernel_size=1)
|
| 14 |
+
self.dec = nn.Conv2d(in_ch, out_ch, kernel_size=1)
|
| 15 |
|
| 16 |
def forward(self, x):
|
| 17 |
return x
|
|
|
|
| 20 |
def load_config(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 21 |
sub = kwargs.get("subfolder", "")
|
| 22 |
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 23 |
+
cfg_file = os.path.join(base, "config.json")
|
| 24 |
+
with open(cfg_file, "r") as f:
|
| 25 |
cfg = json.load(f)
|
| 26 |
cfg.pop("_class_name", None)
|
| 27 |
return cls(**cfg)
|
| 28 |
|
| 29 |
@classmethod
|
| 30 |
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| 31 |
+
# 1) Build the instance from config
|
| 32 |
bridge = cls.load_config(pretrained_model_name_or_path, **kwargs)
|
| 33 |
|
| 34 |
+
# 2) Load raw state dict
|
| 35 |
sub = kwargs.get("subfolder", "")
|
| 36 |
base = os.path.join(pretrained_model_name_or_path, sub) if sub else pretrained_model_name_or_path
|
| 37 |
raw_sd = torch.load(os.path.join(base, "pytorch_model.bin"), map_location="cpu")
|
| 38 |
|
| 39 |
+
# 3) Remap keys explicitly
|
| 40 |
+
remapped_sd = {
|
| 41 |
+
# map to16.* → dec.*
|
| 42 |
+
k.replace("to16", "dec", 1): v if k.startswith("to16") else
|
| 43 |
+
# map to4.* → enc.*
|
| 44 |
+
v
|
| 45 |
+
for k, v in raw_sd.items()
|
| 46 |
+
}
|
| 47 |
+
# But dictionary comprehension above only covers to16; let's do full loop:
|
| 48 |
+
final_sd = {}
|
| 49 |
for k, v in raw_sd.items():
|
| 50 |
+
if k.startswith("to16."):
|
| 51 |
+
final_sd[k.replace("to16", "dec", 1)] = v
|
| 52 |
+
elif k.startswith("to4."):
|
| 53 |
+
final_sd[k.replace("to4", "enc", 1)] = v
|
| 54 |
else:
|
| 55 |
+
final_sd[k] = v
|
| 56 |
|
| 57 |
+
# 4) Load into your layers
|
| 58 |
+
bridge.load_state_dict(final_sd)
|
| 59 |
return bridge
|