Spaces:
Running on Zero
Running on Zero
| """ | |
| FlowSlider: 3プロンプト方向性分解による連続スケール制御 | |
| FlowEditの実画像編集能力を維持しながら、FreeSlidersの考え方を取り入れて | |
| 連続的な編集強度制御を可能にする手法。 | |
| 数式: | |
| V_steer = V_tar_pos - V_tar_neg (純粋な編集方向) | |
| V_fid = V_tar_neg - V_src (ベース変化) | |
| V_delta_s = V_fid + strength * V_steer | |
| strength=0: tar_neg方向への編集(例:劣化なし) | |
| strength=1: tar_pos方向への編集(例:完全劣化) | |
| 0<strength<1: 連続的な中間状態 | |
| """ | |
| from typing import Optional, Tuple, Union, Dict, List, Any | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from tqdm import tqdm | |
| import numpy as np | |
| import json | |
| import os | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps | |
| # FlowEdit_utils.pyから必要な関数をインポート | |
| from FlowEdit_utils import scale_noise, calculate_shift, calc_v_flux | |
| # ============================================ | |
| # Vector Logging and Visualization Utilities | |
| # ============================================ | |
| def compute_vector_stats( | |
| V_fid: torch.Tensor, | |
| V_steer: torch.Tensor, | |
| V_delta_s: torch.Tensor, | |
| zt_edit: torch.Tensor, | |
| prev_V_steer: Optional[torch.Tensor] = None, | |
| prev_zt_edit: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute statistics for velocity field vectors at a single timestep. | |
| Args: | |
| V_fid: Base velocity (V_neg - V_src) | |
| V_steer: Direction velocity (V_pos - V_neg) | |
| V_delta_s: Combined velocity (V_fid + strength * V_steer) | |
| zt_edit: Current edited latent | |
| prev_V_steer: V_steer from previous timestep (for cosine similarity) | |
| prev_zt_edit: zt_edit from previous timestep (for delta computation) | |
| Returns: | |
| Dictionary of statistics | |
| """ | |
| stats = {} | |
| # Compute norms (average over sequence dimension) | |
| stats["V_fid_norm"] = V_fid.norm(dim=-1).mean().item() | |
| stats["V_steer_norm"] = V_steer.norm(dim=-1).mean().item() | |
| stats["V_delta_s_norm"] = V_delta_s.norm(dim=-1).mean().item() | |
| stats["zt_edit_norm"] = zt_edit.norm(dim=-1).mean().item() | |
| # Compute cosine similarity with previous V_steer | |
| if prev_V_steer is not None: | |
| # Flatten for cosine similarity computation | |
| v_dir_flat = V_steer.view(-1) | |
| prev_v_dir_flat = prev_V_steer.view(-1) | |
| cos_sim = F.cosine_similarity(v_dir_flat.unsqueeze(0), prev_v_dir_flat.unsqueeze(0)).item() | |
| stats["V_steer_cosine"] = cos_sim | |
| else: | |
| stats["V_steer_cosine"] = 1.0 # First step, no previous | |
| # Compute angle between V_fid and V_steer | |
| v_base_flat = V_fid.view(-1) | |
| v_dir_flat = V_steer.view(-1) | |
| cos_angle = F.cosine_similarity(v_base_flat.unsqueeze(0), v_dir_flat.unsqueeze(0)).item() | |
| # Clamp to avoid numerical issues with arccos | |
| cos_angle = max(-1.0, min(1.0, cos_angle)) | |
| angle_rad = np.arccos(cos_angle) | |
| stats["V_fid_V_steer_angle"] = np.degrees(angle_rad) | |
| # Compute zt_edit delta (movement from previous step) | |
| if prev_zt_edit is not None: | |
| delta = (zt_edit - prev_zt_edit).norm(dim=-1).mean().item() | |
| stats["zt_edit_delta"] = delta | |
| else: | |
| stats["zt_edit_delta"] = 0.0 | |
| return stats | |
| def save_vector_stats( | |
| stats_list: List[Dict[str, Any]], | |
| output_dir: str, | |
| strength: float, | |
| ): | |
| """ | |
| Save vector statistics to JSON file. | |
| Args: | |
| stats_list: List of statistics dictionaries (one per timestep) | |
| output_dir: Output directory | |
| strength: Scale value used | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Reorganize data for easier plotting | |
| output_data = { | |
| "strength": strength, | |
| "timesteps": [s["timestep"] for s in stats_list], | |
| "V_fid_norm": [s["V_fid_norm"] for s in stats_list], | |
| "V_steer_norm": [s["V_steer_norm"] for s in stats_list], | |
| "V_delta_s_norm": [s["V_delta_s_norm"] for s in stats_list], | |
| "V_steer_cosine": [s["V_steer_cosine"] for s in stats_list], | |
| "V_fid_V_steer_angle": [s["V_fid_V_steer_angle"] for s in stats_list], | |
| "zt_edit_norm": [s["zt_edit_norm"] for s in stats_list], | |
| "zt_edit_delta": [s["zt_edit_delta"] for s in stats_list], | |
| } | |
| output_path = os.path.join(output_dir, f"stats_scale_{strength:.2f}.json") | |
| with open(output_path, "w") as f: | |
| json.dump(output_data, f, indent=2) | |
| return output_path | |
| def plot_vector_stats( | |
| stats_path: str, | |
| output_dir: str, | |
| ): | |
| """ | |
| Generate visualization plots from saved statistics. | |
| Args: | |
| stats_path: Path to stats JSON file | |
| output_dir: Output directory for plots | |
| """ | |
| import matplotlib.pyplot as plt | |
| with open(stats_path, "r") as f: | |
| data = json.load(f) | |
| strength = data["strength"] | |
| timesteps = data["timesteps"] | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Plot 1: Norms | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(timesteps, data["V_fid_norm"], 'b-', label="V_fid", linewidth=2) | |
| ax.plot(timesteps, data["V_steer_norm"], 'r-', label="V_steer", linewidth=2) | |
| ax.plot(timesteps, data["V_delta_s_norm"], 'g--', label="V_delta_s", linewidth=2) | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("L2 Norm", fontsize=12) | |
| ax.set_title(f"Velocity Field Norms (strength={strength:.2f})", fontsize=14) | |
| ax.legend(fontsize=11) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() # t goes from 1.0 to 0 | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, f"plot_norms_scale_{strength:.2f}.png"), dpi=150) | |
| plt.close() | |
| # Plot 2: V_steer Cosine Similarity | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(timesteps[1:], data["V_steer_cosine"][1:], 'purple', linewidth=2) | |
| ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.7, label="Stability threshold (0.9)") | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("Cosine Similarity", fontsize=12) | |
| ax.set_title(f"V_steer Directional Consistency (strength={strength:.2f})", fontsize=14) | |
| ax.set_ylim(-0.1, 1.1) | |
| ax.legend(fontsize=11) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, f"plot_cosine_scale_{strength:.2f}.png"), dpi=150) | |
| plt.close() | |
| # Plot 3: Angle between V_fid and V_steer | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(timesteps, data["V_fid_V_steer_angle"], 'orange', linewidth=2) | |
| ax.axhline(y=90, color='gray', linestyle='--', alpha=0.7, label="Orthogonal (90°)") | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("Angle (degrees)", fontsize=12) | |
| ax.set_title(f"Angle between V_fid and V_steer (strength={strength:.2f})", fontsize=14) | |
| ax.set_ylim(0, 180) | |
| ax.legend(fontsize=11) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, f"plot_angles_scale_{strength:.2f}.png"), dpi=150) | |
| plt.close() | |
| # Plot 4: Edit Trajectory (zt_edit movement) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(timesteps, data["zt_edit_delta"], 'teal', linewidth=2) | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("Step Movement (L2 norm)", fontsize=12) | |
| ax.set_title(f"Edit Trajectory: Per-step Movement (strength={strength:.2f})", fontsize=14) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, f"plot_trajectory_scale_{strength:.2f}.png"), dpi=150) | |
| plt.close() | |
| def plot_vector_comparison( | |
| log_dir: str, | |
| strengths: List[float], | |
| output_dir: str, | |
| ): | |
| """ | |
| Generate comparison plots across multiple strengths. | |
| Args: | |
| log_dir: Directory containing stats JSON files | |
| strengths: List of strength values to compare | |
| output_dir: Output directory for comparison plots | |
| """ | |
| import matplotlib.pyplot as plt | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Load all stats | |
| all_data = {} | |
| for strength in strengths: | |
| stats_path = os.path.join(log_dir, f"stats_scale_{strength:.2f}.json") | |
| if os.path.exists(stats_path): | |
| with open(stats_path, "r") as f: | |
| all_data[strength] = json.load(f) | |
| if not all_data: | |
| print(f"No stats files found in {log_dir}") | |
| return | |
| # Color map for different strengths | |
| colors = plt.cm.viridis(np.linspace(0, 1, len(strengths))) | |
| # Comparison Plot 1: Norms (V_steer) | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| for (strength, data), color in zip(all_data.items(), colors): | |
| ax.plot(data["timesteps"], data["V_steer_norm"], | |
| color=color, linewidth=2, label=f"strength={strength:.1f}") | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("V_steer L2 Norm", fontsize=12) | |
| ax.set_title("V_steer Norm Comparison across Scales", fontsize=14) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, "plot_comparison_norms.png"), dpi=150) | |
| plt.close() | |
| # Comparison Plot 2: Cosine Similarity | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| for (strength, data), color in zip(all_data.items(), colors): | |
| ax.plot(data["timesteps"][1:], data["V_steer_cosine"][1:], | |
| color=color, linewidth=2, label=f"strength={strength:.1f}") | |
| ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.7) | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("Cosine Similarity", fontsize=12) | |
| ax.set_title("V_steer Directional Consistency Comparison", fontsize=14) | |
| ax.set_ylim(-0.1, 1.1) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, "plot_comparison_cosine.png"), dpi=150) | |
| plt.close() | |
| # Comparison Plot 3: Angles | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| for (strength, data), color in zip(all_data.items(), colors): | |
| ax.plot(data["timesteps"], data["V_fid_V_steer_angle"], | |
| color=color, linewidth=2, label=f"strength={strength:.1f}") | |
| ax.axhline(y=90, color='gray', linestyle='--', alpha=0.7) | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("Angle (degrees)", fontsize=12) | |
| ax.set_title("V_fid-V_steer Angle Comparison", fontsize=14) | |
| ax.set_ylim(0, 180) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, "plot_comparison_angles.png"), dpi=150) | |
| plt.close() | |
| # Comparison Plot 4: Trajectory (cumulative movement) | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| for (strength, data), color in zip(all_data.items(), colors): | |
| cumulative = np.cumsum(data["zt_edit_delta"]) | |
| ax.plot(data["timesteps"], cumulative, | |
| color=color, linewidth=2, label=f"strength={strength:.1f}") | |
| ax.set_xlabel("Timestep", fontsize=12) | |
| ax.set_ylabel("Cumulative Movement", fontsize=12) | |
| ax.set_title("Edit Trajectory: Cumulative Distance from Source", fontsize=14) | |
| ax.legend(fontsize=10) | |
| ax.grid(True, alpha=0.3) | |
| ax.invert_xaxis() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, "plot_comparison_trajectory.png"), dpi=150) | |
| plt.close() | |
| print(f"Comparison plots saved to {output_dir}") | |
| def prepare_mask_for_flux( | |
| mask: torch.Tensor, | |
| target_height: int, | |
| target_width: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor: | |
| """ | |
| Prepare a binary mask for use with Flux's packed latent format. | |
| Args: | |
| mask: Input mask tensor. Can be: | |
| - (H, W): Single channel 2D mask | |
| - (1, H, W): Single channel with batch dim | |
| - (B, 1, H, W): Full 4D tensor | |
| - (B, H, W): 3D tensor | |
| target_height: Target height in latent space (H/8) | |
| target_width: Target width in latent space (W/8) | |
| device: Target device | |
| dtype: Target dtype | |
| Returns: | |
| Mask tensor of shape (1, seq_len, 1) for packed latent format | |
| where seq_len = (H/2) * (W/2) for Flux's 2x2 packing | |
| """ | |
| # Ensure 4D tensor (B, C, H, W) | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(0).unsqueeze(0) # (H, W) -> (1, 1, H, W) | |
| elif mask.dim() == 3: | |
| if mask.shape[0] == 1: | |
| mask = mask.unsqueeze(0) # (1, H, W) -> (1, 1, H, W) | |
| else: | |
| mask = mask.unsqueeze(1) # (B, H, W) -> (B, 1, H, W) | |
| mask = mask.to(device=device, dtype=dtype) | |
| # Resize to latent space dimensions | |
| mask_resized = F.interpolate( | |
| mask, | |
| size=(target_height, target_width), | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| # For Flux: pack into sequence format | |
| # Flux uses 2x2 packing, so (B, C, H, W) -> (B, H/2 * W/2, C*4) | |
| # For mask, we just need (B, seq_len, 1) where values are averaged over 2x2 patches | |
| B, C, H, W = mask_resized.shape | |
| # Reshape for 2x2 packing: (B, 1, H, W) -> (B, H//2, 2, W//2, 2) -> (B, H//2, W//2, 4) | |
| mask_packed = mask_resized.view(B, C, H // 2, 2, W // 2, 2) | |
| mask_packed = mask_packed.permute(0, 2, 4, 1, 3, 5) # (B, H//2, W//2, C, 2, 2) | |
| mask_packed = mask_packed.reshape(B, (H // 2) * (W // 2), C * 4) # (B, seq_len, 4) | |
| # Average across the 4 values in each 2x2 patch to get single mask value | |
| mask_packed = mask_packed.mean(dim=-1, keepdim=True) # (B, seq_len, 1) | |
| return mask_packed | |
| def FlowEditFLUX_Slider( | |
| pipe, | |
| scheduler, | |
| x_src, | |
| src_prompt: str, | |
| tar_prompt: str, | |
| tar_prompt_neg: str, | |
| negative_prompt: str = "", | |
| strength: float = 1.0, | |
| T_steps: int = 28, | |
| n_avg: int = 1, | |
| src_guidance_scale: float = 1.5, | |
| tar_guidance_scale: float = 5.5, | |
| n_min: int = 0, | |
| n_max: int = 24, | |
| scale_mode: str = "slider", | |
| normalize_v_dir: bool = False, | |
| v_dir_target_norm: float = 1.0, | |
| log_vectors: bool = False, | |
| log_output_dir: Optional[str] = None, | |
| ): | |
| """ | |
| FlowEdit with 3-prompt directional decomposition for continuous strength control. | |
| Args: | |
| pipe: FluxPipeline | |
| scheduler: FlowMatchEulerDiscreteScheduler | |
| x_src: Source image latent (B, C, H, W) | |
| src_prompt: Source prompt describing the original image (e.g., "a building") | |
| tar_prompt: Positive target prompt (e.g., "a severely decayed building") | |
| tar_prompt_neg: Negative target prompt (e.g., "a new building") | |
| negative_prompt: Negative prompt for CFG (usually empty for Flux) | |
| strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) | |
| T_steps: Total number of timesteps | |
| n_avg: Number of velocity field averaging iterations | |
| src_guidance_scale: Guidance strength for source prompt | |
| tar_guidance_scale: Guidance strength for target prompts | |
| n_min: Number of final steps using regular sampling | |
| n_max: Maximum number of steps to apply flow editing | |
| scale_mode: Scaling method - "slider" (default), "interp", "step", "cfg", or "direct" | |
| - "slider": Scale the direction vector V_delta_s = V_fid + strength * V_steer (FreeSlider-like) | |
| - "interp": FlowEdit-based interpolation V_final = V_src + strength * (V_pos - V_src) | |
| - "step": Scale the step size dt (FlowEdit paper experiment, causes degradation) | |
| - "cfg": Scale the target guidance (tar_guidance_scale * strength) | |
| - "direct": Scale the full velocity difference V_delta_s = strength * (V_pos - V_src) without decomposition | |
| normalize_v_dir: If True, normalize V_steer to v_dir_target_norm before scaling. | |
| This stabilizes edit strength across different CFG settings and prevents | |
| both numerical instability (low CFG) and semantic over-editing (high CFG). | |
| Only applies when scale_mode="slider". | |
| v_dir_target_norm: Target L2 norm for V_steer normalization (default: 1.0). | |
| Higher values produce stronger edits per unit strength. | |
| log_vectors: If True, record vector statistics and generate visualization plots | |
| log_output_dir: Output directory for vector logs (required if log_vectors=True) | |
| Returns: | |
| Edited latent tensor | |
| """ | |
| # Validate log_vectors arguments | |
| if log_vectors and log_output_dir is None: | |
| raise ValueError("log_output_dir must be specified when log_vectors=True") | |
| # Initialize logging variables | |
| stats_list = [] if log_vectors else None | |
| prev_V_steer = None | |
| prev_zt_edit = None | |
| device = x_src.device | |
| # Note: orig_height/width should match the actual image dimensions for correct latent_image_ids | |
| # x_src is VAE-encoded latent (H/8, W/8), so multiply by vae_scale_factor to get original size | |
| orig_height = x_src.shape[2] * pipe.vae_scale_factor | |
| orig_width = x_src.shape[3] * pipe.vae_scale_factor | |
| num_channels_latents = pipe.transformer.config.in_channels // 4 | |
| pipe.check_inputs( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| height=orig_height, | |
| width=orig_width, | |
| callback_on_step_end_tensor_inputs=None, | |
| max_sequence_length=512, | |
| ) | |
| # Prepare latents | |
| x_src, latent_src_image_ids = pipe.prepare_latents( | |
| batch_size=x_src.shape[0], | |
| num_channels_latents=num_channels_latents, | |
| height=orig_height, | |
| width=orig_width, | |
| dtype=x_src.dtype, | |
| device=x_src.device, | |
| generator=None, | |
| latents=x_src | |
| ) | |
| x_src_packed = pipe._pack_latents( | |
| x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3] | |
| ) | |
| latent_image_ids = latent_src_image_ids | |
| # Prepare timesteps | |
| sigmas = np.linspace(1.0, 1 / T_steps, T_steps) | |
| image_seq_len = x_src_packed.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| scheduler.config.base_image_seq_len, | |
| scheduler.config.max_image_seq_len, | |
| scheduler.config.base_shift, | |
| scheduler.config.max_shift, | |
| ) | |
| timesteps, T_steps = retrieve_timesteps( | |
| scheduler, | |
| T_steps, | |
| device, | |
| timesteps=None, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| num_warmup_steps = max(len(timesteps) - T_steps * pipe.scheduler.order, 0) | |
| pipe._num_timesteps = len(timesteps) | |
| # ============================================ | |
| # Encode prompts (3 prompts) | |
| # ============================================ | |
| # Source prompt | |
| ( | |
| src_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| src_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| # Target positive prompt (e.g., "severely decayed") | |
| ( | |
| tar_pos_prompt_embeds, | |
| tar_pos_pooled_prompt_embeds, | |
| tar_pos_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| # Target negative prompt (e.g., "new, pristine") | |
| ( | |
| tar_neg_prompt_embeds, | |
| tar_neg_pooled_prompt_embeds, | |
| tar_neg_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt_neg, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| # ============================================ | |
| # Handle guidance | |
| # ============================================ | |
| # For cfg mode, strength is applied to tar_guidance_scale | |
| effective_tar_guidance = tar_guidance_scale * strength if scale_mode == "cfg" else tar_guidance_scale | |
| if pipe.transformer.config.guidance_embeds: | |
| src_guidance = torch.tensor([src_guidance_scale], device=device) | |
| src_guidance = src_guidance.expand(x_src_packed.shape[0]) | |
| tar_pos_guidance = torch.tensor([effective_tar_guidance], device=device) | |
| tar_pos_guidance = tar_pos_guidance.expand(x_src_packed.shape[0]) | |
| tar_neg_guidance = torch.tensor([effective_tar_guidance], device=device) | |
| tar_neg_guidance = tar_neg_guidance.expand(x_src_packed.shape[0]) | |
| else: | |
| src_guidance = None | |
| tar_pos_guidance = None | |
| tar_neg_guidance = None | |
| # Initialize ODE: zt_edit = x_src | |
| zt_edit = x_src_packed.clone() | |
| # ============================================ | |
| # Main editing loop | |
| # ============================================ | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"FlowSlider (strength={strength:.2f})"): | |
| if T_steps - i > n_max: | |
| continue | |
| scheduler._init_step_index(t) | |
| t_i = scheduler.sigmas[scheduler.step_index] | |
| if i < len(timesteps): | |
| t_im1 = scheduler.sigmas[scheduler.step_index + 1] | |
| else: | |
| t_im1 = t_i | |
| if T_steps - i > n_min: | |
| # Flow-based editing phase | |
| V_delta_s_avg = torch.zeros_like(x_src_packed) | |
| for k in range(n_avg): | |
| # Forward noise | |
| fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) | |
| # Source trajectory | |
| zt_src = (1 - t_i) * x_src_packed + t_i * fwd_noise | |
| # Target trajectory (with offset preservation) | |
| zt_tar = zt_edit + zt_src - x_src_packed | |
| # ============================================ | |
| # 3-prompt velocity computation (CORE CHANGE) | |
| # ============================================ | |
| # Source velocity | |
| Vt_src = calc_v_flux( | |
| pipe, | |
| latents=zt_src, | |
| prompt_embeds=src_prompt_embeds, | |
| pooled_prompt_embeds=src_pooled_prompt_embeds, | |
| guidance=src_guidance, | |
| text_ids=src_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| # Positive target velocity | |
| Vt_pos = calc_v_flux( | |
| pipe, | |
| latents=zt_tar, | |
| prompt_embeds=tar_pos_prompt_embeds, | |
| pooled_prompt_embeds=tar_pos_pooled_prompt_embeds, | |
| guidance=tar_pos_guidance, | |
| text_ids=tar_pos_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| # Negative target velocity | |
| Vt_neg = calc_v_flux( | |
| pipe, | |
| latents=zt_tar, | |
| prompt_embeds=tar_neg_prompt_embeds, | |
| pooled_prompt_embeds=tar_neg_pooled_prompt_embeds, | |
| guidance=tar_neg_guidance, | |
| text_ids=tar_neg_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| # ============================================ | |
| # Directional decomposition | |
| # ============================================ | |
| # V_steer: Pure edit direction (e.g., "aging" direction) | |
| V_steer = Vt_pos - Vt_neg | |
| # V_fid: Base change from source to negative target | |
| V_fid = Vt_neg - Vt_src | |
| # V_delta_s computation depends on scale_mode | |
| if scale_mode == "slider": | |
| # Slider mode: strength the direction vector (FreeSlider-like) | |
| # strength=0 -> V_fid only (tar_neg direction) | |
| # strength=1 -> V_fid + V_steer = Vt_pos - Vt_src (tar_pos direction) | |
| # Apply V_steer normalization if enabled | |
| if normalize_v_dir: | |
| # Compute current norm (mean over sequence dimension) | |
| v_dir_norm = V_steer.norm(dim=-1, keepdim=True).mean() | |
| # Normalize to target norm | |
| V_steer_scaled = V_steer * (v_dir_target_norm / (v_dir_norm + 1e-8)) | |
| else: | |
| V_steer_scaled = V_steer | |
| V_delta_s = V_fid + strength * V_steer_scaled | |
| elif scale_mode == "direct": | |
| # Direct mode: strength the full velocity difference without decomposition | |
| # V_delta_s = strength * (V_pos - V_src) | |
| # This strengths both V_fid and V_steer together, causing trajectory collapse at strength > 1 | |
| V_delta_full = Vt_pos - Vt_src | |
| V_delta_s = strength * V_delta_full | |
| elif scale_mode == "interp": | |
| # Interp mode: FlowEdit-based interpolation | |
| # V_final = V_src + strength * (V_pos - V_src) = (1-strength)*V_src + strength*V_pos | |
| # For 3-prompt: V_final = V_src + strength * (V_pos - V_neg) + (V_neg - V_src) | |
| # = V_neg + strength * V_steer | |
| # But we want: V_final = V_src + strength * V_delta_full | |
| # where V_delta_full = V_pos - V_src (full edit direction) | |
| V_delta_full = Vt_pos - Vt_src | |
| V_final = Vt_src + strength * V_delta_full | |
| # Store V_final directly, will be used differently in ODE propagation | |
| V_delta_s = V_final # This is actually V_final, not a delta | |
| elif scale_mode == "step": | |
| # Step mode: V_delta_s is fixed at strength=1, dt will be scaled later | |
| V_delta_s = V_fid + V_steer # equivalent to strength=1 | |
| elif scale_mode == "cfg": | |
| # CFG mode: strength was already applied to guidance, use strength=1 for direction | |
| V_delta_s = V_fid + V_steer # equivalent to strength=1 | |
| else: | |
| raise ValueError(f"Unknown scale_mode: {scale_mode}") | |
| V_delta_s_avg += (1 / n_avg) * V_delta_s | |
| # ============================================ | |
| # Vector Logging (if enabled) | |
| # ============================================ | |
| if log_vectors: | |
| # Use the last computed V_fid and V_steer for logging | |
| # (when n_avg > 1, this is from the last iteration) | |
| # Note: For slider mode with normalize_v_dir, log the original V_steer | |
| # to see the raw values before normalization | |
| step_stats = compute_vector_stats( | |
| V_fid=V_fid, | |
| V_steer=V_steer, | |
| V_delta_s=V_delta_s_avg, | |
| zt_edit=zt_edit, | |
| prev_V_steer=prev_V_steer, | |
| prev_zt_edit=prev_zt_edit, | |
| ) | |
| step_stats["timestep"] = t_i.item() if hasattr(t_i, 'item') else float(t_i) | |
| # Log normalization info if enabled | |
| if normalize_v_dir and scale_mode == "slider": | |
| step_stats["normalize_v_dir"] = True | |
| step_stats["v_dir_target_norm"] = v_dir_target_norm | |
| step_stats["v_dir_original_norm"] = V_steer.norm(dim=-1).mean().item() | |
| stats_list.append(step_stats) | |
| # Store current values for next iteration comparison | |
| prev_V_steer = V_steer.clone() | |
| prev_zt_edit = zt_edit.clone() | |
| # Propagate ODE | |
| zt_edit = zt_edit.to(torch.float32) | |
| if scale_mode == "step": | |
| # Step mode: strength the step size dt (FlowEdit paper experiment) | |
| zt_edit = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg | |
| elif scale_mode == "interp": | |
| # Interp mode: V_delta_s_avg is actually V_final, use directly | |
| zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| else: | |
| # Slider and CFG mode: normal dt | |
| zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| zt_edit = zt_edit.to(V_delta_s_avg.dtype) | |
| else: # Regular sampling for last n_min steps | |
| if i == T_steps - n_min: | |
| # Initialize SDEDIT-style generation phase | |
| fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) | |
| xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise) | |
| xt_tar = zt_edit + xt_src - x_src_packed | |
| # For final steps, use interpolated target based on strength | |
| # Interpolate between neg and pos embeddings | |
| interp_prompt_embeds = (1 - strength) * tar_neg_prompt_embeds + strength * tar_pos_prompt_embeds | |
| interp_pooled_embeds = (1 - strength) * tar_neg_pooled_prompt_embeds + strength * tar_pos_pooled_prompt_embeds | |
| Vt_tar = calc_v_flux( | |
| pipe, | |
| latents=xt_tar, | |
| prompt_embeds=interp_prompt_embeds, | |
| pooled_prompt_embeds=interp_pooled_embeds, | |
| guidance=tar_pos_guidance, | |
| text_ids=tar_pos_text_ids, # text_ids are typically the same | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| xt_tar = xt_tar.to(torch.float32) | |
| prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar | |
| prev_sample = prev_sample.to(Vt_tar.dtype) | |
| xt_tar = prev_sample | |
| out = zt_edit if n_min == 0 else xt_tar | |
| unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor) | |
| # ============================================ | |
| # Save and visualize vector statistics | |
| # ============================================ | |
| if log_vectors and stats_list: | |
| stats_path = save_vector_stats(stats_list, log_output_dir, strength) | |
| plot_vector_stats(stats_path, log_output_dir) | |
| print(f"Vector statistics saved to {log_output_dir}") | |
| return unpacked_out | |
| def FlowEditFLUX_Slider_batch( | |
| pipe, | |
| scheduler, | |
| x_src, | |
| src_prompt: str, | |
| tar_prompt: str, | |
| tar_prompt_neg: str, | |
| negative_prompt: str = "", | |
| strengths: list = [0.0, 0.25, 0.5, 0.75, 1.0], | |
| T_steps: int = 28, | |
| n_avg: int = 1, | |
| src_guidance_scale: float = 1.5, | |
| tar_guidance_scale: float = 5.5, | |
| n_min: int = 0, | |
| n_max: int = 24, | |
| ): | |
| """ | |
| Batch processing for multiple strengths. | |
| More efficient than calling FlowEditFLUX_Slider multiple times | |
| as prompt encoding is done only once. | |
| Args: | |
| strengths: List of strength values to generate | |
| (other args same as FlowEditFLUX_Slider) | |
| Returns: | |
| Dict[float, Tensor]: Mapping from strength to edited latent | |
| """ | |
| results = {} | |
| for strength in strengths: | |
| result = FlowEditFLUX_Slider( | |
| pipe=pipe, | |
| scheduler=scheduler, | |
| x_src=x_src, | |
| src_prompt=src_prompt, | |
| tar_prompt=tar_prompt, | |
| tar_prompt_neg=tar_prompt_neg, | |
| negative_prompt=negative_prompt, | |
| strength=strength, | |
| T_steps=T_steps, | |
| n_avg=n_avg, | |
| src_guidance_scale=src_guidance_scale, | |
| tar_guidance_scale=tar_guidance_scale, | |
| n_min=n_min, | |
| n_max=n_max, | |
| ) | |
| results[strength] = result | |
| return results | |
| def FlowEditFLUX_Slider_2prompt( | |
| pipe, | |
| scheduler, | |
| x_src, | |
| src_prompt: str, | |
| tar_prompt: str, | |
| negative_prompt: str = "", | |
| strength: float = 1.0, | |
| T_steps: int = 28, | |
| n_avg: int = 1, | |
| src_guidance_scale: float = 1.5, | |
| tar_guidance_scale: float = 5.5, | |
| n_min: int = 0, | |
| n_max: int = 24, | |
| ): | |
| """ | |
| FlowEdit with 2-prompt simple scaling (without neg prompt). | |
| 数式: V_delta_s = strength * (V_tar - V_src) | |
| strength=0: 編集なし(元画像のまま) | |
| strength=1: 通常のFlowEdit(tar方向への完全編集) | |
| strength>1: tar方向への過剰編集 | |
| Args: | |
| pipe: FluxPipeline | |
| scheduler: FlowMatchEulerDiscreteScheduler | |
| x_src: Source image latent | |
| src_prompt: Source prompt | |
| tar_prompt: Target prompt (e.g., "a decayed building") | |
| strength: Edit intensity (0=no change, 1=full edit) | |
| """ | |
| device = x_src.device | |
| orig_height = x_src.shape[2] * pipe.vae_scale_factor | |
| orig_width = x_src.shape[3] * pipe.vae_scale_factor | |
| num_channels_latents = pipe.transformer.config.in_channels // 4 | |
| pipe.check_inputs( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| height=orig_height, | |
| width=orig_width, | |
| callback_on_step_end_tensor_inputs=None, | |
| max_sequence_length=512, | |
| ) | |
| # Prepare latents | |
| x_src, latent_src_image_ids = pipe.prepare_latents( | |
| batch_size=x_src.shape[0], | |
| num_channels_latents=num_channels_latents, | |
| height=orig_height, | |
| width=orig_width, | |
| dtype=x_src.dtype, | |
| device=x_src.device, | |
| generator=None, | |
| latents=x_src | |
| ) | |
| x_src_packed = pipe._pack_latents( | |
| x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3] | |
| ) | |
| latent_image_ids = latent_src_image_ids | |
| # Prepare timesteps | |
| sigmas = np.linspace(1.0, 1 / T_steps, T_steps) | |
| image_seq_len = x_src_packed.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| scheduler.config.base_image_seq_len, | |
| scheduler.config.max_image_seq_len, | |
| scheduler.config.base_shift, | |
| scheduler.config.max_shift, | |
| ) | |
| timesteps, T_steps = retrieve_timesteps( | |
| scheduler, | |
| T_steps, | |
| device, | |
| timesteps=None, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| # Encode prompts (2 prompts only) | |
| ( | |
| src_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| src_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| ( | |
| tar_prompt_embeds, | |
| tar_pooled_prompt_embeds, | |
| tar_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| # Handle guidance | |
| if pipe.transformer.config.guidance_embeds: | |
| src_guidance = torch.tensor([src_guidance_scale], device=device) | |
| src_guidance = src_guidance.expand(x_src_packed.shape[0]) | |
| tar_guidance = torch.tensor([tar_guidance_scale], device=device) | |
| tar_guidance = tar_guidance.expand(x_src_packed.shape[0]) | |
| else: | |
| src_guidance = None | |
| tar_guidance = None | |
| # Initialize ODE | |
| zt_edit = x_src_packed.clone() | |
| # Main editing loop | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"FlowEdit-2prompt (strength={strength:.2f})"): | |
| if T_steps - i > n_max: | |
| continue | |
| scheduler._init_step_index(t) | |
| t_i = scheduler.sigmas[scheduler.step_index] | |
| if i < len(timesteps): | |
| t_im1 = scheduler.sigmas[scheduler.step_index + 1] | |
| else: | |
| t_im1 = t_i | |
| if T_steps - i > n_min: | |
| V_delta_s_avg = torch.zeros_like(x_src_packed) | |
| for k in range(n_avg): | |
| fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) | |
| zt_src = (1 - t_i) * x_src_packed + t_i * fwd_noise | |
| zt_tar = zt_edit + zt_src - x_src_packed | |
| # 2-prompt velocity computation | |
| Vt_src = calc_v_flux( | |
| pipe, | |
| latents=zt_src, | |
| prompt_embeds=src_prompt_embeds, | |
| pooled_prompt_embeds=src_pooled_prompt_embeds, | |
| guidance=src_guidance, | |
| text_ids=src_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| Vt_tar = calc_v_flux( | |
| pipe, | |
| latents=zt_tar, | |
| prompt_embeds=tar_prompt_embeds, | |
| pooled_prompt_embeds=tar_pooled_prompt_embeds, | |
| guidance=tar_guidance, | |
| text_ids=tar_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| # Simple scaling: V_delta_s = strength * (V_tar - V_src) | |
| V_delta_s = strength * (Vt_tar - Vt_src) | |
| V_delta_s_avg += (1 / n_avg) * V_delta_s | |
| zt_edit = zt_edit.to(torch.float32) | |
| zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| zt_edit = zt_edit.to(V_delta_s_avg.dtype) | |
| else: | |
| if i == T_steps - n_min: | |
| fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) | |
| xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise) | |
| xt_tar = zt_edit + xt_src - x_src_packed | |
| # Prompt interpolation for stability at high strengths | |
| # interp = (1 - strength) * src + strength * tar | |
| interp_prompt_embeds = (1 - strength) * src_prompt_embeds + strength * tar_prompt_embeds | |
| interp_pooled_embeds = (1 - strength) * src_pooled_prompt_embeds + strength * tar_pooled_prompt_embeds | |
| Vt_tar = calc_v_flux( | |
| pipe, | |
| latents=xt_tar, | |
| prompt_embeds=interp_prompt_embeds, | |
| pooled_prompt_embeds=interp_pooled_embeds, | |
| guidance=tar_guidance, | |
| text_ids=tar_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| xt_tar = xt_tar.to(torch.float32) | |
| prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar | |
| prev_sample = prev_sample.to(Vt_tar.dtype) | |
| xt_tar = prev_sample | |
| out = zt_edit if n_min == 0 else xt_tar | |
| unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor) | |
| return unpacked_out | |
| def FlowEditFLUX_Slider_with_mask( | |
| pipe, | |
| scheduler, | |
| x_src, | |
| src_prompt: str, | |
| tar_prompt: str, | |
| tar_prompt_neg: str, | |
| mask: torch.Tensor, | |
| negative_prompt: str = "", | |
| strength: float = 1.0, | |
| T_steps: int = 28, | |
| n_avg: int = 1, | |
| src_guidance_scale: float = 1.5, | |
| tar_guidance_scale: float = 5.5, | |
| n_min: int = 0, | |
| n_max: int = 24, | |
| scale_mode: str = "slider", | |
| ): | |
| """ | |
| FlowEdit with 3-prompt directional decomposition and mask-based local editing. | |
| This function applies edits only to the masked region, preserving the | |
| unmasked areas from the source image. | |
| Args: | |
| pipe: FluxPipeline | |
| scheduler: FlowMatchEulerDiscreteScheduler | |
| x_src: Source image latent (B, C, H, W) | |
| src_prompt: Source prompt describing the original image | |
| tar_prompt: Positive target prompt (e.g., "a severely decayed building") | |
| tar_prompt_neg: Negative target prompt (e.g., "a new building") | |
| mask: Binary mask tensor indicating edit region. | |
| Shape: (H, W), (1, H, W), (B, H, W), or (B, 1, H, W) | |
| Values: 1 = edit region, 0 = preserve original | |
| negative_prompt: Negative prompt for CFG (usually empty for Flux) | |
| strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) | |
| T_steps: Total number of timesteps | |
| n_avg: Number of velocity field averaging iterations | |
| src_guidance_scale: Guidance strength for source prompt | |
| tar_guidance_scale: Guidance strength for target prompts | |
| n_min: Number of final steps using regular sampling | |
| n_max: Maximum number of steps to apply flow editing | |
| scale_mode: Scaling method - "slider" (default), "interp", "step", or "cfg" | |
| Returns: | |
| Edited latent tensor with edits applied only in masked region | |
| """ | |
| device = x_src.device | |
| orig_height = x_src.shape[2] * pipe.vae_scale_factor | |
| orig_width = x_src.shape[3] * pipe.vae_scale_factor | |
| num_channels_latents = pipe.transformer.config.in_channels // 4 | |
| pipe.check_inputs( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| height=orig_height, | |
| width=orig_width, | |
| callback_on_step_end_tensor_inputs=None, | |
| max_sequence_length=512, | |
| ) | |
| # Prepare latents | |
| x_src, latent_src_image_ids = pipe.prepare_latents( | |
| batch_size=x_src.shape[0], | |
| num_channels_latents=num_channels_latents, | |
| height=orig_height, | |
| width=orig_width, | |
| dtype=x_src.dtype, | |
| device=x_src.device, | |
| generator=None, | |
| latents=x_src | |
| ) | |
| x_src_packed = pipe._pack_latents( | |
| x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3] | |
| ) | |
| latent_image_ids = latent_src_image_ids | |
| # Prepare mask for packed latent format | |
| mask_packed = prepare_mask_for_flux( | |
| mask=mask, | |
| target_height=x_src.shape[2], | |
| target_width=x_src.shape[3], | |
| device=device, | |
| dtype=x_src.dtype, | |
| ) | |
| # Prepare timesteps | |
| sigmas = np.linspace(1.0, 1 / T_steps, T_steps) | |
| image_seq_len = x_src_packed.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| scheduler.config.base_image_seq_len, | |
| scheduler.config.max_image_seq_len, | |
| scheduler.config.base_shift, | |
| scheduler.config.max_shift, | |
| ) | |
| timesteps, T_steps = retrieve_timesteps( | |
| scheduler, | |
| T_steps, | |
| device, | |
| timesteps=None, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| num_warmup_steps = max(len(timesteps) - T_steps * pipe.scheduler.order, 0) | |
| pipe._num_timesteps = len(timesteps) | |
| # Encode prompts (3 prompts) | |
| ( | |
| src_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| src_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| ( | |
| tar_pos_prompt_embeds, | |
| tar_pos_pooled_prompt_embeds, | |
| tar_pos_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| ( | |
| tar_neg_prompt_embeds, | |
| tar_neg_pooled_prompt_embeds, | |
| tar_neg_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt_neg, | |
| prompt_2=None, | |
| device=device, | |
| ) | |
| # Handle guidance | |
| # For cfg mode, strength is applied to tar_guidance_scale | |
| effective_tar_guidance = tar_guidance_scale * strength if scale_mode == "cfg" else tar_guidance_scale | |
| if pipe.transformer.config.guidance_embeds: | |
| src_guidance = torch.tensor([src_guidance_scale], device=device) | |
| src_guidance = src_guidance.expand(x_src_packed.shape[0]) | |
| tar_pos_guidance = torch.tensor([effective_tar_guidance], device=device) | |
| tar_pos_guidance = tar_pos_guidance.expand(x_src_packed.shape[0]) | |
| tar_neg_guidance = torch.tensor([effective_tar_guidance], device=device) | |
| tar_neg_guidance = tar_neg_guidance.expand(x_src_packed.shape[0]) | |
| else: | |
| src_guidance = None | |
| tar_pos_guidance = None | |
| tar_neg_guidance = None | |
| # Initialize ODE: zt_edit = x_src | |
| zt_edit = x_src_packed.clone() | |
| # Main editing loop | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"FlowEdit-Slider-Mask (strength={strength:.2f})"): | |
| if T_steps - i > n_max: | |
| continue | |
| scheduler._init_step_index(t) | |
| t_i = scheduler.sigmas[scheduler.step_index] | |
| if i < len(timesteps): | |
| t_im1 = scheduler.sigmas[scheduler.step_index + 1] | |
| else: | |
| t_im1 = t_i | |
| if T_steps - i > n_min: | |
| # Flow-based editing phase | |
| V_delta_s_avg = torch.zeros_like(x_src_packed) | |
| for k in range(n_avg): | |
| # Forward noise | |
| fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) | |
| # Source trajectory | |
| zt_src = (1 - t_i) * x_src_packed + t_i * fwd_noise | |
| # Target trajectory (with offset preservation) | |
| zt_tar = zt_edit + zt_src - x_src_packed | |
| # 3-prompt velocity computation | |
| Vt_src = calc_v_flux( | |
| pipe, | |
| latents=zt_src, | |
| prompt_embeds=src_prompt_embeds, | |
| pooled_prompt_embeds=src_pooled_prompt_embeds, | |
| guidance=src_guidance, | |
| text_ids=src_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| Vt_pos = calc_v_flux( | |
| pipe, | |
| latents=zt_tar, | |
| prompt_embeds=tar_pos_prompt_embeds, | |
| pooled_prompt_embeds=tar_pos_pooled_prompt_embeds, | |
| guidance=tar_pos_guidance, | |
| text_ids=tar_pos_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| Vt_neg = calc_v_flux( | |
| pipe, | |
| latents=zt_tar, | |
| prompt_embeds=tar_neg_prompt_embeds, | |
| pooled_prompt_embeds=tar_neg_pooled_prompt_embeds, | |
| guidance=tar_neg_guidance, | |
| text_ids=tar_neg_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| # Directional decomposition | |
| V_steer = Vt_pos - Vt_neg | |
| V_fid = Vt_neg - Vt_src | |
| # V_delta_s computation depends on scale_mode | |
| if scale_mode == "slider": | |
| V_delta_s = V_fid + strength * V_steer | |
| elif scale_mode == "interp": | |
| # Interp mode: FlowEdit-based interpolation | |
| V_delta_full = Vt_pos - Vt_src | |
| V_final = Vt_src + strength * V_delta_full | |
| V_delta_s = V_final # This is actually V_final, not a delta | |
| elif scale_mode == "step": | |
| V_delta_s = V_fid + V_steer # equivalent to strength=1 | |
| elif scale_mode == "cfg": | |
| V_delta_s = V_fid + V_steer # equivalent to strength=1 | |
| else: | |
| raise ValueError(f"Unknown scale_mode: {scale_mode}") | |
| V_delta_s_avg += (1 / n_avg) * V_delta_s | |
| # Propagate ODE (without mask first) | |
| zt_edit = zt_edit.to(torch.float32) | |
| if scale_mode == "step": | |
| # Step mode: strength the step size dt | |
| zt_edit_new = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg | |
| elif scale_mode == "interp": | |
| # Interp mode: V_delta_s_avg is actually V_final, use directly | |
| zt_edit_new = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| else: | |
| zt_edit_new = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| # ============================================ | |
| # MASK APPLICATION (LocalBlend style): | |
| # Apply mask to the RESULT, not the velocity | |
| # zt_edit = x_src + mask * (zt_edit_new - x_src) | |
| # This forces unmasked regions to stay at source | |
| # ============================================ | |
| zt_edit = x_src_packed + mask_packed * (zt_edit_new - x_src_packed) | |
| zt_edit = zt_edit.to(V_delta_s_avg.dtype) | |
| else: # Regular sampling for last n_min steps | |
| if i == T_steps - n_min: | |
| # Initialize SDEDIT-style generation phase | |
| fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) | |
| xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise) | |
| xt_tar = zt_edit + xt_src - x_src_packed | |
| # Interpolate between neg and pos embeddings | |
| interp_prompt_embeds = (1 - strength) * tar_neg_prompt_embeds + strength * tar_pos_prompt_embeds | |
| interp_pooled_embeds = (1 - strength) * tar_neg_pooled_prompt_embeds + strength * tar_pos_pooled_prompt_embeds | |
| Vt_tar = calc_v_flux( | |
| pipe, | |
| latents=xt_tar, | |
| prompt_embeds=interp_prompt_embeds, | |
| pooled_prompt_embeds=interp_pooled_embeds, | |
| guidance=tar_pos_guidance, | |
| text_ids=tar_pos_text_ids, | |
| latent_image_ids=latent_image_ids, | |
| t=t | |
| ) | |
| xt_tar = xt_tar.to(torch.float32) | |
| xt_tar_new = xt_tar + (t_im1 - t_i) * Vt_tar | |
| # LocalBlend style mask application for n_min phase | |
| xt_tar = x_src_packed + mask_packed * (xt_tar_new - x_src_packed) | |
| xt_tar = xt_tar.to(Vt_tar.dtype) | |
| out = zt_edit if n_min == 0 else xt_tar | |
| unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor) | |
| return unpacked_out | |
| # ============================================ | |
| # SD3 Slider Implementation | |
| # ============================================ | |
| def FlowEditSD3_Slider( | |
| pipe, | |
| scheduler, | |
| x_src, | |
| src_prompt: str, | |
| tar_prompt: str, | |
| tar_prompt_neg: str, | |
| negative_prompt: str = "", | |
| strength: float = 1.0, | |
| T_steps: int = 50, | |
| n_avg: int = 1, | |
| src_guidance_scale: float = 3.5, | |
| tar_guidance_scale: float = 13.5, | |
| n_min: int = 0, | |
| n_max: int = 33, | |
| scale_mode: str = "slider", | |
| normalize_v_dir: bool = False, | |
| v_dir_target_norm: float = 1.0, | |
| log_vectors: bool = False, | |
| log_output_dir: Optional[str] = None, | |
| ): | |
| """ | |
| FlowSlider for SD3 with 3-prompt directional decomposition. | |
| Uses 6-way CFG batching for efficient computation: | |
| - [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] | |
| Args: | |
| pipe: StableDiffusion3Pipeline | |
| scheduler: Scheduler (typically FlowMatchEulerDiscreteScheduler) | |
| x_src: Source image latent (B, C, H, W) | |
| src_prompt: Source prompt describing the original image | |
| tar_prompt: Positive target prompt (e.g., "a severely decayed building") | |
| tar_prompt_neg: Negative target prompt (e.g., "a new building") | |
| negative_prompt: Negative prompt for CFG (usually empty) | |
| strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) | |
| T_steps: Total number of timesteps (default: 50 for SD3) | |
| n_avg: Number of velocity field averaging iterations | |
| src_guidance_scale: Guidance strength for source prompt (default: 3.5 for SD3) | |
| tar_guidance_scale: Guidance strength for target prompts (default: 13.5 for SD3) | |
| n_min: Number of final steps using regular sampling | |
| n_max: Maximum number of steps to apply flow editing (default: 33 for SD3) | |
| scale_mode: Scaling method - "slider" (default), "interp", "step", "cfg", or "direct" | |
| - "slider": Scale the direction vector V_delta_s = V_fid + strength * V_steer | |
| - "direct": Scale the full velocity difference V_delta_s = strength * (V_pos - V_src) without decomposition | |
| normalize_v_dir: If True, normalize V_steer to v_dir_target_norm before scaling | |
| v_dir_target_norm: Target L2 norm for V_steer normalization | |
| log_vectors: If True, record vector statistics | |
| log_output_dir: Output directory for vector logs | |
| Returns: | |
| Edited latent tensor | |
| """ | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps | |
| # Validate log_vectors arguments | |
| if log_vectors and log_output_dir is None: | |
| raise ValueError("log_output_dir must be specified when log_vectors=True") | |
| # Initialize logging variables | |
| stats_list = [] if log_vectors else None | |
| prev_V_steer = None | |
| prev_zt_edit = None | |
| device = x_src.device | |
| # Retrieve timesteps | |
| timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, device, timesteps=None) | |
| num_warmup_steps = max(len(timesteps) - T_steps * scheduler.order, 0) | |
| pipe._num_timesteps = len(timesteps) | |
| # ============================================ | |
| # Encode prompts (3 prompts with CFG) | |
| # ============================================ | |
| # Source prompt | |
| pipe._guidance_scale = src_guidance_scale | |
| ( | |
| src_prompt_embeds, | |
| src_negative_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| src_negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| prompt_3=None, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
| device=device, | |
| ) | |
| # Target positive prompt | |
| pipe._guidance_scale = tar_guidance_scale | |
| ( | |
| tar_pos_prompt_embeds, | |
| tar_pos_negative_prompt_embeds, | |
| tar_pos_pooled_prompt_embeds, | |
| tar_pos_negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| prompt_2=None, | |
| prompt_3=None, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
| device=device, | |
| ) | |
| # Target negative prompt | |
| ( | |
| tar_neg_prompt_embeds, | |
| tar_neg_negative_prompt_embeds, | |
| tar_neg_pooled_prompt_embeds, | |
| tar_neg_negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt_neg, | |
| prompt_2=None, | |
| prompt_3=None, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
| device=device, | |
| ) | |
| # ============================================ | |
| # Prepare 6-way CFG embeddings | |
| # [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] | |
| # ============================================ | |
| all_prompt_embeds = torch.cat([ | |
| src_negative_prompt_embeds, # src_uncond | |
| src_prompt_embeds, # src_cond | |
| tar_pos_negative_prompt_embeds, # tar_pos_uncond | |
| tar_pos_prompt_embeds, # tar_pos_cond | |
| tar_neg_negative_prompt_embeds, # tar_neg_uncond | |
| tar_neg_prompt_embeds, # tar_neg_cond | |
| ], dim=0) | |
| all_pooled_prompt_embeds = torch.cat([ | |
| src_negative_pooled_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| tar_pos_negative_pooled_prompt_embeds, | |
| tar_pos_pooled_prompt_embeds, | |
| tar_neg_negative_pooled_prompt_embeds, | |
| tar_neg_pooled_prompt_embeds, | |
| ], dim=0) | |
| # Initialize ODE: zt_edit = x_src | |
| zt_edit = x_src.clone() | |
| # ============================================ | |
| # Main editing loop | |
| # ============================================ | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"SD3-FlowSlider (strength={strength:.2f})"): | |
| if T_steps - i > n_max: | |
| continue | |
| t_i = t / 1000 | |
| if i + 1 < len(timesteps): | |
| t_im1 = timesteps[i + 1] / 1000 | |
| else: | |
| t_im1 = torch.zeros_like(t_i).to(t_i.device) | |
| if T_steps - i > n_min: | |
| # Flow-based editing phase | |
| V_delta_s_avg = torch.zeros_like(x_src) | |
| for k in range(n_avg): | |
| # Forward noise | |
| fwd_noise = torch.randn_like(x_src).to(x_src.device) | |
| # Source trajectory | |
| zt_src = (1 - t_i) * x_src + t_i * fwd_noise | |
| # Target trajectory (with offset preservation) | |
| zt_tar = zt_edit + zt_src - x_src | |
| # ============================================ | |
| # 6-way CFG batched computation | |
| # ============================================ | |
| # Latents: [zt_src, zt_src, zt_tar, zt_tar, zt_tar, zt_tar] | |
| all_latents = torch.cat([ | |
| zt_src, zt_src, # src_uncond, src_cond | |
| zt_tar, zt_tar, # tar_pos_uncond, tar_pos_cond | |
| zt_tar, zt_tar, # tar_neg_uncond, tar_neg_cond | |
| ]) | |
| # Timestep broadcast | |
| timestep_batch = t.expand(all_latents.shape[0]) | |
| # Single transformer call | |
| with torch.no_grad(): | |
| noise_pred_all = pipe.transformer( | |
| hidden_states=all_latents, | |
| timestep=timestep_batch, | |
| encoder_hidden_states=all_prompt_embeds, | |
| pooled_projections=all_pooled_prompt_embeds, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| # Split into 6 parts | |
| ( | |
| src_noise_uncond, src_noise_cond, | |
| tar_pos_noise_uncond, tar_pos_noise_cond, | |
| tar_neg_noise_uncond, tar_neg_noise_cond, | |
| ) = noise_pred_all.chunk(6) | |
| # Apply CFG to get 3 velocities | |
| Vt_src = src_noise_uncond + src_guidance_scale * (src_noise_cond - src_noise_uncond) | |
| Vt_pos = tar_pos_noise_uncond + tar_guidance_scale * (tar_pos_noise_cond - tar_pos_noise_uncond) | |
| Vt_neg = tar_neg_noise_uncond + tar_guidance_scale * (tar_neg_noise_cond - tar_neg_noise_uncond) | |
| # ============================================ | |
| # Directional decomposition (same as FLUX) | |
| # ============================================ | |
| V_steer = Vt_pos - Vt_neg | |
| V_fid = Vt_neg - Vt_src | |
| # V_delta_s computation depends on scale_mode | |
| if scale_mode == "slider": | |
| if normalize_v_dir: | |
| v_dir_norm = V_steer.norm(dim=-1, keepdim=True).mean() | |
| V_steer_scaled = V_steer * (v_dir_target_norm / (v_dir_norm + 1e-8)) | |
| else: | |
| V_steer_scaled = V_steer | |
| V_delta_s = V_fid + strength * V_steer_scaled | |
| elif scale_mode == "direct": | |
| # Direct mode: strength the full velocity difference without decomposition | |
| V_delta_full = Vt_pos - Vt_src | |
| V_delta_s = strength * V_delta_full | |
| elif scale_mode == "interp": | |
| V_delta_full = Vt_pos - Vt_src | |
| V_final = Vt_src + strength * V_delta_full | |
| V_delta_s = V_final | |
| elif scale_mode == "step": | |
| V_delta_s = V_fid + V_steer | |
| elif scale_mode == "cfg": | |
| V_delta_s = V_fid + V_steer | |
| else: | |
| raise ValueError(f"Unknown scale_mode: {scale_mode}") | |
| V_delta_s_avg += (1 / n_avg) * V_delta_s | |
| # ============================================ | |
| # Vector Logging (if enabled) | |
| # ============================================ | |
| if log_vectors: | |
| step_stats = compute_vector_stats( | |
| V_fid=V_fid, | |
| V_steer=V_steer, | |
| V_delta_s=V_delta_s_avg, | |
| zt_edit=zt_edit, | |
| prev_V_steer=prev_V_steer, | |
| prev_zt_edit=prev_zt_edit, | |
| ) | |
| step_stats["timestep"] = t_i.item() if hasattr(t_i, 'item') else float(t_i) | |
| if normalize_v_dir and scale_mode == "slider": | |
| step_stats["normalize_v_dir"] = True | |
| step_stats["v_dir_target_norm"] = v_dir_target_norm | |
| step_stats["v_dir_original_norm"] = V_steer.norm(dim=-1).mean().item() | |
| stats_list.append(step_stats) | |
| prev_V_steer = V_steer.clone() | |
| prev_zt_edit = zt_edit.clone() | |
| # Propagate ODE | |
| zt_edit = zt_edit.to(torch.float32) | |
| if scale_mode == "step": | |
| zt_edit = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg | |
| else: | |
| zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| zt_edit = zt_edit.to(V_delta_s_avg.dtype) | |
| else: # Regular sampling for last n_min steps | |
| if i == T_steps - n_min: | |
| # Initialize SDEDIT-style generation phase | |
| fwd_noise = torch.randn_like(x_src).to(x_src.device) | |
| xt_src = scale_noise(scheduler, x_src, t, noise=fwd_noise) | |
| xt_tar = zt_edit + xt_src - x_src | |
| # For final steps, use interpolated target | |
| interp_prompt_embeds = (1 - strength) * tar_neg_prompt_embeds + strength * tar_pos_prompt_embeds | |
| interp_pooled_embeds = (1 - strength) * tar_neg_pooled_prompt_embeds + strength * tar_pos_pooled_prompt_embeds | |
| # 2-way CFG for interpolated target | |
| interp_all_embeds = torch.cat([tar_pos_negative_prompt_embeds, interp_prompt_embeds], dim=0) | |
| interp_all_pooled = torch.cat([tar_pos_negative_pooled_prompt_embeds, interp_pooled_embeds], dim=0) | |
| interp_latents = torch.cat([xt_tar, xt_tar]) | |
| timestep_batch = t.expand(2) | |
| with torch.no_grad(): | |
| noise_pred_interp = pipe.transformer( | |
| hidden_states=interp_latents, | |
| timestep=timestep_batch, | |
| encoder_hidden_states=interp_all_embeds, | |
| pooled_projections=interp_all_pooled, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| interp_uncond, interp_cond = noise_pred_interp.chunk(2) | |
| Vt_tar = interp_uncond + tar_guidance_scale * (interp_cond - interp_uncond) | |
| xt_tar = xt_tar.to(torch.float32) | |
| prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar | |
| prev_sample = prev_sample.to(Vt_tar.dtype) | |
| xt_tar = prev_sample | |
| out = zt_edit if n_min == 0 else xt_tar | |
| # ============================================ | |
| # Save and visualize vector statistics | |
| # ============================================ | |
| if log_vectors and stats_list: | |
| stats_path = save_vector_stats(stats_list, log_output_dir, strength) | |
| plot_vector_stats(stats_path, log_output_dir) | |
| print(f"Vector statistics saved to {log_output_dir}") | |
| return out | |
| # ============================================ | |
| # Z-Image Slider Implementation | |
| # ============================================ | |
| def FlowEditZImage_Slider( | |
| pipe, | |
| scheduler, | |
| x_src, | |
| src_prompt: str, | |
| tar_prompt: str, | |
| tar_prompt_neg: str, | |
| negative_prompt: str = "", | |
| strength: float = 1.0, | |
| T_steps: int = 28, | |
| n_avg: int = 1, | |
| src_guidance_scale: float = 2.0, | |
| tar_guidance_scale: float = 6.0, | |
| n_min: int = 0, | |
| n_max: int = 20, | |
| max_sequence_length: int = 512, | |
| scale_mode: str = "slider", | |
| normalize_v_dir: bool = False, | |
| v_dir_target_norm: float = 1.0, | |
| log_vectors: bool = False, | |
| log_output_dir: Optional[str] = None, | |
| ): | |
| """ | |
| FlowSlider for Z-Image with 3-prompt directional decomposition. | |
| Uses 6-way CFG with list-based processing (Z-Image specific): | |
| - [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] | |
| Args: | |
| pipe: ZImagePipeline | |
| scheduler: Scheduler (typically FlowMatchEulerDiscreteScheduler) | |
| x_src: Source image latent (B, C, H, W) | |
| src_prompt: Source prompt describing the original image | |
| tar_prompt: Positive target prompt (e.g., "a severely decayed building") | |
| tar_prompt_neg: Negative target prompt (e.g., "a new building") | |
| negative_prompt: Negative prompt for CFG (usually empty) | |
| strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) | |
| T_steps: Total number of timesteps (default: 28 for Z-Image) | |
| n_avg: Number of velocity field averaging iterations | |
| src_guidance_scale: Guidance strength for source prompt (default: 2.0 for Z-Image) | |
| tar_guidance_scale: Guidance strength for target prompts (default: 6.0 for Z-Image) | |
| n_min: Number of final steps using regular sampling | |
| n_max: Maximum number of steps to apply flow editing (default: 20 for Z-Image) | |
| max_sequence_length: Maximum prompt token length (default: 512) | |
| scale_mode: Scaling method - "slider" (default), "interp", "step", "cfg", or "direct" | |
| - "slider": Scale the direction vector V_delta_s = V_fid + strength * V_steer | |
| - "direct": Scale the full velocity difference V_delta_s = strength * (V_pos - V_src) without decomposition | |
| normalize_v_dir: If True, normalize V_steer to v_dir_target_norm before scaling | |
| v_dir_target_norm: Target L2 norm for V_steer normalization | |
| log_vectors: If True, record vector statistics | |
| log_output_dir: Output directory for vector logs | |
| Returns: | |
| Edited latent tensor | |
| """ | |
| from FlowEdit_utils import calculate_shift | |
| # Validate log_vectors arguments | |
| if log_vectors and log_output_dir is None: | |
| raise ValueError("log_output_dir must be specified when log_vectors=True") | |
| # Initialize logging variables | |
| stats_list = [] if log_vectors else None | |
| prev_V_steer = None | |
| prev_zt_edit = None | |
| device = x_src.device | |
| # ============================================ | |
| # Timestep preparation (Z-Image specific) | |
| # ============================================ | |
| image_seq_len = (x_src.shape[2] // 2) * (x_src.shape[3] // 2) | |
| mu = calculate_shift( | |
| image_seq_len, | |
| scheduler.config.get("base_image_seq_len", 256), | |
| scheduler.config.get("max_image_seq_len", 4096), | |
| scheduler.config.get("base_shift", 0.5), | |
| scheduler.config.get("max_shift", 1.15), | |
| ) | |
| scheduler.sigma_min = 0.0 | |
| timesteps, T_steps = retrieve_timesteps( | |
| scheduler, | |
| T_steps, | |
| device, | |
| sigmas=None, | |
| mu=mu, | |
| ) | |
| # ============================================ | |
| # Encode prompts (3 prompts with CFG) | |
| # Z-Image returns List[Tensor] format | |
| # ============================================ | |
| # Source prompt | |
| src_prompt_embeds, src_negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| device=device, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| # Target positive prompt | |
| tar_pos_prompt_embeds, tar_pos_negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| device=device, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| # Target negative prompt | |
| tar_neg_prompt_embeds, tar_neg_negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=tar_prompt_neg, | |
| device=device, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| # Extract embeddings from list format | |
| src_neg_emb = src_negative_prompt_embeds[0] if isinstance(src_negative_prompt_embeds, list) else src_negative_prompt_embeds | |
| src_pos_emb = src_prompt_embeds[0] if isinstance(src_prompt_embeds, list) else src_prompt_embeds | |
| tar_pos_neg_emb = tar_pos_negative_prompt_embeds[0] if isinstance(tar_pos_negative_prompt_embeds, list) else tar_pos_negative_prompt_embeds | |
| tar_pos_pos_emb = tar_pos_prompt_embeds[0] if isinstance(tar_pos_prompt_embeds, list) else tar_pos_prompt_embeds | |
| tar_neg_neg_emb = tar_neg_negative_prompt_embeds[0] if isinstance(tar_neg_negative_prompt_embeds, list) else tar_neg_negative_prompt_embeds | |
| tar_neg_pos_emb = tar_neg_prompt_embeds[0] if isinstance(tar_neg_prompt_embeds, list) else tar_neg_prompt_embeds | |
| # 6-way prompt embeddings list: | |
| # [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] | |
| prompt_embeds_list = [ | |
| src_neg_emb, # src_uncond | |
| src_pos_emb, # src_cond | |
| tar_pos_neg_emb, # tar_pos_uncond | |
| tar_pos_pos_emb, # tar_pos_cond | |
| tar_neg_neg_emb, # tar_neg_uncond | |
| tar_neg_pos_emb, # tar_neg_cond | |
| ] | |
| # Initialize ODE: zt_edit = x_src | |
| zt_edit = x_src.clone() | |
| # ============================================ | |
| # Main editing loop | |
| # ============================================ | |
| for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"ZImage-FlowSlider (strength={strength:.2f})"): | |
| if T_steps - i > n_max: | |
| continue | |
| # Get timestep values from scheduler sigmas | |
| scheduler._init_step_index(t) | |
| t_i = scheduler.sigmas[scheduler.step_index] | |
| if scheduler.step_index + 1 < len(scheduler.sigmas): | |
| t_im1 = scheduler.sigmas[scheduler.step_index + 1] | |
| else: | |
| t_im1 = torch.zeros_like(t_i) | |
| if T_steps - i > n_min: | |
| # Flow-based editing phase | |
| V_delta_s_avg = torch.zeros_like(x_src) | |
| for k in range(n_avg): | |
| # Forward noise | |
| fwd_noise = torch.randn_like(x_src).to(device) | |
| # Source trajectory | |
| zt_src = (1 - t_i) * x_src + t_i * fwd_noise | |
| # Target trajectory (with offset preservation) | |
| zt_tar = zt_edit + zt_src - x_src | |
| # ============================================ | |
| # 6-way CFG with list-based processing (Z-Image specific) | |
| # ============================================ | |
| # Prepare latents list: [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] | |
| # Z-Image expects List[(C, 1, H, W)] format | |
| transformer_dtype = pipe.transformer.dtype | |
| latents_list = [ | |
| zt_src.squeeze(0).unsqueeze(1).to(transformer_dtype), # src_uncond | |
| zt_src.squeeze(0).unsqueeze(1).to(transformer_dtype), # src_cond | |
| zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_pos_uncond | |
| zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_pos_cond | |
| zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_neg_uncond | |
| zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_neg_cond | |
| ] | |
| # Z-Image timestep format: (1000 - t) / 1000 | |
| timestep_zimage = (1000 - t) / 1000 | |
| timestep_batch = timestep_zimage.expand(len(latents_list)) | |
| # Single transformer call with list input | |
| with torch.no_grad(): | |
| noise_pred_list = pipe.transformer( | |
| latents_list, | |
| timestep_batch, | |
| prompt_embeds_list, | |
| return_dict=False, | |
| )[0] | |
| # Apply sign inversion and squeeze frame dimension (Z-Image specific) | |
| noise_pred_list = [-pred.squeeze(1) for pred in noise_pred_list] | |
| # Split into 6 predictions | |
| ( | |
| src_noise_uncond, src_noise_cond, | |
| tar_pos_noise_uncond, tar_pos_noise_cond, | |
| tar_neg_noise_uncond, tar_neg_noise_cond, | |
| ) = noise_pred_list | |
| # Apply CFG to get 3 velocities | |
| Vt_src = src_noise_uncond + src_guidance_scale * (src_noise_cond - src_noise_uncond) | |
| Vt_pos = tar_pos_noise_uncond + tar_guidance_scale * (tar_pos_noise_cond - tar_pos_noise_uncond) | |
| Vt_neg = tar_neg_noise_uncond + tar_guidance_scale * (tar_neg_noise_cond - tar_neg_noise_uncond) | |
| # ============================================ | |
| # Directional decomposition (same as FLUX/SD3) | |
| # ============================================ | |
| V_steer = Vt_pos - Vt_neg | |
| V_fid = Vt_neg - Vt_src | |
| # V_delta_s computation depends on scale_mode | |
| if scale_mode == "slider": | |
| if normalize_v_dir: | |
| v_dir_norm = V_steer.norm(dim=-1, keepdim=True).mean() | |
| V_steer_scaled = V_steer * (v_dir_target_norm / (v_dir_norm + 1e-8)) | |
| else: | |
| V_steer_scaled = V_steer | |
| V_delta_s = V_fid + strength * V_steer_scaled | |
| elif scale_mode == "direct": | |
| # Direct mode: strength the full velocity difference without decomposition | |
| V_delta_full = Vt_pos - Vt_src | |
| V_delta_s = strength * V_delta_full | |
| elif scale_mode == "interp": | |
| V_delta_full = Vt_pos - Vt_src | |
| V_final = Vt_src + strength * V_delta_full | |
| V_delta_s = V_final | |
| elif scale_mode == "step": | |
| V_delta_s = V_fid + V_steer | |
| elif scale_mode == "cfg": | |
| V_delta_s = V_fid + V_steer | |
| else: | |
| raise ValueError(f"Unknown scale_mode: {scale_mode}") | |
| # Add batch dimension back for accumulation | |
| V_delta_s_avg += (1 / n_avg) * V_delta_s.unsqueeze(0) | |
| # ============================================ | |
| # Vector Logging (if enabled) | |
| # ============================================ | |
| if log_vectors: | |
| step_stats = compute_vector_stats( | |
| V_fid=V_fid.unsqueeze(0), | |
| V_steer=V_steer.unsqueeze(0), | |
| V_delta_s=V_delta_s_avg, | |
| zt_edit=zt_edit, | |
| prev_V_steer=prev_V_steer, | |
| prev_zt_edit=prev_zt_edit, | |
| ) | |
| step_stats["timestep"] = t_i.item() if hasattr(t_i, 'item') else float(t_i) | |
| if normalize_v_dir and scale_mode == "slider": | |
| step_stats["normalize_v_dir"] = True | |
| step_stats["v_dir_target_norm"] = v_dir_target_norm | |
| step_stats["v_dir_original_norm"] = V_steer.norm(dim=-1).mean().item() | |
| stats_list.append(step_stats) | |
| prev_V_steer = V_steer.unsqueeze(0).clone() | |
| prev_zt_edit = zt_edit.clone() | |
| # Propagate ODE | |
| zt_edit = zt_edit.to(torch.float32) | |
| if scale_mode == "step": | |
| zt_edit = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg | |
| else: | |
| zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg | |
| zt_edit = zt_edit.to(V_delta_s_avg.dtype) | |
| else: # Regular sampling for last n_min steps | |
| if i == T_steps - n_min: | |
| # Initialize SDEDIT-style generation phase | |
| fwd_noise = torch.randn_like(x_src).to(device) | |
| xt_src = scale_noise(scheduler, x_src, t, noise=fwd_noise) | |
| xt_tar = zt_edit + xt_src - x_src | |
| # For final steps, use interpolated target embedding | |
| interp_emb = (1 - strength) * tar_neg_pos_emb + strength * tar_pos_pos_emb | |
| # 2-way CFG for interpolated target | |
| transformer_dtype = pipe.transformer.dtype | |
| latents_list = [ | |
| xt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # uncond | |
| xt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # cond | |
| ] | |
| prompt_embeds_2way = [tar_pos_neg_emb, interp_emb] | |
| timestep_zimage = (1000 - t) / 1000 | |
| timestep_batch = timestep_zimage.expand(2) | |
| with torch.no_grad(): | |
| noise_pred_list = pipe.transformer( | |
| latents_list, | |
| timestep_batch, | |
| prompt_embeds_2way, | |
| return_dict=False, | |
| )[0] | |
| # Apply sign inversion and squeeze | |
| noise_pred_list = [-pred.squeeze(1) for pred in noise_pred_list] | |
| interp_uncond, interp_cond = noise_pred_list | |
| Vt_tar = interp_uncond + tar_guidance_scale * (interp_cond - interp_uncond) | |
| xt_tar = xt_tar.to(torch.float32) | |
| prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar.unsqueeze(0) | |
| prev_sample = prev_sample.to(Vt_tar.dtype) | |
| xt_tar = prev_sample | |
| out = zt_edit if n_min == 0 else xt_tar | |
| # ============================================ | |
| # Save and visualize vector statistics | |
| # ============================================ | |
| if log_vectors and stats_list: | |
| stats_path = save_vector_stats(stats_list, log_output_dir, strength) | |
| plot_vector_stats(stats_path, log_output_dir) | |
| print(f"Vector statistics saved to {log_output_dir}") | |
| return out | |