Chrissy1 commited on
Commit
f65f896
·
verified ·
1 Parent(s): 3117cfa

Initial commit with folder contents

Browse files
Files changed (1) hide show
  1. src/pipeline.py +83 -32
src/pipeline.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  from diffusers import (
2
  DiffusionPipeline,
3
  AutoencoderKL,
@@ -13,56 +20,100 @@ from transformers import (
13
  CLIPTokenizer,
14
  CLIPTextModel
15
  )
16
- import torch
17
- import torch._dynamo
18
- import gc
19
- from PIL import Image
20
  from pipelines.models import TextToImageRequest
21
  from torch import Generator
22
- import time
23
- import math
24
- from typing import Type, Dict, Any, Tuple, Callable, Optional, Union
25
- import numpy as np
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
29
 
30
- # preconfigs
31
- import os
32
- os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
33
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
34
  torch._dynamo.config.suppress_errors = True
35
  torch.backends.cuda.matmul.allow_tf32 = True
36
  torch.backends.cudnn.enabled = True
37
- # torch.backends.cudnn.benchmark = True
38
 
39
- # globals
40
  Pipeline = None
41
- ckpt_id = "black-forest-labs/FLUX.1-schnell"
42
- ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
 
43
 
44
  def empty_cache():
 
45
  gc.collect()
46
  torch.cuda.empty_cache()
47
  torch.cuda.reset_max_memory_allocated()
48
  torch.cuda.reset_peak_memory_stats()
49
 
50
- def load_pipeline() -> Pipeline:
51
- text_encoder_2 = T5EncoderModel.from_pretrained("Chrissy1/extra0manQ0", revision = "c0db1e82d89825a4664ad873f20d261cbe46e737", subfolder="text_encoder_2",torch_dtype=torch.bfloat16).to(memory_format=torch.channels_last)
52
- path = os.path.join(HF_HUB_CACHE, "models--Chrissy1--extra0manQ0/snapshots/c0db1e82d89825a4664ad873f20d261cbe46e737/transformer")
53
- transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False).to(memory_format=torch.channels_last)
54
- quantize_(AutoencoderKL.from_pretrained(ckpt_id,revision=ckpt_revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,), int8_weight_only())
55
- pipeline = FluxPipeline.from_pretrained(ckpt_id, revision=ckpt_revision, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  pipeline.to("cuda")
 
 
57
  with torch.inference_mode():
58
- pipeline(prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
 
 
 
 
 
 
 
 
59
  return pipeline
60
 
61
- sample = 1
62
  @torch.no_grad()
63
- def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
64
- global sample
65
- if not sample:
66
- sample=1
67
- empty_cache()
68
- return pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from typing import Optional
7
+
8
  from diffusers import (
9
  DiffusionPipeline,
10
  AutoencoderKL,
 
20
  CLIPTokenizer,
21
  CLIPTextModel
22
  )
 
 
 
 
23
  from pipelines.models import TextToImageRequest
24
  from torch import Generator
25
+ from torchao.quantization import quantize_, int8_weight_only
 
 
 
 
 
 
26
 
27
+ # Pre-configurations
28
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"
 
29
  os.environ["TOKENIZERS_PARALLELISM"] = "True"
30
  torch._dynamo.config.suppress_errors = True
31
  torch.backends.cuda.matmul.allow_tf32 = True
32
  torch.backends.cudnn.enabled = True
 
33
 
34
+ # Global variables
35
  Pipeline = None
36
+ CKPT_ID = "black-forest-labs/FLUX.1-schnell"
37
+ CKPT_REVISION = "741f7c3ce8b383c54771c7003378a50191e9efe9"
38
+
39
 
40
  def empty_cache():
41
+ """Utility function to clear GPU memory."""
42
  gc.collect()
43
  torch.cuda.empty_cache()
44
  torch.cuda.reset_max_memory_allocated()
45
  torch.cuda.reset_peak_memory_stats()
46
 
47
+
48
+ def load_pipeline() -> FluxPipeline:
49
+ """Loads the diffusion pipeline with specified models and configurations."""
50
+ # Load text encoder
51
+ text_encoder_2 = T5EncoderModel.from_pretrained(
52
+ "Chrissy1/extra0manQ0",
53
+ revision="c0db1e82d89825a4664ad873f20d261cbe46e737",
54
+ subfolder="text_encoder_2",
55
+ torch_dtype=torch.bfloat16
56
+ ).to(memory_format=torch.channels_last)
57
+
58
+ # Load transformer
59
+ transformer_path = os.path.join(
60
+ HF_HUB_CACHE,
61
+ "models--Chrissy1--extra0manQ0/snapshots/c0db1e82d89825a4664ad873f20d261cbe46e737/transformer"
62
+ )
63
+ transformer = FluxTransformer2DModel.from_pretrained(
64
+ transformer_path,
65
+ torch_dtype=torch.bfloat16,
66
+ use_safetensors=False
67
+ ).to(memory_format=torch.channels_last)
68
+
69
+ # Load and quantize autoencoder
70
+ vae = AutoencoderKL.from_pretrained(
71
+ CKPT_ID,
72
+ revision=CKPT_REVISION,
73
+ subfolder="vae",
74
+ local_files_only=True,
75
+ torch_dtype=torch.bfloat16
76
+ )
77
+ quantize_(vae, int8_weight_only())
78
+
79
+ # Load FluxPipeline
80
+ pipeline = FluxPipeline.from_pretrained(
81
+ CKPT_ID,
82
+ revision=CKPT_REVISION,
83
+ transformer=transformer,
84
+ text_encoder_2=text_encoder_2,
85
+ torch_dtype=torch.bfloat16
86
+ )
87
  pipeline.to("cuda")
88
+
89
+ # Warm-up run to ensure the pipeline is ready
90
  with torch.inference_mode():
91
+ pipeline(
92
+ prompt="insensible, timbale, pothery, electrovital, actinogram, taxis, intracerebellar, centrodesmus",
93
+ width=1024,
94
+ height=1024,
95
+ guidance_scale=0.0,
96
+ num_inference_steps=4,
97
+ max_sequence_length=256
98
+ )
99
+
100
  return pipeline
101
 
102
+
103
  @torch.no_grad()
104
+ def infer(request: TextToImageRequest, pipeline: FluxPipeline, generator: Generator) -> Image:
105
+ """Generates an image based on the input request and pipeline."""
106
+ empty_cache() # Clear cache before inference
107
+
108
+ result = pipeline(
109
+ prompt=request.prompt,
110
+ generator=generator,
111
+ guidance_scale=0.0,
112
+ num_inference_steps=4,
113
+ max_sequence_length=256,
114
+ height=request.height,
115
+ width=request.width,
116
+ output_type="pil"
117
+ )
118
+
119
+ return result.images[0]