Update src/pipeline.py
Browse files- src/pipeline.py +2 -0
src/pipeline.py
CHANGED
|
@@ -21,6 +21,8 @@ ids = "slobers/Flux.1.Schnella"
|
|
| 21 |
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
|
| 22 |
|
| 23 |
def load_pipeline() -> Pipeline:
|
|
|
|
|
|
|
| 24 |
pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, local_files_only=True, torch_dtype=torch.bfloat16,)
|
| 25 |
pipeline.to("cuda")
|
| 26 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
|
|
|
| 21 |
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
|
| 22 |
|
| 23 |
def load_pipeline() -> Pipeline:
|
| 24 |
+
path = os.path.join(HF_HUB_CACHE, "models--slobers--Flux.1.Schnella/snapshots/e34d670e44cecbbc90e4962e7aada2ac5ce8b55b/transformer")
|
| 25 |
+
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
|
| 26 |
pipeline = FluxPipeline.from_pretrained(ids, revision=Revision, local_files_only=True, torch_dtype=torch.bfloat16,)
|
| 27 |
pipeline.to("cuda")
|
| 28 |
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|