Instructions to use mhnakif/comfy2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use mhnakif/comfy2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("mhnakif/comfy2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Sharpness subspace calibration via sinusoidal grating stimuli. | |
| Replaces the previous Gaussian blur approach with narrowband frequency | |
| gratings, which achieve higher linearity (R²=0.94 vs 0.88) because each | |
| stimulus contains a single spatial frequency — a purer probe of the VAE's | |
| frequency encoding axis. | |
| The two methods discover the same 1D subspace (|cos|=0.986, 9.7° apart), | |
| but grating stimuli yield a cleaner PC1 direction. | |
| """ | |
| import math | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple | |
| import warnings | |
| import torch | |
| import comfy.utils | |
| from .patchify import patchify | |
| from .lcs_data import LCSData | |
| class SharpnessData: | |
| """Calibration data for the sharpness subspace. | |
| Produced by PCA on FLUX VAE-encoded sinusoidal gratings at varying | |
| spatial frequencies. PC1 captures ~94% of variance with R²=0.94 | |
| linearity vs log₂(frequency). | |
| """ | |
| basis: torch.Tensor # [64, K] PCA basis (columns), K typically 1-2 | |
| mean: torch.Tensor # [64] PCA mean (in color-removed space if lcs_data was used) | |
| sign: float # +1 or -1: ensures positive strength = sharper | |
| lcs_basis: Optional[torch.Tensor] = None # [64, 3] LCS basis used during calibration (for re-orthogonalization) | |
| def to(self, device, dtype=None): | |
| """Move all tensors to device/dtype.""" | |
| kw = {"device": device} | |
| if dtype is not None: | |
| kw["dtype"] = dtype | |
| return SharpnessData( | |
| basis=self.basis.to(**kw), | |
| mean=self.mean.to(**kw), | |
| sign=self.sign, | |
| lcs_basis=self.lcs_basis.to(**kw) if self.lcs_basis is not None else None, | |
| ) | |
| def _generate_grating_batch( | |
| indices: List[int], | |
| angles: torch.Tensor, | |
| phases: torch.Tensor, | |
| frequencies: Tuple[float, ...], | |
| coord_x: torch.Tensor, | |
| coord_y: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Generate a batch of sinusoidal grating stimuli by flat index. | |
| Each flat index maps to (orientation, frequency) via divmod. | |
| Returns [len(indices), 3, H, W] tensor. | |
| """ | |
| num_freqs = len(frequencies) | |
| batch = [] | |
| for idx in indices: | |
| ori = idx // num_freqs | |
| freq = frequencies[idx % num_freqs] | |
| angle = angles[ori].item() | |
| phase = phases[ori].item() | |
| cos_a, sin_a = math.cos(angle), math.sin(angle) | |
| coord = coord_x * cos_a + coord_y * sin_a | |
| grating = 0.5 + 0.3 * torch.sin(2 * math.pi * freq * coord + phase) | |
| batch.append(grating.unsqueeze(0).expand(3, -1, -1)) | |
| return torch.stack(batch, dim=0) | |
| def calibrate_sharpness(vae, num_samples: int = 64, image_size: int = 512, | |
| frequencies: Tuple[float, ...] = (1, 2, 4, 8, 16, 32, 64), | |
| batch_size: int = 8, | |
| lcs_data: LCSData = None, | |
| # Legacy parameter — accepted but ignored | |
| blur_levels: Optional[Tuple[float, ...]] = None, | |
| ) -> SharpnessData: | |
| """Compute sharpness subspace data (PCA basis, mean, sign) from FLUX VAE. | |
| Generates sinusoidal gratings at varying spatial frequencies (one pure | |
| frequency per stimulus), VAE-encodes them, and runs PCA to find the | |
| sharpness/frequency direction in 64D patch space. | |
| Args: | |
| vae: ComfyUI VAE object | |
| num_samples: Number of orientations (each combined with all frequencies) | |
| image_size: Size of generated images | |
| frequencies: Spatial frequencies in cycles/image | |
| batch_size: Batch size for VAE encoding | |
| lcs_data: Optional LCS data for removing color component during calibration. | |
| When provided, the sharpness PC1 will be orthogonal to the color subspace, | |
| preventing color shifts during intervention. | |
| Returns: SharpnessData | |
| """ | |
| if blur_levels is not None: | |
| warnings.warn( | |
| "blur_levels is deprecated and ignored; calibration now uses sinusoidal gratings", | |
| DeprecationWarning, stacklevel=2, | |
| ) | |
| n_freqs = len(frequencies) | |
| total_images = num_samples * n_freqs | |
| print(f"\n[LCS Sharpness Calibration] Starting: {num_samples} orientations × {n_freqs} frequencies = {total_images} stimuli") | |
| print(f"[LCS Sharpness Calibration] Frequencies: {list(frequencies)} cycles/image") | |
| # Pre-compute shared state for grating generation | |
| gen = torch.Generator().manual_seed(42) | |
| angles = torch.rand(num_samples, generator=gen) * math.pi # [0, π) | |
| phases = torch.rand(num_samples, generator=gen) * 2 * math.pi # [0, 2π) | |
| y_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(1) | |
| x_coords = torch.linspace(-0.5, 0.5, image_size).unsqueeze(0) | |
| coord_y = y_coords.expand(image_size, image_size) | |
| coord_x = x_coords.expand(image_size, image_size) | |
| # Build frequency labels for all stimuli (flat index → frequency) | |
| freq_labels = [frequencies[idx % n_freqs] for idx in range(total_images)] | |
| freq_labels_t = torch.tensor(freq_labels, dtype=torch.float32) | |
| log_freq = torch.log2(freq_labels_t.clamp(min=0.5)) | |
| # Generate stimuli lazily per batch and VAE encode | |
| vectors = [] | |
| pbar = comfy.utils.ProgressBar(total_images) | |
| for batch_start in range(0, total_images, batch_size): | |
| batch_end = min(batch_start + batch_size, total_images) | |
| indices = list(range(batch_start, batch_end)) | |
| batch = _generate_grating_batch(indices, angles, phases, frequencies, coord_x, coord_y) | |
| actual_batch = batch.shape[0] | |
| # Convert BCHW → BHWC for ComfyUI VAE | |
| imgs_bhwc = batch.permute(0, 2, 3, 1).contiguous().cpu() | |
| # VAE encode — try batch first, fall back to per-image for video VAEs | |
| latent = vae.encode(imgs_bhwc) | |
| patches, _, _, _ = patchify(latent) | |
| avg = patches.mean(dim=1).cpu() | |
| if avg.shape[0] == actual_batch: | |
| vectors.extend(avg.unbind(0)) | |
| else: | |
| # Video VAE: batch not fully supported, encode one by one | |
| vectors.extend(avg.unbind(0)) | |
| for k in range(1, actual_batch): | |
| single = imgs_bhwc[k:k+1] | |
| lat = vae.encode(single) | |
| p, _, _, _ = patchify(lat) | |
| vectors.append(p.mean(dim=1).cpu().squeeze(0)) | |
| pbar.update(actual_batch) | |
| # Stack all vectors: [N, 64] | |
| X = torch.stack(vectors, dim=0).float() | |
| print(f"[LCS Sharpness Calibration] Collected {X.shape[0]} vectors of dimension {X.shape[1]}") | |
| # Remove LCS color component FIRST, in the raw space where LCS was calibrated. | |
| # This must happen before per-vector DC removal, because the LCS basis has | |
| # significant DC components (PC1 ≈ brightness). Doing DC removal first would | |
| # shift vectors into a different space where B^T(x - mu) is incorrect. | |
| if lcs_data is not None: | |
| print("[LCS Sharpness Calibration] Removing LCS color component...") | |
| lcs_mean = lcs_data.mean.to(X.device, X.dtype) | |
| lcs_basis = lcs_data.basis.to(X.device, X.dtype) | |
| # Project out color: X' = X - B B^T (X - mu) | |
| centered = X - lcs_mean | |
| lcs_coords = centered @ lcs_basis # [N, 3] | |
| X = X - lcs_coords @ lcs_basis.T | |
| print("[LCS Sharpness Calibration] Color component removed") | |
| # Remove per-vector DC AFTER color removal. | |
| # VAE encoding shifts the latent mean depending on stimulus content. | |
| # Per-vector zero-mean forces PCA to find patterns in the relative channel | |
| # structure, not in the absolute level. | |
| X = X - X.mean(dim=1, keepdim=True) | |
| # Step 3: PCA | |
| print("[LCS Sharpness Calibration] Computing PCA...") | |
| mean = X.mean(dim=0) # [64] | |
| X_centered = X - mean | |
| U, S, Vh = torch.linalg.svd(X_centered, full_matrices=False) | |
| # Top 2 components | |
| basis = Vh[:2].T # [64, 2] | |
| # Variance explained | |
| total_var = (S ** 2).sum() | |
| explained = (S[:2] ** 2) / total_var | |
| print(f"[LCS Sharpness Calibration] PC1: {explained[0]:.1%}, PC2: {explained[1]:.1%} ({(explained[0]+explained[1]):.1%} total)") | |
| # Step 4: Determine sign convention | |
| # Project all vectors onto PC1 | |
| pc1_scores = X_centered @ basis[:, 0] # [N] | |
| # Correlate PC1 score with log₂(frequency) | |
| # Higher frequency = sharper → if positive correlation, sign = +1 | |
| correlation = torch.corrcoef(torch.stack([pc1_scores, log_freq]))[0, 1] | |
| sign = 1.0 if correlation > 0 else -1.0 | |
| print(f"[LCS Sharpness Calibration] PC1-frequency correlation: {correlation:.3f} → sign = {sign:+.0f}") | |
| print(f"[LCS Sharpness Calibration] Complete! Basis shape: {basis.shape}") | |
| return SharpnessData( | |
| basis=basis, | |
| mean=mean, | |
| sign=sign, | |
| lcs_basis=lcs_data.basis.clone() if lcs_data is not None else None, | |
| ) | |