--- license: apache-2.0 --- # Repos https://github.com/mit-han-lab/deepcompressor # Installation https://github.com/mit-han-lab/deepcompressor/issues/56 https://github.com/nunchaku-tech/deepcompressor/issues/80 # Windows https://learn.microsoft.com/en-us/windows/wsl/install https://www.anaconda.com/docs/getting-started/miniconda/install # Environment Hardware: Nvidia RTX 5060 Ti (Blackwell, sm_120) Software (WSL): Python 3.12.11 pip 25.1 CUDA 12.8 Torch 2.7.1+cu128 Diffusers 0.35.0.dev0 Transformers 4.53.2 flash_attn 2.7.4.post1 xformers 0.0.31.post1 # Calibration Dataset Preparation https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#step-2-calibration-dataset-preparation Example: `python -m deepcompressor.app.diffusion.dataset.collect.calib svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/collect/qdiff.yaml --pipeline-path svdq/flux.1-kontext-dev/` Sample Log ``` In total 32 samples Evaluating with batch size 1 Data: 3%|██▎ | 1/32 [13:57<7:12:32, 837.19s/it] Sampling: 12%|█████████▍ | 1/8 [01:34<11:01, 94.44s/it] ``` # Quantization https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#step-3-model-quantization Model Path: https://github.com/nunchaku-tech/deepcompressor/issues/70#issuecomment-2788155233 Save model: `--save-model true` or `--save-model /PATH/TO/CHECKPOINT/DIR` Example: `python -m deepcompressor.app.diffusion.ptq svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/svdquant/nvfp4.yaml --pipeline-path svdq/flux.1-kontext-dev/ --save-model ~/svdq/` Model Files Structure - refer [black-forest-labs/FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/tree/main) # Deploy https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#deployment Example `python -m deepcompressor.backend.nunchaku.convert --quant-path ~/svdq/ --output-root ~/svdq/ --model-name flux.1-kontext-dev-svdq-fp4` ComfyUI metadata reference: - FP4: https://huggingface.co/mit-han-lab/nunchaku-flux.1-kontext-dev/blob/main/svdq-fp4_r32-flux.1-kontext-dev.safetensors - INT4: https://huggingface.co/mit-han-lab/nunchaku-flux.1-kontext-dev/blob/main/svdq-int4_r32-flux.1-kontext-dev.safetensors --- # Remarks 2025-07-23 Test Notes - FP4 quantization model loads successfully in ComfyUI, but isn’t fully functional yet. Needs further investigation and debugging. - Calibration dataset appears misaligned, may need to revisit and adjust the sampling code for Flux.1 Kontext Dev. - Later, consider running another test using the base [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main) model for comparison. - Check with the deepcompressor/nunchaku team, request their latest working implementation. --- # Blockers 1) NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device. Potential fix: app.diffusion.pipeline.config.py ```python @staticmethod def _default_build( name: str, path: str, dtype: str | torch.dtype, device: str | torch.device, shift_activations: bool ) -> DiffusionPipeline: if not path: if name == "sdxl": path = "stabilityai/stable-diffusion-xl-base-1.0" elif name == "sdxl-turbo": path = "stabilityai/sdxl-turbo" elif name == "pixart-sigma": path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS" elif name == "flux.1-kontext-dev": path = "black-forest-labs/FLUX.1-Kontext-dev" elif name == "flux.1-dev": path = "black-forest-labs/FLUX.1-dev" elif name == "flux.1-canny-dev": path = "black-forest-labs/FLUX.1-Canny-dev" elif name == "flux.1-depth-dev": path = "black-forest-labs/FLUX.1-Depth-dev" elif name == "flux.1-fill-dev": path = "black-forest-labs/FLUX.1-Fill-dev" elif name == "flux.1-schnell": path = "black-forest-labs/FLUX.1-schnell" else: raise ValueError(f"Path for {name} is not specified.") if name in ["flux.1-kontext-dev"]: pipeline = FluxKontextPipeline.from_pretrained(path, torch_dtype=dtype) elif name in ["flux.1-canny-dev", "flux.1-depth-dev"]: pipeline = FluxControlPipeline.from_pretrained(path, torch_dtype=dtype) elif name == "flux.1-fill-dev": pipeline = FluxFillPipeline.from_pretrained(path, torch_dtype=dtype) elif name.startswith("sana-"): if dtype == torch.bfloat16: pipeline = SanaPipeline.from_pretrained(path, variant="bf16", torch_dtype=dtype, use_safetensors=True) pipeline.vae.to(dtype) pipeline.text_encoder.to(dtype) else: pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype) else: pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype) # Debug output print(">>> DEVICE:", device) print(">>> PIPELINE TYPE:", type(pipeline)) # Try to move each component using .to_empty() for name in ["unet", "transformer", "vae", "text_encoder"]: module = getattr(pipeline, name, None) if isinstance(module, torch.nn.Module): try: print(f">>> Moving {name} to {device} using to_empty()") module.to_empty(device=device) except Exception as e: print(f">>> WARNING: {name}.to_empty({device}) failed: {e}") try: print(f">>> Falling back to {name}.to({device})") module.to(device) except Exception as ee: print(f">>> ERROR: {name}.to({device}) also failed: {ee}") # Identify main model (for patching) model = getattr(pipeline, "unet", None) or getattr(pipeline, "transformer", None) if model is not None: replace_fused_linear_with_concat_linear(model) replace_up_block_conv_with_concat_conv(model) if shift_activations: shift_input_activations(model) else: print(">>> WARNING: No model (unet/transformer) found for patching") return pipeline ``` 2) KeyError: Potential fix: app.diffusion.nn.struct.py ```python @staticmethod def _default_construct( module: Attention, /, parent: tp.Optional["DiffusionTransformerBlockStruct"] = None, fname: str = "", rname: str = "", rkey: str = "", idx: int = 0, **kwargs, ) -> "DiffusionAttentionStruct": if isinstance(module, FluxAttention): # FluxAttention has different attribute names than standard attention with_rope = True num_query_heads = module.heads # FluxAttention uses 'heads', not 'num_heads' num_key_value_heads = module.heads # FLUX typically uses same for q/k/v # FluxAttention doesn't have 'to_out', but may have other output projections # Check what output projection attributes actually exist o_proj = None o_proj_rname = "" # Try to find the correct output projection if hasattr(module, 'to_out') and module.to_out is not None: o_proj = module.to_out[0] if isinstance(module.to_out, (list, tuple)) else module.to_out o_proj_rname = "to_out.0" if isinstance(module.to_out, (list, tuple)) else "to_out" elif hasattr(module, 'to_add_out'): o_proj = module.to_add_out o_proj_rname = "to_add_out" q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v" q, k, v = module.to_q, module.to_k, module.to_v q_rname, k_rname, v_rname = "to_q", "to_k", "to_v" # Handle the add_* projections that FluxAttention has add_q_proj = getattr(module, "add_q_proj", None) add_k_proj = getattr(module, "add_k_proj", None) add_v_proj = getattr(module, "add_v_proj", None) add_o_proj = getattr(module, "to_add_out", None) add_q_proj_rname = "add_q_proj" if add_q_proj else "" add_k_proj_rname = "add_k_proj" if add_k_proj else "" add_v_proj_rname = "add_v_proj" if add_v_proj else "" add_o_proj_rname = "to_add_out" if add_o_proj else "" kwargs = ( "encoder_hidden_states", "attention_mask", "image_rotary_emb", ) cross_attention = add_k_proj is not None elif module.is_cross_attention: q_proj, k_proj, v_proj = module.to_q, None, None add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", "" add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", "" else: q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v add_q_proj = getattr(module, "add_q_proj", None) add_k_proj = getattr(module, "add_k_proj", None) add_v_proj = getattr(module, "add_v_proj", None) add_o_proj = getattr(module, "to_add_out", None) q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v" add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj" add_o_proj_rname = "to_add_out" if getattr(module, "to_out", None) is not None: o_proj = module.to_out[0] o_proj_rname = "to_out.0" assert isinstance(o_proj, nn.Linear) elif parent is not None: assert isinstance(parent.module, FluxSingleTransformerBlock) assert isinstance(parent.module.proj_out, ConcatLinear) assert len(parent.module.proj_out.linears) == 2 o_proj = parent.module.proj_out.linears[0] o_proj_rname = ".proj_out.linears.0" else: raise RuntimeError("Cannot find the output projection.") if isinstance(module.processor, DiffusionAttentionProcessor): with_rope = module.processor.rope is not None elif module.processor.__class__.__name__.startswith("Flux"): with_rope = True else: with_rope = False # TODO: fix for other processors config = AttentionConfigStruct( hidden_size=q_proj.weight.shape[1], add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0, inner_size=q_proj.weight.shape[0], num_query_heads=module.heads, num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads), with_qk_norm=module.norm_q is not None, with_rope=with_rope, linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0), ) return DiffusionAttentionStruct( module=module, parent=parent, fname=fname, idx=idx, rname=rname, rkey=rkey, config=config, q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, o_proj=o_proj, add_q_proj=add_q_proj, add_k_proj=add_k_proj, add_v_proj=add_v_proj, add_o_proj=add_o_proj, q=None, # TODO: add q, k, v k=None, v=None, q_proj_rname=q_proj_rname, k_proj_rname=k_proj_rname, v_proj_rname=v_proj_rname, o_proj_rname=o_proj_rname, add_q_proj_rname=add_q_proj_rname, add_k_proj_rname=add_k_proj_rname, add_v_proj_rname=add_v_proj_rname, add_o_proj_rname=add_o_proj_rname, q_rname="", k_rname="", v_rname="", ) ``` 3) ValueError: Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined. Potential Fix: app.diffusion.dataset.collect.calib.py ```python def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset): samples_dirpath = os.path.join(config.output.root, "samples") caches_dirpath = os.path.join(config.output.root, "caches") os.makedirs(samples_dirpath, exist_ok=True) os.makedirs(caches_dirpath, exist_ok=True) caches = [] pipeline = config.pipeline.build() model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer assert isinstance(model, nn.Module) model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True) batch_size = config.eval.batch_size print(f"In total {len(dataset)} samples") print(f"Evaluating with batch size {batch_size}") pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1) for batch in tqdm( dataset.iter(batch_size=batch_size, drop_last_batch=False), desc="Data", leave=False, dynamic_ncols=True, total=(len(dataset) + batch_size - 1) // batch_size, ): filenames = batch["filename"] prompts = batch["prompt"] seeds = [hash_str_to_int(name) for name in filenames] generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds] pipeline_kwargs = config.eval.get_pipeline_kwargs() task = config.pipeline.task control_root = config.eval.control_root if task in ["canny-to-image", "depth-to-image", "inpainting"]: controls = get_control( task, batch["image"], names=batch["filename"], data_root=os.path.join( control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}" ), ) if task == "inpainting": pipeline_kwargs["image"] = controls[0] pipeline_kwargs["mask_image"] = controls[1] else: pipeline_kwargs["control_image"] = controls # Handle meta tensors by moving individual components try: pipeline = pipeline.to("cuda") except NotImplementedError: # Move individual pipeline components that have to_empty method if hasattr(pipeline, 'transformer') and pipeline.transformer is not None: try: pipeline.transformer = pipeline.transformer.to("cuda") except NotImplementedError: pipeline.transformer = pipeline.transformer.to_empty(device="cuda") if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None: try: pipeline.text_encoder = pipeline.text_encoder.to("cuda") except NotImplementedError: pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda") if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None: try: pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda") except NotImplementedError: pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda") if hasattr(pipeline, 'vae') and pipeline.vae is not None: try: pipeline.vae = pipeline.vae.to("cuda") except NotImplementedError: pipeline.vae = pipeline.vae.to_empty(device="cuda") result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images num_guidances = (len(caches) // batch_size) // config.eval.num_steps num_steps = len(caches) // (batch_size * num_guidances) assert ( len(caches) == batch_size * num_steps * num_guidances ), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}" for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)): image.save(os.path.join(samples_dirpath, f"{filename}.png")) for s in range(num_steps): for g in range(num_guidances): c = caches[s * batch_size * num_guidances + g * batch_size + j] c["filename"] = filename c["step"] = s c["guidance"] = g c = tree_map(lambda x: process(x), c) torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt")) caches.clear() ``` 4) RuntimeError: Tensor.item() cannot be called on meta tensors Potential Fix: quantizer.impl.scale.py ```python def quantize_scale( s: torch.Tensor, /, *, quant_dtypes: tp.Sequence[QuantDataType], quant_spans: tp.Sequence[float], view_shapes: tp.Sequence[torch.Size], ) -> QuantScale: """Quantize the scale tensor. Args: s (`torch.Tensor`): The scale tensor. quant_dtypes (`Sequence[QuantDataType]`): The quantization dtypes of the scale tensor. quant_spans (`Sequence[float]`): The quantization spans of the scale tensor. view_shapes (`Sequence[torch.Size]`): The view shapes of the scale tensor. Returns: `QuantScale`: The quantized scale tensor. """ # Add validation at the start if s.numel() == 0: raise ValueError("Input tensor is empty") if s.isnan().any() or s.isinf().any(): raise ValueError("Input tensor contains NaN or Inf values") if (s == 0).all(): raise ValueError("Input tensor contains all zeros") # Add meta tensor check before any operations if s.is_meta: raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.") # Existing validation if s.isnan().any() or s.isinf().any(): raise ValueError("Input tensor contains NaN or Inf values") scale = QuantScale() s = s.abs() for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True): s = s.view(view_shape) # (#g0, rs0, #g1, rs1, #g2, rs2, ...) ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) # i.e., s_dynamic_span ss = simple_quantize( ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype ) # i.e., s_scale = s_dynamic_span / s_quant_span s = s / ss scale.append(ss) view_shape = view_shapes[-1] s = s.view(view_shape) if any(v != 1 for v in view_shape[1::2]): ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1]) else: assert quant_spans[-1] == 1, "The last quant span must be 1." ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1]) scale.append(ss) scale.remove_zero() return scale def quantize( self, *, # scale-based quantization related arguments scale: torch.Tensor | None = None, zero: torch.Tensor | None = None, # range-based quantization related arguments tensor: torch.Tensor | None = None, dynamic_range: DynamicRange | None = None, ) -> tuple[QuantScale, torch.Tensor]: """Get the quantization scale and zero point of the tensor to be quantized. Args: scale (`torch.Tensor` or `None`, *optional*, defaults to `None`): The scale tensor. zero (`torch.Tensor` or `None`, *optional*, defaults to `None`): The zero point tensor. tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`): Ten tensor to be quantized. This is only used for range-based quantization. dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`): The dynamic range of the tensor to be quantized. Returns: `tuple[QuantScale, torch.Tensor]`: The scale and the zero point. """ # region step 1: get the dynamic span for range-based scale or the scale tensor if scale is None: range_based = True assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor." dynamic_range = dynamic_range or DynamicRange() dynamic_range = dynamic_range.measure( tensor.view(self.tensor_view_shape), zero_domain=self.tensor_zero_domain, is_float_point=self.tensor_quant_dtype.is_float_point, ) dynamic_range = dynamic_range.intersect(self.tensor_range_bound) dynamic_span = (dynamic_range.max - dynamic_range.min) if self.has_zero_point else dynamic_range.max else: range_based = False scale = scale.view(self.scale_view_shapes[-1]) assert isinstance(scale, torch.Tensor), "Scale must be a tensor." # endregion # region step 2: get the scale if self.linear_scale_quant_dtypes: if range_based: linear_scale = dynamic_span / self.linear_tensor_quant_span elif self.exponent_scale_quant_dtypes: linear_scale = scale.mul(self.exponent_tensor_quant_span).div(self.linear_tensor_quant_span) else: linear_scale = scale lin_s = quantize_scale( linear_scale, quant_dtypes=self.linear_scale_quant_dtypes, quant_spans=self.linear_scale_quant_spans, view_shapes=self.linear_scale_view_shapes, ) assert lin_s.data is not None, "Linear scale tensor is None." if not lin_s.data.is_meta: assert not lin_s.data.isnan().any(), "Linear scale tensor contains NaN." assert not lin_s.data.isinf().any(), "Linear scale tensor contains Inf." else: lin_s = QuantScale() if self.exponent_scale_quant_dtypes: if range_based: exp_scale = dynamic_span / self.exponent_tensor_quant_span else: exp_scale = scale if lin_s.data is not None: lin_s.data = lin_s.data.expand(self.linear_scale_view_shapes[-1]).reshape(self.scale_view_shapes[-1]) exp_scale = exp_scale / lin_s.data exp_s = quantize_scale( exp_scale, quant_dtypes=self.exponent_scale_quant_dtypes, quant_spans=self.exponent_scale_quant_spans, view_shapes=self.exponent_scale_view_shapes, ) assert exp_s.data is not None, "Exponential scale tensor is None." assert not exp_s.data.isnan().any(), "Exponential scale tensor contains NaN." assert not exp_s.data.isinf().any(), "Exponential scale tensor contains Inf." s = exp_s if lin_s.data is None else lin_s.extend(exp_s) else: s = lin_s # Before the final assertions, add debugging and validation if s.data is None: # Log debugging information print(f"Linear scale dtypes: {self.linear_scale_quant_dtypes}") print(f"Exponent scale dtypes: {self.exponent_scale_quant_dtypes}") if hasattr(lin_s, 'data') and lin_s.data is not None: print(f"Linear scale data shape: {lin_s.data.shape}") raise RuntimeError("Scale computation failed - resulting scale is None") assert s.data is not None, "Scale tensor is None." assert not s.data.isnan().any(), "Scale tensor contains NaN." assert not s.data.isinf().any(), "Scale tensor contains Inf." # endregion # region step 3: get the zero point if self.has_zero_point: if range_based: if self.tensor_zero_domain == ZeroPointDomain.PreScale: zero = self.tensor_quant_range.min - dynamic_range.min / s.data else: zero = self.tensor_quant_range.min * s.data - dynamic_range.min assert isinstance(zero, torch.Tensor), "Zero point must be a tensor." z = simple_quantize(zero, has_zero_point=True, quant_dtype=self.zero_quant_dtype) else: z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device) assert not z.isnan().any(), "Zero point tensor contains NaN." assert not z.isinf().any(), "Zero point tensor contains Inf." # endregion return s, z ``` Potential Fix: app.diffusion.ptq.py ```python def ptq( # noqa: C901 model: DiffusionModelStruct, config: DiffusionQuantConfig, cache: DiffusionPtqCacheConfig | None = None, load_dirpath: str = "", save_dirpath: str = "", copy_on_save: bool = False, save_model: bool = False, ) -> DiffusionModelStruct: """Post-training quantization of a diffusion model. Args: model (`DiffusionModelStruct`): The diffusion model. config (`DiffusionQuantConfig`): The diffusion model post-training quantization configuration. cache (`DiffusionPtqCacheConfig`, *optional*, defaults to `None`): The diffusion model quantization cache path configuration. load_dirpath (`str`, *optional*, defaults to `""`): The directory path to load the quantization checkpoint. save_dirpath (`str`, *optional*, defaults to `""`): The directory path to save the quantization checkpoint. copy_on_save (`bool`, *optional*, defaults to `False`): Whether to copy the cache to the save directory. save_model (`bool`, *optional*, defaults to `False`): Whether to save the quantized model checkpoint. Returns: `DiffusionModelStruct`: The quantized diffusion model. """ logger = tools.logging.getLogger(__name__) if not isinstance(model, DiffusionModelStruct): model = DiffusionModelStruct.construct(model) assert isinstance(model, DiffusionModelStruct) quant_wgts = config.enabled_wgts quant_ipts = config.enabled_ipts quant_opts = config.enabled_opts quant_acts = quant_ipts or quant_opts quant = quant_wgts or quant_acts load_model_path, load_path, save_path = "", None, None if load_dirpath: load_path = DiffusionQuantCacheConfig( smooth=os.path.join(load_dirpath, "smooth.pt"), branch=os.path.join(load_dirpath, "branch.pt"), wgts=os.path.join(load_dirpath, "wgts.pt"), acts=os.path.join(load_dirpath, "acts.pt"), ) load_model_path = os.path.join(load_dirpath, "model.pt") if os.path.exists(load_model_path): if config.enabled_wgts and config.wgts.enabled_low_rank: if os.path.exists(load_path.branch): load_model = True else: logger.warning(f"Model low-rank branch checkpoint {load_path.branch} does not exist") load_model = False else: load_model = True if load_model: logger.info(f"* Loading model from {load_model_path}") save_dirpath = "" # do not save the model if loading else: logger.warning(f"Model checkpoint {load_model_path} does not exist") load_model = False else: load_model = False if save_dirpath: os.makedirs(save_dirpath, exist_ok=True) save_path = DiffusionQuantCacheConfig( smooth=os.path.join(save_dirpath, "smooth.pt"), branch=os.path.join(save_dirpath, "branch.pt"), wgts=os.path.join(save_dirpath, "wgts.pt"), acts=os.path.join(save_dirpath, "acts.pt"), ) else: save_model = False if quant and config.enabled_rotation: logger.info("* Rotating model for quantization") tools.logging.Formatter.indent_inc() rotate_diffusion(model, config=config) tools.logging.Formatter.indent_dec() gc.collect() torch.cuda.empty_cache() # region smooth quantization if quant and config.enabled_smooth: logger.info("* Smoothing model for quantization") tools.logging.Formatter.indent_inc() load_from = "" if load_path and os.path.exists(load_path.smooth): load_from = load_path.smooth elif cache and cache.path.smooth and os.path.exists(cache.path.smooth): load_from = cache.path.smooth if load_from: logger.info(f"- Loading smooth scales from {load_from}") smooth_cache = torch.load(load_from) smooth_diffusion(model, config, smooth_cache=smooth_cache) else: logger.info("- Generating smooth scales") smooth_cache = smooth_diffusion(model, config) if cache and cache.path.smooth: logger.info(f"- Saving smooth scales to {cache.path.smooth}") os.makedirs(cache.dirpath.smooth, exist_ok=True) torch.save(smooth_cache, cache.path.smooth) load_from = cache.path.smooth if save_path: if not copy_on_save and load_from: logger.info(f"- Linking smooth scales to {save_path.smooth}") os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth) else: logger.info(f"- Saving smooth scales to {save_path.smooth}") torch.save(smooth_cache, save_path.smooth) del smooth_cache tools.logging.Formatter.indent_dec() gc.collect() torch.cuda.empty_cache() # endregion # region collect original state dict if config.needs_acts_quantizer_cache: if load_path and os.path.exists(load_path.acts): orig_state_dict = None elif cache and cache.path.acts and os.path.exists(cache.path.acts): orig_state_dict = None else: orig_state_dict: dict[str, torch.Tensor] = { name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1 } else: orig_state_dict = None # endregion if load_model: logger.info(f"* Loading model checkpoint from {load_model_path}") load_diffusion_weights_state_dict( model, config, state_dict=torch.load(load_model_path), branch_state_dict=torch.load(load_path.branch) if os.path.exists(load_path.branch) else None, ) gc.collect() torch.cuda.empty_cache() elif quant_wgts: logger.info("* Ensuring model is on actual device before quantization") # Check if model has meta tensors has_meta_tensors = any(param.is_meta for param in model.module.parameters()) if has_meta_tensors: logger.info("* Model contains meta tensors, materializing to actual device") # Option 1: Use to_empty() and reload weights (recommended) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Store original state dict if available try: original_state_dict = model.module.state_dict() model.module = model.module.to_empty(device=device) model.module.load_state_dict(original_state_dict) logger.info("* Successfully materialized model with original weights") except Exception as e: logger.warning(f"* Failed to preserve weights during materialization: {e}") # Fallback: just move to empty device (weights will be zero) model.module = model.module.to_empty(device=device) logger.warning("* Model moved to device but weights may be uninitialized") else: # Model already has real tensors, just ensure it's on the right device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.module = model.module.to(device) # Verify no meta tensors remain remaining_meta = [name for name, param in model.module.named_parameters() if param.is_meta] if remaining_meta: raise RuntimeError(f"Parameters still on meta device: {remaining_meta}") logger.info("* Model successfully prepared for quantization") logger.info("* Quantizing weights") tools.logging.Formatter.indent_inc() quantizer_state_dict, quantizer_load_from = None, "" if load_path and os.path.exists(load_path.wgts): quantizer_load_from = load_path.wgts elif cache and cache.path.wgts and os.path.exists(cache.path.wgts): quantizer_load_from = cache.path.wgts if quantizer_load_from: logger.info(f"- Loading weight settings from {quantizer_load_from}") quantizer_state_dict = torch.load(quantizer_load_from) branch_state_dict, branch_load_from = None, "" if load_path and os.path.exists(load_path.branch): branch_load_from = load_path.branch elif cache and cache.path.branch and os.path.exists(cache.path.branch): branch_load_from = cache.path.branch if branch_load_from: logger.info(f"- Loading branch settings from {branch_load_from}") branch_state_dict = torch.load(branch_load_from) if not quantizer_load_from: logger.info("- Generating weight settings") if not branch_load_from: logger.info("- Generating branch settings") quantizer_state_dict, branch_state_dict, scale_state_dict = quantize_diffusion_weights( model, config, quantizer_state_dict=quantizer_state_dict, branch_state_dict=branch_state_dict, return_with_scale_state_dict=bool(save_dirpath), ) if not quantizer_load_from and cache and cache.dirpath.wgts: logger.info(f"- Saving weight settings to {cache.path.wgts}") os.makedirs(cache.dirpath.wgts, exist_ok=True) torch.save(quantizer_state_dict, cache.path.wgts) quantizer_load_from = cache.path.wgts if not branch_load_from and cache and cache.dirpath.branch: logger.info(f"- Saving branch settings to {cache.path.branch}") os.makedirs(cache.dirpath.branch, exist_ok=True) torch.save(branch_state_dict, cache.path.branch) branch_load_from = cache.path.branch if save_path: if not copy_on_save and quantizer_load_from: logger.info(f"- Linking weight settings to {save_path.wgts}") os.symlink(os.path.relpath(quantizer_load_from, save_dirpath), save_path.wgts) else: logger.info(f"- Saving weight settings to {save_path.wgts}") torch.save(quantizer_state_dict, save_path.wgts) if not copy_on_save and branch_load_from: logger.info(f"- Linking branch settings to {save_path.branch}") os.symlink(os.path.relpath(branch_load_from, save_dirpath), save_path.branch) else: logger.info(f"- Saving branch settings to {save_path.branch}") torch.save(branch_state_dict, save_path.branch) if save_model: logger.info(f"- Saving model to {save_dirpath}") torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt")) torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt")) del quantizer_state_dict, branch_state_dict, scale_state_dict tools.logging.Formatter.indent_dec() gc.collect() torch.cuda.empty_cache() if quant_acts: logger.info(" * Quantizing activations") tools.logging.Formatter.indent_inc() if config.needs_acts_quantizer_cache: load_from = "" if load_path and os.path.exists(load_path.acts): load_from = load_path.acts elif cache and cache.path.acts and os.path.exists(cache.path.acts): load_from = cache.path.acts if load_from: logger.info(f"- Loading activation settings from {load_from}") quantizer_state_dict = torch.load(load_from) quantize_diffusion_activations( model, config, quantizer_state_dict=quantizer_state_dict, orig_state_dict=orig_state_dict ) else: logger.info("- Generating activation settings") quantizer_state_dict = quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict) if cache and cache.dirpath.acts and quantizer_state_dict is not None: logger.info(f"- Saving activation settings to {cache.path.acts}") os.makedirs(cache.dirpath.acts, exist_ok=True) torch.save(quantizer_state_dict, cache.path.acts) load_from = cache.path.acts if save_dirpath: if not copy_on_save and load_from: logger.info(f"- Linking activation quantizer settings to {save_path.acts}") os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts) else: logger.info(f"- Saving activation quantizer settings to {save_path.acts}") torch.save(quantizer_state_dict, save_path.acts) del quantizer_state_dict else: logger.info("- No need to generate/load activation quantizer settings") quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict) tools.logging.Formatter.indent_dec() del orig_state_dict gc.collect() torch.cuda.empty_cache() return model ``` 5) RuntimeError: Dataset scripts are no longer supported, but found COCO.py References https://github.com/nunchaku-tech/nunchaku/commit/b99fb8be615bc98c6915bbe06a1e0092cbc074a5 https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-kontext-dev.py https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L266 https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_kontext.py https://github.com/nunchaku-tech/deepcompressor/issues/91 https://deepwiki.com/nunchaku-tech/deepcompressor https://huggingface.co/mit-han-lab/nunchaku-flux.1-kontext-dev/tree/main --- # Dependencies https://github.com/Dao-AILab/flash-attention https://github.com/facebookresearch/xformers https://github.com/openai/CLIP https://github.com/THUDM/ImageReward # Wheels https://huggingface.co/datasets/siraxe/PrecompiledWheels_Torch-2.8-cu128-cp312 https://huggingface.co/lldacing/flash-attention-windows-wheel https://github.com/loscrossos/lib_flashattention/releases