Update src/pipeline.py
Browse files- src/pipeline.py +4 -2
src/pipeline.py
CHANGED
|
@@ -37,11 +37,10 @@ def load_pipeline() -> Pipeline:
|
|
| 37 |
).to(memory_format=torch.channels_last)
|
| 38 |
|
| 39 |
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_fx", revision="00c83cdfdfe46992eb0ed45921eee34261fcb56e", torch_dtype=dtype)
|
| 40 |
-
vae.encoder.load_state_dict(torch.load("encoder.pth"), strict=False)
|
| 41 |
-
vae.decoder.load_state_dict(torch.load("decoder.pth"), strict=False)
|
| 42 |
|
| 43 |
path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
|
| 44 |
model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
|
|
|
|
| 45 |
pipeline = FluxPipeline.from_pretrained(
|
| 46 |
ckpt_id,
|
| 47 |
vae=vae,
|
|
@@ -50,6 +49,9 @@ def load_pipeline() -> Pipeline:
|
|
| 50 |
text_encoder_2=text_encoder_2,
|
| 51 |
torch_dtype=dtype,
|
| 52 |
).to(device)
|
|
|
|
|
|
|
|
|
|
| 53 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
|
| 54 |
quantize_(pipeline.vae, int8_weight_only())
|
| 55 |
for _ in range(3):
|
|
|
|
| 37 |
).to(memory_format=torch.channels_last)
|
| 38 |
|
| 39 |
vae = AutoencoderTiny.from_pretrained("RobertML/FLUX.1-schnell-vae_fx", revision="00c83cdfdfe46992eb0ed45921eee34261fcb56e", torch_dtype=dtype)
|
|
|
|
|
|
|
| 40 |
|
| 41 |
path = os.path.join(HF_HUB_CACHE, "models--RobertML--FLUX.1-schnell-int8wo/snapshots/307e0777d92df966a3c0f99f31a6ee8957a9857a")
|
| 42 |
model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False).to(memory_format=torch.channels_last)
|
| 43 |
+
|
| 44 |
pipeline = FluxPipeline.from_pretrained(
|
| 45 |
ckpt_id,
|
| 46 |
vae=vae,
|
|
|
|
| 49 |
text_encoder_2=text_encoder_2,
|
| 50 |
torch_dtype=dtype,
|
| 51 |
).to(device)
|
| 52 |
+
basepath = os.path.join(HF_HUB_CACHE, "models--manbeast3b--Flux.1.schnell-vae-kl-p10/snapshots/facb90ac7d8e13df9a8c177f18b0d450a3e1ed41")
|
| 53 |
+
pipeline.vae.encoder.load_state_dict(torch.load(os.path.join(basepath, "encoder.pth")), strict=False)
|
| 54 |
+
pipeline.vae.decoder.load_state_dict(torch.load(os.path.join(basepath, "decoder.pth")), strict=False)
|
| 55 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune")
|
| 56 |
quantize_(pipeline.vae, int8_weight_only())
|
| 57 |
for _ in range(3):
|