Update src/pipeline.py
Browse files- src/pipeline.py +5 -1
src/pipeline.py
CHANGED
|
@@ -35,6 +35,10 @@ def empty_cache():
|
|
| 35 |
torch.cuda.reset_peak_memory_stats()
|
| 36 |
print(f"Flush took: {time.time() - start}")
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def load_pipeline() -> Pipeline:
|
| 39 |
empty_cache()
|
| 40 |
dtype, device = torch.bfloat16, "cuda"
|
|
@@ -49,7 +53,7 @@ def load_pipeline() -> Pipeline:
|
|
| 49 |
|
| 50 |
model_id = "manbeast3b/flux-schnell-int8"
|
| 51 |
transformer = FluxTransformer2DModel.from_pretrained(
|
| 52 |
-
model_id, subfolder="transformer", torch_dtype=torch.bfloat16,
|
| 53 |
)
|
| 54 |
text_encoder_2 = T5EncoderModel.from_pretrained(
|
| 55 |
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
|
|
|
| 35 |
torch.cuda.reset_peak_memory_stats()
|
| 36 |
print(f"Flush took: {time.time() - start}")
|
| 37 |
|
| 38 |
+
|
| 39 |
+
cache_dir = "/root/.cache/huggingface/hub/models--manbeast3b--flux-schnell-int8/snapshots/eb656b7968de3088ccac7cda876f5782e5a2f721/"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def load_pipeline() -> Pipeline:
|
| 43 |
empty_cache()
|
| 44 |
dtype, device = torch.bfloat16, "cuda"
|
|
|
|
| 53 |
|
| 54 |
model_id = "manbeast3b/flux-schnell-int8"
|
| 55 |
transformer = FluxTransformer2DModel.from_pretrained(
|
| 56 |
+
model_id, subfolder="transformer", torch_dtype=torch.bfloat16, cache_dir=cache_dir, # quantization_config=config,
|
| 57 |
)
|
| 58 |
text_encoder_2 = T5EncoderModel.from_pretrained(
|
| 59 |
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|