see-through-demo / inference /scripts /quantize_and_push.py
24yearsold's picture
update: add ComfyUI Node Extension mention to description
b55a1fc verified
"""Quantize See-through pipeline models (UNet + text encoders) to NF4 or FP8.
Saves quantized components locally and optionally pushes to HuggingFace.
VAE and TransparentVAE remain in bf16.
Usage (from repo root):
python inference/scripts/quantize_and_push.py --quant_mode nf4
python inference/scripts/quantize_and_push.py --quant_mode fp8 --pipeline layerdiff
python inference/scripts/quantize_and_push.py --quant_mode nf4 --push_to_hub \
--output_repo_layerdiff layerdifforg/seethroughv0.0.2_layerdiff3d_nf4 \
--output_repo_depth 24yearsold/seethroughv0.0.1_marigold_nf4
"""
import os.path as osp
import argparse
import sys
import os
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
import shutil
import torch
from huggingface_hub import snapshot_download, HfApi
# diffusers BitsAndBytesConfig — for UNet (inherits from diffusers ModelMixin)
from diffusers import BitsAndBytesConfig as DiffusersBnBConfig
# transformers BitsAndBytesConfig — for CLIP text encoders (transformers models)
from transformers import BitsAndBytesConfig as TransformersBnBConfig
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel
def make_quant_configs(quant_mode: str):
"""Build diffusers and transformers BitsAndBytesConfig for the chosen mode."""
if quant_mode == "nf4":
diffusers_config = DiffusersBnBConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
transformers_config = TransformersBnBConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif quant_mode == "fp8":
diffusers_config = DiffusersBnBConfig(load_in_8bit=True)
transformers_config = TransformersBnBConfig(load_in_8bit=True)
else:
raise ValueError(f"Unknown quant_mode: {quant_mode!r}. Expected 'nf4' or 'fp8'.")
return diffusers_config, transformers_config
def copy_subfolder(src_root: str, dst_root: str, subfolder: str):
"""Copy a subfolder from src_root to dst_root, overwriting if it exists."""
src = osp.join(src_root, subfolder)
dst = osp.join(dst_root, subfolder)
if osp.isdir(src):
if osp.exists(dst):
shutil.rmtree(dst)
shutil.copytree(src, dst)
print(f" Copied subfolder: {subfolder}")
else:
print(f" WARNING: subfolder {subfolder!r} not found in {src_root}")
def copy_file(src_root: str, dst_root: str, filename: str):
"""Copy a single file from src_root to dst_root."""
src = osp.join(src_root, filename)
dst = osp.join(dst_root, filename)
if osp.isfile(src):
shutil.copy2(src, dst)
print(f" Copied file: {filename}")
else:
print(f" WARNING: file {filename!r} not found in {src_root}")
def quantize_layerdiff(
repo_id: str,
output_dir: str,
quant_mode: str,
hf_token: str = None,
):
"""Quantize the LayerDiff3D pipeline (SDXL-based)."""
print(f"\n{'='*60}")
print(f"Quantizing LayerDiff3D pipeline ({quant_mode})")
print(f" Source: {repo_id}")
print(f" Output: {output_dir}")
print(f"{'='*60}")
diffusers_config, transformers_config = make_quant_configs(quant_mode)
# --- Download source repo ---
print("\n[1/6] Downloading source repo...")
local_repo = snapshot_download(
repo_id,
token=hf_token,
)
print(f" Downloaded to: {local_repo}")
os.makedirs(output_dir, exist_ok=True)
# --- Quantize UNet ---
print("\n[2/6] Loading and quantizing UNet...")
unet = UNetFrameConditionModel.from_pretrained(
local_repo,
subfolder="unet",
quantization_config=diffusers_config,
torch_dtype=torch.bfloat16,
)
unet_dir = osp.join(output_dir, "unet")
os.makedirs(unet_dir, exist_ok=True)
unet.save_pretrained(unet_dir)
print(f" Saved quantized UNet to: {unet_dir}")
del unet
torch.cuda.empty_cache()
# --- Quantize text_encoder (CLIPTextModel) ---
print("\n[3/6] Loading and quantizing text_encoder (CLIPTextModel)...")
text_encoder = CLIPTextModel.from_pretrained(
local_repo,
subfolder="text_encoder",
quantization_config=transformers_config,
torch_dtype=torch.bfloat16,
)
te_dir = osp.join(output_dir, "text_encoder")
os.makedirs(te_dir, exist_ok=True)
text_encoder.save_pretrained(te_dir)
print(f" Saved quantized text_encoder to: {te_dir}")
del text_encoder
torch.cuda.empty_cache()
# --- Quantize text_encoder_2 (CLIPTextModelWithProjection) ---
print("\n[4/6] Loading and quantizing text_encoder_2 (CLIPTextModelWithProjection)...")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
local_repo,
subfolder="text_encoder_2",
quantization_config=transformers_config,
torch_dtype=torch.bfloat16,
)
te2_dir = osp.join(output_dir, "text_encoder_2")
os.makedirs(te2_dir, exist_ok=True)
text_encoder_2.save_pretrained(te2_dir)
print(f" Saved quantized text_encoder_2 to: {te2_dir}")
del text_encoder_2
torch.cuda.empty_cache()
# --- Copy bf16 components as-is ---
print("\n[5/6] Copying bf16 components (VAE, TransparentVAE, scheduler, tokenizers)...")
bf16_subfolders = ["trans_vae", "vae", "scheduler", "tokenizer", "tokenizer_2"]
for sf in bf16_subfolders:
copy_subfolder(local_repo, output_dir, sf)
# --- Copy root config files ---
print("\n[6/6] Copying root config files...")
for fname in os.listdir(local_repo):
fpath = osp.join(local_repo, fname)
if osp.isfile(fpath):
copy_file(local_repo, output_dir, fname)
print(f"\nLayerDiff3D quantization complete: {output_dir}")
def quantize_marigold(
repo_id: str,
output_dir: str,
quant_mode: str,
hf_token: str = None,
):
"""Quantize the Marigold3D pipeline (SD1.5-based)."""
print(f"\n{'='*60}")
print(f"Quantizing Marigold3D pipeline ({quant_mode})")
print(f" Source: {repo_id}")
print(f" Output: {output_dir}")
print(f"{'='*60}")
diffusers_config, transformers_config = make_quant_configs(quant_mode)
# --- Download source repo ---
print("\n[1/4] Downloading source repo...")
local_repo = snapshot_download(
repo_id,
token=hf_token,
)
print(f" Downloaded to: {local_repo}")
os.makedirs(output_dir, exist_ok=True)
# --- Quantize UNet ---
print("\n[2/4] Loading and quantizing UNet...")
unet = UNetFrameConditionModel.from_pretrained(
local_repo,
subfolder="unet",
quantization_config=diffusers_config,
torch_dtype=torch.bfloat16,
)
unet_dir = osp.join(output_dir, "unet")
os.makedirs(unet_dir, exist_ok=True)
unet.save_pretrained(unet_dir)
print(f" Saved quantized UNet to: {unet_dir}")
del unet
torch.cuda.empty_cache()
# --- Quantize text_encoder (CLIPTextModel, single encoder for SD1.5) ---
print("\n[3/4] Loading and quantizing text_encoder (CLIPTextModel)...")
text_encoder = CLIPTextModel.from_pretrained(
local_repo,
subfolder="text_encoder",
quantization_config=transformers_config,
torch_dtype=torch.bfloat16,
)
te_dir = osp.join(output_dir, "text_encoder")
os.makedirs(te_dir, exist_ok=True)
text_encoder.save_pretrained(te_dir)
print(f" Saved quantized text_encoder to: {te_dir}")
del text_encoder
torch.cuda.empty_cache()
# --- Copy bf16 components as-is ---
print("\n[4/4] Copying bf16 components (VAE, scheduler, tokenizer)...")
bf16_subfolders = ["vae", "scheduler", "tokenizer"]
for sf in bf16_subfolders:
copy_subfolder(local_repo, output_dir, sf)
# Copy root config files
print(" Copying root config files...")
for fname in os.listdir(local_repo):
fpath = osp.join(local_repo, fname)
if osp.isfile(fpath):
copy_file(local_repo, output_dir, fname)
print(f"\nMarigold3D quantization complete: {output_dir}")
def push_to_hub(local_dir: str, repo_id: str, hf_token: str):
"""Upload quantized model directory to HuggingFace."""
print(f"\nPushing {local_dir} -> {repo_id}")
api = HfApi(token=hf_token)
api.create_repo(repo_id, repo_type="model", exist_ok=True)
api.upload_folder(
folder_path=local_dir,
repo_id=repo_id,
repo_type="model",
)
print(f" Uploaded to: https://huggingface.co/{repo_id}")
def main():
parser = argparse.ArgumentParser(
description="Quantize See-through models to NF4 or FP8 and optionally push to HuggingFace."
)
parser.add_argument(
"--pipeline",
type=str,
default="both",
choices=["layerdiff", "marigold", "both"],
help="Which pipeline to quantize (default: both)",
)
parser.add_argument(
"--repo_id_layerdiff",
type=str,
default="layerdifforg/seethroughv0.0.2_layerdiff3d",
help="Source bf16 HF repo for LayerDiff3D",
)
parser.add_argument(
"--repo_id_depth",
type=str,
default="24yearsold/seethroughv0.0.1_marigold",
help="Source bf16 HF repo for Marigold3D",
)
parser.add_argument(
"--output_repo_layerdiff",
type=str,
default=None,
help="Target HF repo for quantized LayerDiff3D (required if --push_to_hub)",
)
parser.add_argument(
"--output_repo_depth",
type=str,
default=None,
help="Target HF repo for quantized Marigold3D (required if --push_to_hub)",
)
parser.add_argument(
"--output_local",
type=str,
default="workspace/quantized_models/",
help="Local save root directory (default: workspace/quantized_models/)",
)
parser.add_argument(
"--quant_mode",
type=str,
default="nf4",
choices=["nf4", "fp8"],
help="Quantization mode (default: nf4)",
)
parser.add_argument(
"--push_to_hub",
action="store_true",
help="Push quantized models to HuggingFace",
)
args = parser.parse_args()
# Read HF token
hf_token = None
hf_credential_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))), "hf_credential")
if osp.isfile(hf_credential_path):
hf_token = open(hf_credential_path).read().strip()
print(f"Loaded HF token from {hf_credential_path}")
elif os.environ.get("HF_TOKEN"):
hf_token = os.environ["HF_TOKEN"]
print("Using HF_TOKEN from environment")
else:
print("WARNING: No HF token found. Private repos will fail to download.")
# Validate push args
if args.push_to_hub:
if args.pipeline in ("layerdiff", "both") and not args.output_repo_layerdiff:
parser.error("--output_repo_layerdiff is required when --push_to_hub is set for layerdiff pipeline")
if args.pipeline in ("marigold", "both") and not args.output_repo_depth:
parser.error("--output_repo_depth is required when --push_to_hub is set for marigold pipeline")
if hf_token is None:
parser.error("--push_to_hub requires an HF token (hf_credential file or HF_TOKEN env var)")
# Build output paths
layerdiff_dir = osp.join(args.output_local, "layerdiff")
marigold_dir = osp.join(args.output_local, "marigold")
print(f"\nQuantization mode: {args.quant_mode}")
print(f"Pipeline(s): {args.pipeline}")
# --- Quantize ---
if args.pipeline in ("layerdiff", "both"):
quantize_layerdiff(
repo_id=args.repo_id_layerdiff,
output_dir=layerdiff_dir,
quant_mode=args.quant_mode,
hf_token=hf_token,
)
if args.pipeline in ("marigold", "both"):
quantize_marigold(
repo_id=args.repo_id_depth,
output_dir=marigold_dir,
quant_mode=args.quant_mode,
hf_token=hf_token,
)
# --- Push to HF ---
if args.push_to_hub:
if args.pipeline in ("layerdiff", "both"):
push_to_hub(layerdiff_dir, args.output_repo_layerdiff, hf_token)
if args.pipeline in ("marigold", "both"):
push_to_hub(marigold_dir, args.output_repo_depth, hf_token)
# --- Summary ---
print(f"\n{'='*60}")
print("Summary")
print(f"{'='*60}")
print(f"Quantization mode: {args.quant_mode}")
if args.pipeline in ("layerdiff", "both"):
print(f"LayerDiff3D:")
print(f" Source: {args.repo_id_layerdiff}")
print(f" Local: {osp.abspath(layerdiff_dir)}")
print(f" Quantized: unet, text_encoder, text_encoder_2")
print(f" Kept bf16: trans_vae, vae, scheduler, tokenizer, tokenizer_2")
if args.push_to_hub:
print(f" Pushed to: {args.output_repo_layerdiff}")
if args.pipeline in ("marigold", "both"):
print(f"Marigold3D:")
print(f" Source: {args.repo_id_depth}")
print(f" Local: {osp.abspath(marigold_dir)}")
print(f" Quantized: unet, text_encoder")
print(f" Kept bf16: vae, scheduler, tokenizer")
if args.push_to_hub:
print(f" Pushed to: {args.output_repo_depth}")
print(f"{'='*60}")
if __name__ == "__main__":
main()