jokerbit commited on
Commit
7f7da07
·
verified ·
1 Parent(s): 062697f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .gitignore +1 -1
  2. src/pipeline.py +14 -16
.gitignore CHANGED
@@ -5,4 +5,4 @@
5
  **/.venv
6
  .venv
7
  .git
8
-
 
5
  **/.venv
6
  .venv
7
  .git
8
+ *.swp
src/pipeline.py CHANGED
@@ -10,14 +10,12 @@ from torch import Generator
10
  from torchao.quantization import quantize_, int8_weight_only
11
  from time import perf_counter
12
 
13
- FLUX_CHECKPOINT = "black-forest-labs/FLUX.1-schnell"
14
  HOME = os.environ["HOME"]
15
- # REPO_DIR = ".cache/huggingface/hub/models--jokerbit--flux-schnell-int8/snapshots/9510dd83d6d44ab375b5e8facec10afa81be2a8f"
16
- QUANTIZED_MODEL = ["transformer", "text_encoder", "text_encoder_2", "vae"]
17
- # QUANT_CKPT = {"transformer": os.path.join(HOME, REPO_DIR, "flux_schnell_transformer_int8wo.pt"),
18
- # "text_encoder": os.path.join(HOME, REPO_DIR, "flux_schnell_text_encoder_int8wo.pt"),
19
- # "text_encoder_2": os.path.join(HOME, REPO_DIR, "flux_schnell_text_encoder_2_int8wo.pt"),
20
- # "vae": os.path.join(HOME, REPO_DIR, "flux_schnell_vae_int8wo.pt")}
21
 
22
  QUANT_CONFIG = int8_weight_only()
23
  DTYPE = torch.bfloat16
@@ -25,7 +23,7 @@ NUM_STEPS = 4
25
 
26
  def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
27
  if quant_ckpt is not None:
28
- config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer")
29
  model = FluxTransformer2DModel.from_config(config).to(DTYPE)
30
  state_dict = torch.load(quant_ckpt, map_location="cpu")
31
  model.load_state_dict(state_dict, assign=True)
@@ -33,7 +31,7 @@ def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), qu
33
  return model
34
 
35
  model = FluxTransformer2DModel.from_pretrained(
36
- FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE,
37
  )
38
  if quantize:
39
  quantize_(model, quant_config)
@@ -42,7 +40,7 @@ def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), qu
42
 
43
  def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
44
  if quant_ckpt is not None:
45
- config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder")
46
  model = CLIPTextModel(config).to(DTYPE)
47
  state_dict = torch.load(quant_ckpt, map_location="cpu")
48
  model.load_state_dict(state_dict, assign=True)
@@ -50,7 +48,7 @@ def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), q
50
  return model
51
 
52
  model = CLIPTextModel.from_pretrained(
53
- FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE
54
  )
55
  if quantize:
56
  quantize_(model, quant_config)
@@ -59,7 +57,7 @@ def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), q
59
 
60
  def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
61
  if quant_ckpt is not None:
62
- config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2")
63
  model = T5EncoderModel(config).to(DTYPE)
64
  state_dict = torch.load(quant_ckpt, map_location="cpu")
65
  print(f"Loaded {quant_ckpt}")
@@ -67,7 +65,7 @@ def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(),
67
  return model
68
 
69
  model = T5EncoderModel.from_pretrained(
70
- FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE
71
  )
72
  if quantize:
73
  quantize_(model, quant_config)
@@ -76,14 +74,14 @@ def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(),
76
 
77
  def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
78
  if quant_ckpt is not None:
79
- config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae")
80
  model = AutoencoderKL.from_config(config).to(DTYPE)
81
  state_dict = torch.load(quant_ckpt, map_location="cpu")
82
  model.load_state_dict(state_dict, assign=True)
83
  print(f"Loaded {quant_ckpt}")
84
  return model
85
  model = AutoencoderKL.from_pretrained(
86
- FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE
87
  )
88
  if quantize:
89
  quantize_(model, quant_config)
@@ -119,7 +117,7 @@ def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
119
  if request.seed is None:
120
  generator = None
121
  else:
122
- generator = Generator(request.seed).device(_pipeline.device)
123
 
124
  empty_cache()
125
  image = _pipeline(prompt=request.prompt,
 
10
  from torchao.quantization import quantize_, int8_weight_only
11
  from time import perf_counter
12
 
13
+
14
  HOME = os.environ["HOME"]
15
+ FLUX_CHECKPOINT = os.path.join(HOME,
16
+ ".cache/huggingface/hub/models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/")
17
+ QUANTIZED_MODEL = ["transformer", "text_encoder_2", "text_encoder", "vae"]
18
+
 
 
19
 
20
  QUANT_CONFIG = int8_weight_only()
21
  DTYPE = torch.bfloat16
 
23
 
24
  def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
25
  if quant_ckpt is not None:
26
+ config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer", local_files_only=True)
27
  model = FluxTransformer2DModel.from_config(config).to(DTYPE)
28
  state_dict = torch.load(quant_ckpt, map_location="cpu")
29
  model.load_state_dict(state_dict, assign=True)
 
31
  return model
32
 
33
  model = FluxTransformer2DModel.from_pretrained(
34
+ FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE, local_files_only=True
35
  )
36
  if quantize:
37
  quantize_(model, quant_config)
 
40
 
41
  def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
42
  if quant_ckpt is not None:
43
+ config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder", local_files_only=True)
44
  model = CLIPTextModel(config).to(DTYPE)
45
  state_dict = torch.load(quant_ckpt, map_location="cpu")
46
  model.load_state_dict(state_dict, assign=True)
 
48
  return model
49
 
50
  model = CLIPTextModel.from_pretrained(
51
+ FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE, local_files_only=True
52
  )
53
  if quantize:
54
  quantize_(model, quant_config)
 
57
 
58
  def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
59
  if quant_ckpt is not None:
60
+ config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2", local_files_only=True)
61
  model = T5EncoderModel(config).to(DTYPE)
62
  state_dict = torch.load(quant_ckpt, map_location="cpu")
63
  print(f"Loaded {quant_ckpt}")
 
65
  return model
66
 
67
  model = T5EncoderModel.from_pretrained(
68
+ FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE, local_files_only=True
69
  )
70
  if quantize:
71
  quantize_(model, quant_config)
 
74
 
75
  def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
76
  if quant_ckpt is not None:
77
+ config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae", local_files_only=True)
78
  model = AutoencoderKL.from_config(config).to(DTYPE)
79
  state_dict = torch.load(quant_ckpt, map_location="cpu")
80
  model.load_state_dict(state_dict, assign=True)
81
  print(f"Loaded {quant_ckpt}")
82
  return model
83
  model = AutoencoderKL.from_pretrained(
84
+ FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE, local_files_only=True
85
  )
86
  if quantize:
87
  quantize_(model, quant_config)
 
117
  if request.seed is None:
118
  generator = None
119
  else:
120
+ generator = Generator(device=_pipeline.device).manual_seed(request.seed)
121
 
122
  empty_cache()
123
  image = _pipeline(prompt=request.prompt,