manbeast3b commited on
Commit
adefc74
·
verified ·
1 Parent(s): c305d4f

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +4 -2
src/pipeline.py CHANGED
@@ -37,11 +37,10 @@ def load_pipeline() -> Pipeline:
37
  ).to(memory_format=torch.channels_last)
38
 
39
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_fx", revision="00c83cdfdfe46992eb0ed45921eee34261fcb56e", torch_dtype=dtype)
40
- vae.encoder.load_state_dict(torch.load("encoder.pth"), strict=False)
41
- vae.decoder.load_state_dict(torch.load("decoder.pth"), strict=False)
42
 
43
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
44
  model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
 
45
  pipeline = FluxPipeline.from_pretrained(
46
  ckpt_id,
47
  vae=vae,
@@ -50,6 +49,9 @@ def load_pipeline() -> Pipeline:
50
  text_encoder_2=text_encoder_2,
51
  torch_dtype=dtype,
52
  ).to(device)
 
 
 
53
  pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
54
  quantize_(pipeline.vae, int8_weight_only())
55
  for _ in range(3):
 
37
  ).to(memory_format=torch.channels_last)
38
 
39
  vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_fx", revision="00c83cdfdfe46992eb0ed45921eee34261fcb56e", torch_dtype=dtype)
 
 
40
 
41
  path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
42
  model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
43
+
44
  pipeline = FluxPipeline.from_pretrained(
45
  ckpt_id,
46
  vae=vae,
 
49
  text_encoder_2=text_encoder_2,
50
  torch_dtype=dtype,
51
  ).to(device)
52
+ basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell-vae-kl-p10/snapshots/facb90ac7d8e13df9a8c177f18b0d450a3e1ed41")
53
+ pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
54
+ pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
55
  pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
56
  quantize_(pipeline.vae, int8_weight_only())
57
  for _ in range(3):