Update src/pipeline.py
Browse files- src/pipeline.py +1 -1
src/pipeline.py
CHANGED
|
@@ -81,7 +81,7 @@ def load_pipeline():
|
|
| 81 |
sd = torch.load(p, map_location="cpu", weights_only=True)
|
| 82 |
f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
|
| 83 |
mod.load_state_dict(f_sd, strict=False)
|
| 84 |
-
|
| 85 |
|
| 86 |
lsd("ko.pth", vae.encoder, "encoder.")
|
| 87 |
lsd("ok.pth", vae.decoder, "decoder.")
|
|
|
|
| 81 |
sd = torch.load(p, map_location="cpu", weights_only=True)
|
| 82 |
f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
|
| 83 |
mod.load_state_dict(f_sd, strict=False)
|
| 84 |
+
mod.to(dtype=torch.bfloat16)
|
| 85 |
|
| 86 |
lsd("ko.pth", vae.encoder, "encoder.")
|
| 87 |
lsd("ok.pth", vae.decoder, "decoder.")
|