"""Quantize See-through pipeline models (UNet + text encoders) to NF4 or FP8. Saves quantized components locally and optionally pushes to HuggingFace. VAE and TransparentVAE remain in bf16. Usage (from repo root): python inference/scripts/quantize_and_push.py --quant_mode nf4 python inference/scripts/quantize_and_push.py --quant_mode fp8 --pipeline layerdiff python inference/scripts/quantize_and_push.py --quant_mode nf4 --push_to_hub \ --output_repo_layerdiff layerdifforg/seethroughv0.0.2_layerdiff3d_nf4 \ --output_repo_depth 24yearsold/seethroughv0.0.1_marigold_nf4 """ import os.path as osp import argparse import sys import os sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) import shutil import torch from huggingface_hub import snapshot_download, HfApi # diffusers BitsAndBytesConfig — for UNet (inherits from diffusers ModelMixin) from diffusers import BitsAndBytesConfig as DiffusersBnBConfig # transformers BitsAndBytesConfig — for CLIP text encoders (transformers models) from transformers import BitsAndBytesConfig as TransformersBnBConfig from transformers import CLIPTextModel, CLIPTextModelWithProjection from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel def make_quant_configs(quant_mode: str): """Build diffusers and transformers BitsAndBytesConfig for the chosen mode.""" if quant_mode == "nf4": diffusers_config = DiffusersBnBConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) transformers_config = TransformersBnBConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) elif quant_mode == "fp8": diffusers_config = DiffusersBnBConfig(load_in_8bit=True) transformers_config = TransformersBnBConfig(load_in_8bit=True) else: raise ValueError(f"Unknown quant_mode: {quant_mode!r}. Expected 'nf4' or 'fp8'.") return diffusers_config, transformers_config def copy_subfolder(src_root: str, dst_root: str, subfolder: str): """Copy a subfolder from src_root to dst_root, overwriting if it exists.""" src = osp.join(src_root, subfolder) dst = osp.join(dst_root, subfolder) if osp.isdir(src): if osp.exists(dst): shutil.rmtree(dst) shutil.copytree(src, dst) print(f" Copied subfolder: {subfolder}") else: print(f" WARNING: subfolder {subfolder!r} not found in {src_root}") def copy_file(src_root: str, dst_root: str, filename: str): """Copy a single file from src_root to dst_root.""" src = osp.join(src_root, filename) dst = osp.join(dst_root, filename) if osp.isfile(src): shutil.copy2(src, dst) print(f" Copied file: {filename}") else: print(f" WARNING: file {filename!r} not found in {src_root}") def quantize_layerdiff( repo_id: str, output_dir: str, quant_mode: str, hf_token: str = None, ): """Quantize the LayerDiff3D pipeline (SDXL-based).""" print(f"\n{'='*60}") print(f"Quantizing LayerDiff3D pipeline ({quant_mode})") print(f" Source: {repo_id}") print(f" Output: {output_dir}") print(f"{'='*60}") diffusers_config, transformers_config = make_quant_configs(quant_mode) # --- Download source repo --- print("\n[1/6] Downloading source repo...") local_repo = snapshot_download( repo_id, token=hf_token, ) print(f" Downloaded to: {local_repo}") os.makedirs(output_dir, exist_ok=True) # --- Quantize UNet --- print("\n[2/6] Loading and quantizing UNet...") unet = UNetFrameConditionModel.from_pretrained( local_repo, subfolder="unet", quantization_config=diffusers_config, torch_dtype=torch.bfloat16, ) unet_dir = osp.join(output_dir, "unet") os.makedirs(unet_dir, exist_ok=True) unet.save_pretrained(unet_dir) print(f" Saved quantized UNet to: {unet_dir}") del unet torch.cuda.empty_cache() # --- Quantize text_encoder (CLIPTextModel) --- print("\n[3/6] Loading and quantizing text_encoder (CLIPTextModel)...") text_encoder = CLIPTextModel.from_pretrained( local_repo, subfolder="text_encoder", quantization_config=transformers_config, torch_dtype=torch.bfloat16, ) te_dir = osp.join(output_dir, "text_encoder") os.makedirs(te_dir, exist_ok=True) text_encoder.save_pretrained(te_dir) print(f" Saved quantized text_encoder to: {te_dir}") del text_encoder torch.cuda.empty_cache() # --- Quantize text_encoder_2 (CLIPTextModelWithProjection) --- print("\n[4/6] Loading and quantizing text_encoder_2 (CLIPTextModelWithProjection)...") text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( local_repo, subfolder="text_encoder_2", quantization_config=transformers_config, torch_dtype=torch.bfloat16, ) te2_dir = osp.join(output_dir, "text_encoder_2") os.makedirs(te2_dir, exist_ok=True) text_encoder_2.save_pretrained(te2_dir) print(f" Saved quantized text_encoder_2 to: {te2_dir}") del text_encoder_2 torch.cuda.empty_cache() # --- Copy bf16 components as-is --- print("\n[5/6] Copying bf16 components (VAE, TransparentVAE, scheduler, tokenizers)...") bf16_subfolders = ["trans_vae", "vae", "scheduler", "tokenizer", "tokenizer_2"] for sf in bf16_subfolders: copy_subfolder(local_repo, output_dir, sf) # --- Copy root config files --- print("\n[6/6] Copying root config files...") for fname in os.listdir(local_repo): fpath = osp.join(local_repo, fname) if osp.isfile(fpath): copy_file(local_repo, output_dir, fname) print(f"\nLayerDiff3D quantization complete: {output_dir}") def quantize_marigold( repo_id: str, output_dir: str, quant_mode: str, hf_token: str = None, ): """Quantize the Marigold3D pipeline (SD1.5-based).""" print(f"\n{'='*60}") print(f"Quantizing Marigold3D pipeline ({quant_mode})") print(f" Source: {repo_id}") print(f" Output: {output_dir}") print(f"{'='*60}") diffusers_config, transformers_config = make_quant_configs(quant_mode) # --- Download source repo --- print("\n[1/4] Downloading source repo...") local_repo = snapshot_download( repo_id, token=hf_token, ) print(f" Downloaded to: {local_repo}") os.makedirs(output_dir, exist_ok=True) # --- Quantize UNet --- print("\n[2/4] Loading and quantizing UNet...") unet = UNetFrameConditionModel.from_pretrained( local_repo, subfolder="unet", quantization_config=diffusers_config, torch_dtype=torch.bfloat16, ) unet_dir = osp.join(output_dir, "unet") os.makedirs(unet_dir, exist_ok=True) unet.save_pretrained(unet_dir) print(f" Saved quantized UNet to: {unet_dir}") del unet torch.cuda.empty_cache() # --- Quantize text_encoder (CLIPTextModel, single encoder for SD1.5) --- print("\n[3/4] Loading and quantizing text_encoder (CLIPTextModel)...") text_encoder = CLIPTextModel.from_pretrained( local_repo, subfolder="text_encoder", quantization_config=transformers_config, torch_dtype=torch.bfloat16, ) te_dir = osp.join(output_dir, "text_encoder") os.makedirs(te_dir, exist_ok=True) text_encoder.save_pretrained(te_dir) print(f" Saved quantized text_encoder to: {te_dir}") del text_encoder torch.cuda.empty_cache() # --- Copy bf16 components as-is --- print("\n[4/4] Copying bf16 components (VAE, scheduler, tokenizer)...") bf16_subfolders = ["vae", "scheduler", "tokenizer"] for sf in bf16_subfolders: copy_subfolder(local_repo, output_dir, sf) # Copy root config files print(" Copying root config files...") for fname in os.listdir(local_repo): fpath = osp.join(local_repo, fname) if osp.isfile(fpath): copy_file(local_repo, output_dir, fname) print(f"\nMarigold3D quantization complete: {output_dir}") def push_to_hub(local_dir: str, repo_id: str, hf_token: str): """Upload quantized model directory to HuggingFace.""" print(f"\nPushing {local_dir} -> {repo_id}") api = HfApi(token=hf_token) api.create_repo(repo_id, repo_type="model", exist_ok=True) api.upload_folder( folder_path=local_dir, repo_id=repo_id, repo_type="model", ) print(f" Uploaded to: https://huggingface.co/{repo_id}") def main(): parser = argparse.ArgumentParser( description="Quantize See-through models to NF4 or FP8 and optionally push to HuggingFace." ) parser.add_argument( "--pipeline", type=str, default="both", choices=["layerdiff", "marigold", "both"], help="Which pipeline to quantize (default: both)", ) parser.add_argument( "--repo_id_layerdiff", type=str, default="layerdifforg/seethroughv0.0.2_layerdiff3d", help="Source bf16 HF repo for LayerDiff3D", ) parser.add_argument( "--repo_id_depth", type=str, default="24yearsold/seethroughv0.0.1_marigold", help="Source bf16 HF repo for Marigold3D", ) parser.add_argument( "--output_repo_layerdiff", type=str, default=None, help="Target HF repo for quantized LayerDiff3D (required if --push_to_hub)", ) parser.add_argument( "--output_repo_depth", type=str, default=None, help="Target HF repo for quantized Marigold3D (required if --push_to_hub)", ) parser.add_argument( "--output_local", type=str, default="workspace/quantized_models/", help="Local save root directory (default: workspace/quantized_models/)", ) parser.add_argument( "--quant_mode", type=str, default="nf4", choices=["nf4", "fp8"], help="Quantization mode (default: nf4)", ) parser.add_argument( "--push_to_hub", action="store_true", help="Push quantized models to HuggingFace", ) args = parser.parse_args() # Read HF token hf_token = None hf_credential_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))), "hf_credential") if osp.isfile(hf_credential_path): hf_token = open(hf_credential_path).read().strip() print(f"Loaded HF token from {hf_credential_path}") elif os.environ.get("HF_TOKEN"): hf_token = os.environ["HF_TOKEN"] print("Using HF_TOKEN from environment") else: print("WARNING: No HF token found. Private repos will fail to download.") # Validate push args if args.push_to_hub: if args.pipeline in ("layerdiff", "both") and not args.output_repo_layerdiff: parser.error("--output_repo_layerdiff is required when --push_to_hub is set for layerdiff pipeline") if args.pipeline in ("marigold", "both") and not args.output_repo_depth: parser.error("--output_repo_depth is required when --push_to_hub is set for marigold pipeline") if hf_token is None: parser.error("--push_to_hub requires an HF token (hf_credential file or HF_TOKEN env var)") # Build output paths layerdiff_dir = osp.join(args.output_local, "layerdiff") marigold_dir = osp.join(args.output_local, "marigold") print(f"\nQuantization mode: {args.quant_mode}") print(f"Pipeline(s): {args.pipeline}") # --- Quantize --- if args.pipeline in ("layerdiff", "both"): quantize_layerdiff( repo_id=args.repo_id_layerdiff, output_dir=layerdiff_dir, quant_mode=args.quant_mode, hf_token=hf_token, ) if args.pipeline in ("marigold", "both"): quantize_marigold( repo_id=args.repo_id_depth, output_dir=marigold_dir, quant_mode=args.quant_mode, hf_token=hf_token, ) # --- Push to HF --- if args.push_to_hub: if args.pipeline in ("layerdiff", "both"): push_to_hub(layerdiff_dir, args.output_repo_layerdiff, hf_token) if args.pipeline in ("marigold", "both"): push_to_hub(marigold_dir, args.output_repo_depth, hf_token) # --- Summary --- print(f"\n{'='*60}") print("Summary") print(f"{'='*60}") print(f"Quantization mode: {args.quant_mode}") if args.pipeline in ("layerdiff", "both"): print(f"LayerDiff3D:") print(f" Source: {args.repo_id_layerdiff}") print(f" Local: {osp.abspath(layerdiff_dir)}") print(f" Quantized: unet, text_encoder, text_encoder_2") print(f" Kept bf16: trans_vae, vae, scheduler, tokenizer, tokenizer_2") if args.push_to_hub: print(f" Pushed to: {args.output_repo_layerdiff}") if args.pipeline in ("marigold", "both"): print(f"Marigold3D:") print(f" Source: {args.repo_id_depth}") print(f" Local: {osp.abspath(marigold_dir)}") print(f" Quantized: unet, text_encoder") print(f" Kept bf16: vae, scheduler, tokenizer") if args.push_to_hub: print(f" Pushed to: {args.output_repo_depth}") print(f"{'='*60}") if __name__ == "__main__": main()