Spaces:
Runtime error
Runtime error
| """Quantize See-through pipeline models (UNet + text encoders) to NF4 or FP8. | |
| Saves quantized components locally and optionally pushes to HuggingFace. | |
| VAE and TransparentVAE remain in bf16. | |
| Usage (from repo root): | |
| python inference/scripts/quantize_and_push.py --quant_mode nf4 | |
| python inference/scripts/quantize_and_push.py --quant_mode fp8 --pipeline layerdiff | |
| python inference/scripts/quantize_and_push.py --quant_mode nf4 --push_to_hub \ | |
| --output_repo_layerdiff layerdifforg/seethroughv0.0.2_layerdiff3d_nf4 \ | |
| --output_repo_depth 24yearsold/seethroughv0.0.1_marigold_nf4 | |
| """ | |
| import os.path as osp | |
| import argparse | |
| import sys | |
| import os | |
| sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) | |
| import shutil | |
| import torch | |
| from huggingface_hub import snapshot_download, HfApi | |
| # diffusers BitsAndBytesConfig — for UNet (inherits from diffusers ModelMixin) | |
| from diffusers import BitsAndBytesConfig as DiffusersBnBConfig | |
| # transformers BitsAndBytesConfig — for CLIP text encoders (transformers models) | |
| from transformers import BitsAndBytesConfig as TransformersBnBConfig | |
| from transformers import CLIPTextModel, CLIPTextModelWithProjection | |
| from modules.layerdiffuse.layerdiff3d import UNetFrameConditionModel | |
| def make_quant_configs(quant_mode: str): | |
| """Build diffusers and transformers BitsAndBytesConfig for the chosen mode.""" | |
| if quant_mode == "nf4": | |
| diffusers_config = DiffusersBnBConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| transformers_config = TransformersBnBConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| elif quant_mode == "fp8": | |
| diffusers_config = DiffusersBnBConfig(load_in_8bit=True) | |
| transformers_config = TransformersBnBConfig(load_in_8bit=True) | |
| else: | |
| raise ValueError(f"Unknown quant_mode: {quant_mode!r}. Expected 'nf4' or 'fp8'.") | |
| return diffusers_config, transformers_config | |
| def copy_subfolder(src_root: str, dst_root: str, subfolder: str): | |
| """Copy a subfolder from src_root to dst_root, overwriting if it exists.""" | |
| src = osp.join(src_root, subfolder) | |
| dst = osp.join(dst_root, subfolder) | |
| if osp.isdir(src): | |
| if osp.exists(dst): | |
| shutil.rmtree(dst) | |
| shutil.copytree(src, dst) | |
| print(f" Copied subfolder: {subfolder}") | |
| else: | |
| print(f" WARNING: subfolder {subfolder!r} not found in {src_root}") | |
| def copy_file(src_root: str, dst_root: str, filename: str): | |
| """Copy a single file from src_root to dst_root.""" | |
| src = osp.join(src_root, filename) | |
| dst = osp.join(dst_root, filename) | |
| if osp.isfile(src): | |
| shutil.copy2(src, dst) | |
| print(f" Copied file: {filename}") | |
| else: | |
| print(f" WARNING: file {filename!r} not found in {src_root}") | |
| def quantize_layerdiff( | |
| repo_id: str, | |
| output_dir: str, | |
| quant_mode: str, | |
| hf_token: str = None, | |
| ): | |
| """Quantize the LayerDiff3D pipeline (SDXL-based).""" | |
| print(f"\n{'='*60}") | |
| print(f"Quantizing LayerDiff3D pipeline ({quant_mode})") | |
| print(f" Source: {repo_id}") | |
| print(f" Output: {output_dir}") | |
| print(f"{'='*60}") | |
| diffusers_config, transformers_config = make_quant_configs(quant_mode) | |
| # --- Download source repo --- | |
| print("\n[1/6] Downloading source repo...") | |
| local_repo = snapshot_download( | |
| repo_id, | |
| token=hf_token, | |
| ) | |
| print(f" Downloaded to: {local_repo}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # --- Quantize UNet --- | |
| print("\n[2/6] Loading and quantizing UNet...") | |
| unet = UNetFrameConditionModel.from_pretrained( | |
| local_repo, | |
| subfolder="unet", | |
| quantization_config=diffusers_config, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| unet_dir = osp.join(output_dir, "unet") | |
| os.makedirs(unet_dir, exist_ok=True) | |
| unet.save_pretrained(unet_dir) | |
| print(f" Saved quantized UNet to: {unet_dir}") | |
| del unet | |
| torch.cuda.empty_cache() | |
| # --- Quantize text_encoder (CLIPTextModel) --- | |
| print("\n[3/6] Loading and quantizing text_encoder (CLIPTextModel)...") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| local_repo, | |
| subfolder="text_encoder", | |
| quantization_config=transformers_config, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| te_dir = osp.join(output_dir, "text_encoder") | |
| os.makedirs(te_dir, exist_ok=True) | |
| text_encoder.save_pretrained(te_dir) | |
| print(f" Saved quantized text_encoder to: {te_dir}") | |
| del text_encoder | |
| torch.cuda.empty_cache() | |
| # --- Quantize text_encoder_2 (CLIPTextModelWithProjection) --- | |
| print("\n[4/6] Loading and quantizing text_encoder_2 (CLIPTextModelWithProjection)...") | |
| text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( | |
| local_repo, | |
| subfolder="text_encoder_2", | |
| quantization_config=transformers_config, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| te2_dir = osp.join(output_dir, "text_encoder_2") | |
| os.makedirs(te2_dir, exist_ok=True) | |
| text_encoder_2.save_pretrained(te2_dir) | |
| print(f" Saved quantized text_encoder_2 to: {te2_dir}") | |
| del text_encoder_2 | |
| torch.cuda.empty_cache() | |
| # --- Copy bf16 components as-is --- | |
| print("\n[5/6] Copying bf16 components (VAE, TransparentVAE, scheduler, tokenizers)...") | |
| bf16_subfolders = ["trans_vae", "vae", "scheduler", "tokenizer", "tokenizer_2"] | |
| for sf in bf16_subfolders: | |
| copy_subfolder(local_repo, output_dir, sf) | |
| # --- Copy root config files --- | |
| print("\n[6/6] Copying root config files...") | |
| for fname in os.listdir(local_repo): | |
| fpath = osp.join(local_repo, fname) | |
| if osp.isfile(fpath): | |
| copy_file(local_repo, output_dir, fname) | |
| print(f"\nLayerDiff3D quantization complete: {output_dir}") | |
| def quantize_marigold( | |
| repo_id: str, | |
| output_dir: str, | |
| quant_mode: str, | |
| hf_token: str = None, | |
| ): | |
| """Quantize the Marigold3D pipeline (SD1.5-based).""" | |
| print(f"\n{'='*60}") | |
| print(f"Quantizing Marigold3D pipeline ({quant_mode})") | |
| print(f" Source: {repo_id}") | |
| print(f" Output: {output_dir}") | |
| print(f"{'='*60}") | |
| diffusers_config, transformers_config = make_quant_configs(quant_mode) | |
| # --- Download source repo --- | |
| print("\n[1/4] Downloading source repo...") | |
| local_repo = snapshot_download( | |
| repo_id, | |
| token=hf_token, | |
| ) | |
| print(f" Downloaded to: {local_repo}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # --- Quantize UNet --- | |
| print("\n[2/4] Loading and quantizing UNet...") | |
| unet = UNetFrameConditionModel.from_pretrained( | |
| local_repo, | |
| subfolder="unet", | |
| quantization_config=diffusers_config, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| unet_dir = osp.join(output_dir, "unet") | |
| os.makedirs(unet_dir, exist_ok=True) | |
| unet.save_pretrained(unet_dir) | |
| print(f" Saved quantized UNet to: {unet_dir}") | |
| del unet | |
| torch.cuda.empty_cache() | |
| # --- Quantize text_encoder (CLIPTextModel, single encoder for SD1.5) --- | |
| print("\n[3/4] Loading and quantizing text_encoder (CLIPTextModel)...") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| local_repo, | |
| subfolder="text_encoder", | |
| quantization_config=transformers_config, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| te_dir = osp.join(output_dir, "text_encoder") | |
| os.makedirs(te_dir, exist_ok=True) | |
| text_encoder.save_pretrained(te_dir) | |
| print(f" Saved quantized text_encoder to: {te_dir}") | |
| del text_encoder | |
| torch.cuda.empty_cache() | |
| # --- Copy bf16 components as-is --- | |
| print("\n[4/4] Copying bf16 components (VAE, scheduler, tokenizer)...") | |
| bf16_subfolders = ["vae", "scheduler", "tokenizer"] | |
| for sf in bf16_subfolders: | |
| copy_subfolder(local_repo, output_dir, sf) | |
| # Copy root config files | |
| print(" Copying root config files...") | |
| for fname in os.listdir(local_repo): | |
| fpath = osp.join(local_repo, fname) | |
| if osp.isfile(fpath): | |
| copy_file(local_repo, output_dir, fname) | |
| print(f"\nMarigold3D quantization complete: {output_dir}") | |
| def push_to_hub(local_dir: str, repo_id: str, hf_token: str): | |
| """Upload quantized model directory to HuggingFace.""" | |
| print(f"\nPushing {local_dir} -> {repo_id}") | |
| api = HfApi(token=hf_token) | |
| api.create_repo(repo_id, repo_type="model", exist_ok=True) | |
| api.upload_folder( | |
| folder_path=local_dir, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| print(f" Uploaded to: https://huggingface.co/{repo_id}") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Quantize See-through models to NF4 or FP8 and optionally push to HuggingFace." | |
| ) | |
| parser.add_argument( | |
| "--pipeline", | |
| type=str, | |
| default="both", | |
| choices=["layerdiff", "marigold", "both"], | |
| help="Which pipeline to quantize (default: both)", | |
| ) | |
| parser.add_argument( | |
| "--repo_id_layerdiff", | |
| type=str, | |
| default="layerdifforg/seethroughv0.0.2_layerdiff3d", | |
| help="Source bf16 HF repo for LayerDiff3D", | |
| ) | |
| parser.add_argument( | |
| "--repo_id_depth", | |
| type=str, | |
| default="24yearsold/seethroughv0.0.1_marigold", | |
| help="Source bf16 HF repo for Marigold3D", | |
| ) | |
| parser.add_argument( | |
| "--output_repo_layerdiff", | |
| type=str, | |
| default=None, | |
| help="Target HF repo for quantized LayerDiff3D (required if --push_to_hub)", | |
| ) | |
| parser.add_argument( | |
| "--output_repo_depth", | |
| type=str, | |
| default=None, | |
| help="Target HF repo for quantized Marigold3D (required if --push_to_hub)", | |
| ) | |
| parser.add_argument( | |
| "--output_local", | |
| type=str, | |
| default="workspace/quantized_models/", | |
| help="Local save root directory (default: workspace/quantized_models/)", | |
| ) | |
| parser.add_argument( | |
| "--quant_mode", | |
| type=str, | |
| default="nf4", | |
| choices=["nf4", "fp8"], | |
| help="Quantization mode (default: nf4)", | |
| ) | |
| parser.add_argument( | |
| "--push_to_hub", | |
| action="store_true", | |
| help="Push quantized models to HuggingFace", | |
| ) | |
| args = parser.parse_args() | |
| # Read HF token | |
| hf_token = None | |
| hf_credential_path = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))), "hf_credential") | |
| if osp.isfile(hf_credential_path): | |
| hf_token = open(hf_credential_path).read().strip() | |
| print(f"Loaded HF token from {hf_credential_path}") | |
| elif os.environ.get("HF_TOKEN"): | |
| hf_token = os.environ["HF_TOKEN"] | |
| print("Using HF_TOKEN from environment") | |
| else: | |
| print("WARNING: No HF token found. Private repos will fail to download.") | |
| # Validate push args | |
| if args.push_to_hub: | |
| if args.pipeline in ("layerdiff", "both") and not args.output_repo_layerdiff: | |
| parser.error("--output_repo_layerdiff is required when --push_to_hub is set for layerdiff pipeline") | |
| if args.pipeline in ("marigold", "both") and not args.output_repo_depth: | |
| parser.error("--output_repo_depth is required when --push_to_hub is set for marigold pipeline") | |
| if hf_token is None: | |
| parser.error("--push_to_hub requires an HF token (hf_credential file or HF_TOKEN env var)") | |
| # Build output paths | |
| layerdiff_dir = osp.join(args.output_local, "layerdiff") | |
| marigold_dir = osp.join(args.output_local, "marigold") | |
| print(f"\nQuantization mode: {args.quant_mode}") | |
| print(f"Pipeline(s): {args.pipeline}") | |
| # --- Quantize --- | |
| if args.pipeline in ("layerdiff", "both"): | |
| quantize_layerdiff( | |
| repo_id=args.repo_id_layerdiff, | |
| output_dir=layerdiff_dir, | |
| quant_mode=args.quant_mode, | |
| hf_token=hf_token, | |
| ) | |
| if args.pipeline in ("marigold", "both"): | |
| quantize_marigold( | |
| repo_id=args.repo_id_depth, | |
| output_dir=marigold_dir, | |
| quant_mode=args.quant_mode, | |
| hf_token=hf_token, | |
| ) | |
| # --- Push to HF --- | |
| if args.push_to_hub: | |
| if args.pipeline in ("layerdiff", "both"): | |
| push_to_hub(layerdiff_dir, args.output_repo_layerdiff, hf_token) | |
| if args.pipeline in ("marigold", "both"): | |
| push_to_hub(marigold_dir, args.output_repo_depth, hf_token) | |
| # --- Summary --- | |
| print(f"\n{'='*60}") | |
| print("Summary") | |
| print(f"{'='*60}") | |
| print(f"Quantization mode: {args.quant_mode}") | |
| if args.pipeline in ("layerdiff", "both"): | |
| print(f"LayerDiff3D:") | |
| print(f" Source: {args.repo_id_layerdiff}") | |
| print(f" Local: {osp.abspath(layerdiff_dir)}") | |
| print(f" Quantized: unet, text_encoder, text_encoder_2") | |
| print(f" Kept bf16: trans_vae, vae, scheduler, tokenizer, tokenizer_2") | |
| if args.push_to_hub: | |
| print(f" Pushed to: {args.output_repo_layerdiff}") | |
| if args.pipeline in ("marigold", "both"): | |
| print(f"Marigold3D:") | |
| print(f" Source: {args.repo_id_depth}") | |
| print(f" Local: {osp.abspath(marigold_dir)}") | |
| print(f" Quantized: unet, text_encoder") | |
| print(f" Kept bf16: vae, scheduler, tokenizer") | |
| if args.push_to_hub: | |
| print(f" Pushed to: {args.output_repo_depth}") | |
| print(f"{'='*60}") | |
| if __name__ == "__main__": | |
| main() | |