manbeast3b commited on
Commit
86e604c
·
verified ·
1 Parent(s): a5fe896

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +1 -1
src/pipeline.py CHANGED
@@ -81,7 +81,7 @@ def load_pipeline():
81
  sd = torch.load(p, map_location="cpu", weights_only=True)
82
  f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
83
  mod.load_state_dict(f_sd, strict=False)
84
- #mod.to(dtype=torch.bfloat16)
85
 
86
  lsd("ko.pth", vae.encoder, "encoder.")
87
  lsd("ok.pth", vae.decoder, "decoder.")
 
81
  sd = torch.load(p, map_location="cpu", weights_only=True)
82
  f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
83
  mod.load_state_dict(f_sd, strict=False)
84
+ mod.to(dtype=torch.bfloat16)
85
 
86
  lsd("ko.pth", vae.encoder, "encoder.")
87
  lsd("ok.pth", vae.decoder, "decoder.")