Update src/pipeline.py
Browse files- src/pipeline.py +10 -12
src/pipeline.py
CHANGED
|
@@ -49,14 +49,6 @@ def load_quanto_text_encoder_2(text_repo_path):
|
|
| 49 |
|
| 50 |
def load_pipeline() -> Pipeline:
|
| 51 |
|
| 52 |
-
try:
|
| 53 |
-
text_repo_path = os.path.join(HF_HUB_CACHE, "models--RichardWilliam--XULF_T5_bf16/snapshots/63a3d9ef7b586655600ac9bd4e4747d038237761")
|
| 54 |
-
text_encoder_2 = load_quanto_text_encoder_2(text_repo_path=text_repo_path)
|
| 55 |
-
except:
|
| 56 |
-
text_encoder_2 = T5EncoderModel.from_pretrained("RichardWilliam/XULF_T5_bf16",
|
| 57 |
-
revision = "63a3d9ef7b586655600ac9bd4e4747d038237761",
|
| 58 |
-
torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
|
| 59 |
-
|
| 60 |
origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae",
|
| 61 |
revision="3ee225c539465c27adadec45c6e8af50a7397b7d",
|
| 62 |
torch_dtype=torch.bfloat16)
|
|
@@ -69,15 +61,20 @@ def load_pipeline() -> Pipeline:
|
|
| 69 |
transformer = origin_trans
|
| 70 |
|
| 71 |
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
|
| 72 |
-
revision=REVISION,
|
| 73 |
-
vae=origin_vae,
|
| 74 |
transformer=transformer,
|
| 75 |
-
text_encoder_2=text_encoder_2,
|
| 76 |
torch_dtype=torch.bfloat16)
|
| 77 |
pipeline.to("cuda")
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
for __ in range(3):
|
| 80 |
-
pipeline(prompt="
|
| 81 |
width=1024,
|
| 82 |
height=1024,
|
| 83 |
guidance_scale=0.0,
|
|
@@ -89,6 +86,7 @@ def load_pipeline() -> Pipeline:
|
|
| 89 |
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
|
| 90 |
|
| 91 |
reset_cache()
|
|
|
|
| 92 |
|
| 93 |
generator = Generator(pipeline.device).manual_seed(request.seed)
|
| 94 |
|
|
|
|
| 49 |
|
| 50 |
def load_pipeline() -> Pipeline:
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
origin_vae = AutoencoderTiny.from_pretrained("RichardWilliam/XULF_Vae",
|
| 53 |
revision="3ee225c539465c27adadec45c6e8af50a7397b7d",
|
| 54 |
torch_dtype=torch.bfloat16)
|
|
|
|
| 61 |
transformer = origin_trans
|
| 62 |
|
| 63 |
pipeline = DiffusionPipeline.from_pretrained(CHECKPOINT,
|
| 64 |
+
revision=REVISION,
|
|
|
|
| 65 |
transformer=transformer,
|
|
|
|
| 66 |
torch_dtype=torch.bfloat16)
|
| 67 |
pipeline.to("cuda")
|
| 68 |
|
| 69 |
+
try:
|
| 70 |
+
# pipeline.text_encoder_v2 = load_quanto_text_encoder_2(text_repo_path=None)
|
| 71 |
+
pipeline.enable_cuda_graph(type="max-autotune")
|
| 72 |
+
pipeline.text_encoder_v2 = load_quanto_text_encoder_2(text_repo_path=None)
|
| 73 |
+
except:
|
| 74 |
+
print("Something wrong here")
|
| 75 |
+
|
| 76 |
for __ in range(3):
|
| 77 |
+
pipeline(prompt="schoenobatist, halisteresis, chronometric, hallucinative",
|
| 78 |
width=1024,
|
| 79 |
height=1024,
|
| 80 |
guidance_scale=0.0,
|
|
|
|
| 86 |
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
|
| 87 |
|
| 88 |
reset_cache()
|
| 89 |
+
torch.cuda.empty_cache()
|
| 90 |
|
| 91 |
generator = Generator(pipeline.device).manual_seed(request.seed)
|
| 92 |
|