Upload folder using huggingface_hub
Browse files- .gitignore +1 -1
- src/pipeline.py +14 -16
.gitignore
CHANGED
|
@@ -5,4 +5,4 @@
|
|
| 5 |
**/.venv
|
| 6 |
.venv
|
| 7 |
.git
|
| 8 |
-
|
|
|
|
| 5 |
**/.venv
|
| 6 |
.venv
|
| 7 |
.git
|
| 8 |
+
*.swp
|
src/pipeline.py
CHANGED
|
@@ -10,14 +10,12 @@ from torch import Generator
|
|
| 10 |
from torchao.quantization import quantize_, int8_weight_only
|
| 11 |
from time import perf_counter
|
| 12 |
|
| 13 |
-
|
| 14 |
HOME = os.environ["HOME"]
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
# "text_encoder_2": os.path.join(HOME, REPO_DIR, "flux_schnell_text_encoder_2_int8wo.pt"),
|
| 20 |
-
# "vae": os.path.join(HOME, REPO_DIR, "flux_schnell_vae_int8wo.pt")}
|
| 21 |
|
| 22 |
QUANT_CONFIG = int8_weight_only()
|
| 23 |
DTYPE = torch.bfloat16
|
|
@@ -25,7 +23,7 @@ NUM_STEPS = 4
|
|
| 25 |
|
| 26 |
def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 27 |
if quant_ckpt is not None:
|
| 28 |
-
config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer")
|
| 29 |
model = FluxTransformer2DModel.from_config(config).to(DTYPE)
|
| 30 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 31 |
model.load_state_dict(state_dict, assign=True)
|
|
@@ -33,7 +31,7 @@ def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), qu
|
|
| 33 |
return model
|
| 34 |
|
| 35 |
model = FluxTransformer2DModel.from_pretrained(
|
| 36 |
-
FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE,
|
| 37 |
)
|
| 38 |
if quantize:
|
| 39 |
quantize_(model, quant_config)
|
|
@@ -42,7 +40,7 @@ def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), qu
|
|
| 42 |
|
| 43 |
def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 44 |
if quant_ckpt is not None:
|
| 45 |
-
config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder")
|
| 46 |
model = CLIPTextModel(config).to(DTYPE)
|
| 47 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 48 |
model.load_state_dict(state_dict, assign=True)
|
|
@@ -50,7 +48,7 @@ def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), q
|
|
| 50 |
return model
|
| 51 |
|
| 52 |
model = CLIPTextModel.from_pretrained(
|
| 53 |
-
FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE
|
| 54 |
)
|
| 55 |
if quantize:
|
| 56 |
quantize_(model, quant_config)
|
|
@@ -59,7 +57,7 @@ def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), q
|
|
| 59 |
|
| 60 |
def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 61 |
if quant_ckpt is not None:
|
| 62 |
-
config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2")
|
| 63 |
model = T5EncoderModel(config).to(DTYPE)
|
| 64 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 65 |
print(f"Loaded {quant_ckpt}")
|
|
@@ -67,7 +65,7 @@ def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(),
|
|
| 67 |
return model
|
| 68 |
|
| 69 |
model = T5EncoderModel.from_pretrained(
|
| 70 |
-
FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE
|
| 71 |
)
|
| 72 |
if quantize:
|
| 73 |
quantize_(model, quant_config)
|
|
@@ -76,14 +74,14 @@ def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(),
|
|
| 76 |
|
| 77 |
def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 78 |
if quant_ckpt is not None:
|
| 79 |
-
config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae")
|
| 80 |
model = AutoencoderKL.from_config(config).to(DTYPE)
|
| 81 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 82 |
model.load_state_dict(state_dict, assign=True)
|
| 83 |
print(f"Loaded {quant_ckpt}")
|
| 84 |
return model
|
| 85 |
model = AutoencoderKL.from_pretrained(
|
| 86 |
-
FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE
|
| 87 |
)
|
| 88 |
if quantize:
|
| 89 |
quantize_(model, quant_config)
|
|
@@ -119,7 +117,7 @@ def infer(request: TextToImageRequest, _pipeline: FluxPipeline) -> Image:
|
|
| 119 |
if request.seed is None:
|
| 120 |
generator = None
|
| 121 |
else:
|
| 122 |
-
generator = Generator(
|
| 123 |
|
| 124 |
empty_cache()
|
| 125 |
image = _pipeline(prompt=request.prompt,
|
|
|
|
| 10 |
from torchao.quantization import quantize_, int8_weight_only
|
| 11 |
from time import perf_counter
|
| 12 |
|
| 13 |
+
|
| 14 |
HOME = os.environ["HOME"]
|
| 15 |
+
FLUX_CHECKPOINT = os.path.join(HOME,
|
| 16 |
+
".cache/huggingface/hub/models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/")
|
| 17 |
+
QUANTIZED_MODEL = ["transformer", "text_encoder_2", "text_encoder", "vae"]
|
| 18 |
+
|
|
|
|
|
|
|
| 19 |
|
| 20 |
QUANT_CONFIG = int8_weight_only()
|
| 21 |
DTYPE = torch.bfloat16
|
|
|
|
| 23 |
|
| 24 |
def get_transformer(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 25 |
if quant_ckpt is not None:
|
| 26 |
+
config = FluxTransformer2DModel.load_config(FLUX_CHECKPOINT, subfolder="transformer", local_files_only=True)
|
| 27 |
model = FluxTransformer2DModel.from_config(config).to(DTYPE)
|
| 28 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 29 |
model.load_state_dict(state_dict, assign=True)
|
|
|
|
| 31 |
return model
|
| 32 |
|
| 33 |
model = FluxTransformer2DModel.from_pretrained(
|
| 34 |
+
FLUX_CHECKPOINT, subfolder="transformer", torch_dtype=DTYPE, local_files_only=True
|
| 35 |
)
|
| 36 |
if quantize:
|
| 37 |
quantize_(model, quant_config)
|
|
|
|
| 40 |
|
| 41 |
def get_text_encoder(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 42 |
if quant_ckpt is not None:
|
| 43 |
+
config = CLIPTextConfig.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder", local_files_only=True)
|
| 44 |
model = CLIPTextModel(config).to(DTYPE)
|
| 45 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 46 |
model.load_state_dict(state_dict, assign=True)
|
|
|
|
| 48 |
return model
|
| 49 |
|
| 50 |
model = CLIPTextModel.from_pretrained(
|
| 51 |
+
FLUX_CHECKPOINT, subfolder="text_encoder", torch_dtype=DTYPE, local_files_only=True
|
| 52 |
)
|
| 53 |
if quantize:
|
| 54 |
quantize_(model, quant_config)
|
|
|
|
| 57 |
|
| 58 |
def get_text_encoder_2(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 59 |
if quant_ckpt is not None:
|
| 60 |
+
config = T5Config.from_pretrained(FLUX_CHECKPOINT, subfolder="text_encoder_2", local_files_only=True)
|
| 61 |
model = T5EncoderModel(config).to(DTYPE)
|
| 62 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 63 |
print(f"Loaded {quant_ckpt}")
|
|
|
|
| 65 |
return model
|
| 66 |
|
| 67 |
model = T5EncoderModel.from_pretrained(
|
| 68 |
+
FLUX_CHECKPOINT, subfolder="text_encoder_2", torch_dtype=DTYPE, local_files_only=True
|
| 69 |
)
|
| 70 |
if quantize:
|
| 71 |
quantize_(model, quant_config)
|
|
|
|
| 74 |
|
| 75 |
def get_vae(quantize: bool = True, quant_config = int8_weight_only(), quant_ckpt: str = None):
|
| 76 |
if quant_ckpt is not None:
|
| 77 |
+
config = AutoencoderKL.load_config(FLUX_CHECKPOINT, subfolder="vae", local_files_only=True)
|
| 78 |
model = AutoencoderKL.from_config(config).to(DTYPE)
|
| 79 |
state_dict = torch.load(quant_ckpt, map_location="cpu")
|
| 80 |
model.load_state_dict(state_dict, assign=True)
|
| 81 |
print(f"Loaded {quant_ckpt}")
|
| 82 |
return model
|
| 83 |
model = AutoencoderKL.from_pretrained(
|
| 84 |
+
FLUX_CHECKPOINT, subfolder="vae", torch_dtype=DTYPE, local_files_only=True
|
| 85 |
)
|
| 86 |
if quantize:
|
| 87 |
quantize_(model, quant_config)
|
|
|
|
| 117 |
if request.seed is None:
|
| 118 |
generator = None
|
| 119 |
else:
|
| 120 |
+
generator = Generator(device=_pipeline.device).manual_seed(request.seed)
|
| 121 |
|
| 122 |
empty_cache()
|
| 123 |
image = _pipeline(prompt=request.prompt,
|