manbeast3b commited on
Commit
f72dcb5
·
verified ·
1 Parent(s): e040adc

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +2 -0
src/pipeline.py CHANGED
@@ -37,6 +37,8 @@ def load_pipeline() -> Pipeline:
37
  ).to(memory_format=torch.channels_last)
38
 
39
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_fx", revision="00c83cdfdfe46992eb0ed45921eee34261fcb56e", torch_dtype=dtype)
 
 
40
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
41
  model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
42
  pipeline = FluxPipeline.from_pretrained(
 
37
  ).to(memory_format=torch.channels_last)
38
 
39
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_fx", revision="00c83cdfdfe46992eb0ed45921eee34261fcb56e", torch_dtype=dtype)
40
+ vae.encoder.load_state_dict(torch.load("encoder.pth"), strict=False)
41
+ vae.decoder.load_state_dict(torch.load("decoder.pth"), strict=False)
42
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
43
  model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
44
  pipeline = FluxPipeline.from_pretrained(