Update src/pipeline.py
Browse files- src/pipeline.py +5 -9
src/pipeline.py
CHANGED
|
@@ -788,11 +788,9 @@ os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
|
|
| 788 |
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
| 789 |
torch._dynamo.config.suppress_errors = True
|
| 790 |
|
| 791 |
-
Pipeline
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
|
| 795 |
-
|
| 796 |
|
| 797 |
def load_pipeline() -> Pipeline:
|
| 798 |
text_encoder = CLIPTextModel.from_pretrained(CHECKPOINT, revision=REVISION, subfolder="text_encoder", local_files_only=True, torch_dtype=torch.bfloat16,)
|
|
@@ -826,11 +824,8 @@ def load_pipeline() -> Pipeline:
|
|
| 826 |
|
| 827 |
return pipeline
|
| 828 |
|
|
|
|
| 829 |
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
|
| 830 |
-
gc.collect()
|
| 831 |
-
torch.cuda.empty_cache()
|
| 832 |
-
torch.cuda.reset_peak_memory_stats()
|
| 833 |
-
|
| 834 |
generator = Generator(pipeline.device).manual_seed(request.seed)
|
| 835 |
|
| 836 |
return pipeline(
|
|
@@ -842,3 +837,4 @@ def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
|
|
| 842 |
height=request.height,
|
| 843 |
width=request.width,
|
| 844 |
).images[0]
|
|
|
|
|
|
| 788 |
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
| 789 |
torch._dynamo.config.suppress_errors = True
|
| 790 |
|
| 791 |
+
Pipeline = None
|
| 792 |
+
ids = "slobers/Flux.1.Schnella"
|
| 793 |
+
Revision = "e34d670e44cecbbc90e4962e7aada2ac5ce8b55b"
|
|
|
|
|
|
|
| 794 |
|
| 795 |
def load_pipeline() -> Pipeline:
|
| 796 |
text_encoder = CLIPTextModel.from_pretrained(CHECKPOINT, revision=REVISION, subfolder="text_encoder", local_files_only=True, torch_dtype=torch.bfloat16,)
|
|
|
|
| 824 |
|
| 825 |
return pipeline
|
| 826 |
|
| 827 |
+
@torch.no_grad()
|
| 828 |
def infer(request: TextToImageRequest, pipeline: Pipeline) -> Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
generator = Generator(pipeline.device).manual_seed(request.seed)
|
| 830 |
|
| 831 |
return pipeline(
|
|
|
|
| 837 |
height=request.height,
|
| 838 |
width=request.width,
|
| 839 |
).images[0]
|
| 840 |
+
|