""" FlowSlider: 3プロンプト方向性分解による連続スケール制御 FlowEditの実画像編集能力を維持しながら、FreeSlidersの考え方を取り入れて 連続的な編集強度制御を可能にする手法。 数式: V_steer = V_tar_pos - V_tar_neg (純粋な編集方向) V_fid = V_tar_neg - V_src (ベース変化) V_delta_s = V_fid + strength * V_steer strength=0: tar_neg方向への編集(例:劣化なし) strength=1: tar_pos方向への編集(例:完全劣化) 0 Dict[str, float]: """ Compute statistics for velocity field vectors at a single timestep. Args: V_fid: Base velocity (V_neg - V_src) V_steer: Direction velocity (V_pos - V_neg) V_delta_s: Combined velocity (V_fid + strength * V_steer) zt_edit: Current edited latent prev_V_steer: V_steer from previous timestep (for cosine similarity) prev_zt_edit: zt_edit from previous timestep (for delta computation) Returns: Dictionary of statistics """ stats = {} # Compute norms (average over sequence dimension) stats["V_fid_norm"] = V_fid.norm(dim=-1).mean().item() stats["V_steer_norm"] = V_steer.norm(dim=-1).mean().item() stats["V_delta_s_norm"] = V_delta_s.norm(dim=-1).mean().item() stats["zt_edit_norm"] = zt_edit.norm(dim=-1).mean().item() # Compute cosine similarity with previous V_steer if prev_V_steer is not None: # Flatten for cosine similarity computation v_dir_flat = V_steer.view(-1) prev_v_dir_flat = prev_V_steer.view(-1) cos_sim = F.cosine_similarity(v_dir_flat.unsqueeze(0), prev_v_dir_flat.unsqueeze(0)).item() stats["V_steer_cosine"] = cos_sim else: stats["V_steer_cosine"] = 1.0 # First step, no previous # Compute angle between V_fid and V_steer v_base_flat = V_fid.view(-1) v_dir_flat = V_steer.view(-1) cos_angle = F.cosine_similarity(v_base_flat.unsqueeze(0), v_dir_flat.unsqueeze(0)).item() # Clamp to avoid numerical issues with arccos cos_angle = max(-1.0, min(1.0, cos_angle)) angle_rad = np.arccos(cos_angle) stats["V_fid_V_steer_angle"] = np.degrees(angle_rad) # Compute zt_edit delta (movement from previous step) if prev_zt_edit is not None: delta = (zt_edit - prev_zt_edit).norm(dim=-1).mean().item() stats["zt_edit_delta"] = delta else: stats["zt_edit_delta"] = 0.0 return stats def save_vector_stats( stats_list: List[Dict[str, Any]], output_dir: str, strength: float, ): """ Save vector statistics to JSON file. Args: stats_list: List of statistics dictionaries (one per timestep) output_dir: Output directory strength: Scale value used """ os.makedirs(output_dir, exist_ok=True) # Reorganize data for easier plotting output_data = { "strength": strength, "timesteps": [s["timestep"] for s in stats_list], "V_fid_norm": [s["V_fid_norm"] for s in stats_list], "V_steer_norm": [s["V_steer_norm"] for s in stats_list], "V_delta_s_norm": [s["V_delta_s_norm"] for s in stats_list], "V_steer_cosine": [s["V_steer_cosine"] for s in stats_list], "V_fid_V_steer_angle": [s["V_fid_V_steer_angle"] for s in stats_list], "zt_edit_norm": [s["zt_edit_norm"] for s in stats_list], "zt_edit_delta": [s["zt_edit_delta"] for s in stats_list], } output_path = os.path.join(output_dir, f"stats_scale_{strength:.2f}.json") with open(output_path, "w") as f: json.dump(output_data, f, indent=2) return output_path def plot_vector_stats( stats_path: str, output_dir: str, ): """ Generate visualization plots from saved statistics. Args: stats_path: Path to stats JSON file output_dir: Output directory for plots """ import matplotlib.pyplot as plt with open(stats_path, "r") as f: data = json.load(f) strength = data["strength"] timesteps = data["timesteps"] os.makedirs(output_dir, exist_ok=True) # Plot 1: Norms fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(timesteps, data["V_fid_norm"], 'b-', label="V_fid", linewidth=2) ax.plot(timesteps, data["V_steer_norm"], 'r-', label="V_steer", linewidth=2) ax.plot(timesteps, data["V_delta_s_norm"], 'g--', label="V_delta_s", linewidth=2) ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("L2 Norm", fontsize=12) ax.set_title(f"Velocity Field Norms (strength={strength:.2f})", fontsize=14) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) ax.invert_xaxis() # t goes from 1.0 to 0 plt.tight_layout() plt.savefig(os.path.join(output_dir, f"plot_norms_scale_{strength:.2f}.png"), dpi=150) plt.close() # Plot 2: V_steer Cosine Similarity fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(timesteps[1:], data["V_steer_cosine"][1:], 'purple', linewidth=2) ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.7, label="Stability threshold (0.9)") ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("Cosine Similarity", fontsize=12) ax.set_title(f"V_steer Directional Consistency (strength={strength:.2f})", fontsize=14) ax.set_ylim(-0.1, 1.1) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, f"plot_cosine_scale_{strength:.2f}.png"), dpi=150) plt.close() # Plot 3: Angle between V_fid and V_steer fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(timesteps, data["V_fid_V_steer_angle"], 'orange', linewidth=2) ax.axhline(y=90, color='gray', linestyle='--', alpha=0.7, label="Orthogonal (90°)") ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("Angle (degrees)", fontsize=12) ax.set_title(f"Angle between V_fid and V_steer (strength={strength:.2f})", fontsize=14) ax.set_ylim(0, 180) ax.legend(fontsize=11) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, f"plot_angles_scale_{strength:.2f}.png"), dpi=150) plt.close() # Plot 4: Edit Trajectory (zt_edit movement) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(timesteps, data["zt_edit_delta"], 'teal', linewidth=2) ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("Step Movement (L2 norm)", fontsize=12) ax.set_title(f"Edit Trajectory: Per-step Movement (strength={strength:.2f})", fontsize=14) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, f"plot_trajectory_scale_{strength:.2f}.png"), dpi=150) plt.close() def plot_vector_comparison( log_dir: str, strengths: List[float], output_dir: str, ): """ Generate comparison plots across multiple strengths. Args: log_dir: Directory containing stats JSON files strengths: List of strength values to compare output_dir: Output directory for comparison plots """ import matplotlib.pyplot as plt os.makedirs(output_dir, exist_ok=True) # Load all stats all_data = {} for strength in strengths: stats_path = os.path.join(log_dir, f"stats_scale_{strength:.2f}.json") if os.path.exists(stats_path): with open(stats_path, "r") as f: all_data[strength] = json.load(f) if not all_data: print(f"No stats files found in {log_dir}") return # Color map for different strengths colors = plt.cm.viridis(np.linspace(0, 1, len(strengths))) # Comparison Plot 1: Norms (V_steer) fig, ax = plt.subplots(figsize=(12, 7)) for (strength, data), color in zip(all_data.items(), colors): ax.plot(data["timesteps"], data["V_steer_norm"], color=color, linewidth=2, label=f"strength={strength:.1f}") ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("V_steer L2 Norm", fontsize=12) ax.set_title("V_steer Norm Comparison across Scales", fontsize=14) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, "plot_comparison_norms.png"), dpi=150) plt.close() # Comparison Plot 2: Cosine Similarity fig, ax = plt.subplots(figsize=(12, 7)) for (strength, data), color in zip(all_data.items(), colors): ax.plot(data["timesteps"][1:], data["V_steer_cosine"][1:], color=color, linewidth=2, label=f"strength={strength:.1f}") ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.7) ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("Cosine Similarity", fontsize=12) ax.set_title("V_steer Directional Consistency Comparison", fontsize=14) ax.set_ylim(-0.1, 1.1) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, "plot_comparison_cosine.png"), dpi=150) plt.close() # Comparison Plot 3: Angles fig, ax = plt.subplots(figsize=(12, 7)) for (strength, data), color in zip(all_data.items(), colors): ax.plot(data["timesteps"], data["V_fid_V_steer_angle"], color=color, linewidth=2, label=f"strength={strength:.1f}") ax.axhline(y=90, color='gray', linestyle='--', alpha=0.7) ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("Angle (degrees)", fontsize=12) ax.set_title("V_fid-V_steer Angle Comparison", fontsize=14) ax.set_ylim(0, 180) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, "plot_comparison_angles.png"), dpi=150) plt.close() # Comparison Plot 4: Trajectory (cumulative movement) fig, ax = plt.subplots(figsize=(12, 7)) for (strength, data), color in zip(all_data.items(), colors): cumulative = np.cumsum(data["zt_edit_delta"]) ax.plot(data["timesteps"], cumulative, color=color, linewidth=2, label=f"strength={strength:.1f}") ax.set_xlabel("Timestep", fontsize=12) ax.set_ylabel("Cumulative Movement", fontsize=12) ax.set_title("Edit Trajectory: Cumulative Distance from Source", fontsize=14) ax.legend(fontsize=10) ax.grid(True, alpha=0.3) ax.invert_xaxis() plt.tight_layout() plt.savefig(os.path.join(output_dir, "plot_comparison_trajectory.png"), dpi=150) plt.close() print(f"Comparison plots saved to {output_dir}") def prepare_mask_for_flux( mask: torch.Tensor, target_height: int, target_width: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """ Prepare a binary mask for use with Flux's packed latent format. Args: mask: Input mask tensor. Can be: - (H, W): Single channel 2D mask - (1, H, W): Single channel with batch dim - (B, 1, H, W): Full 4D tensor - (B, H, W): 3D tensor target_height: Target height in latent space (H/8) target_width: Target width in latent space (W/8) device: Target device dtype: Target dtype Returns: Mask tensor of shape (1, seq_len, 1) for packed latent format where seq_len = (H/2) * (W/2) for Flux's 2x2 packing """ # Ensure 4D tensor (B, C, H, W) if mask.dim() == 2: mask = mask.unsqueeze(0).unsqueeze(0) # (H, W) -> (1, 1, H, W) elif mask.dim() == 3: if mask.shape[0] == 1: mask = mask.unsqueeze(0) # (1, H, W) -> (1, 1, H, W) else: mask = mask.unsqueeze(1) # (B, H, W) -> (B, 1, H, W) mask = mask.to(device=device, dtype=dtype) # Resize to latent space dimensions mask_resized = F.interpolate( mask, size=(target_height, target_width), mode='bilinear', align_corners=False ) # For Flux: pack into sequence format # Flux uses 2x2 packing, so (B, C, H, W) -> (B, H/2 * W/2, C*4) # For mask, we just need (B, seq_len, 1) where values are averaged over 2x2 patches B, C, H, W = mask_resized.shape # Reshape for 2x2 packing: (B, 1, H, W) -> (B, H//2, 2, W//2, 2) -> (B, H//2, W//2, 4) mask_packed = mask_resized.view(B, C, H // 2, 2, W // 2, 2) mask_packed = mask_packed.permute(0, 2, 4, 1, 3, 5) # (B, H//2, W//2, C, 2, 2) mask_packed = mask_packed.reshape(B, (H // 2) * (W // 2), C * 4) # (B, seq_len, 4) # Average across the 4 values in each 2x2 patch to get single mask value mask_packed = mask_packed.mean(dim=-1, keepdim=True) # (B, seq_len, 1) return mask_packed @torch.no_grad() def FlowEditFLUX_Slider( pipe, scheduler, x_src, src_prompt: str, tar_prompt: str, tar_prompt_neg: str, negative_prompt: str = "", strength: float = 1.0, T_steps: int = 28, n_avg: int = 1, src_guidance_scale: float = 1.5, tar_guidance_scale: float = 5.5, n_min: int = 0, n_max: int = 24, scale_mode: str = "slider", normalize_v_dir: bool = False, v_dir_target_norm: float = 1.0, log_vectors: bool = False, log_output_dir: Optional[str] = None, ): """ FlowEdit with 3-prompt directional decomposition for continuous strength control. Args: pipe: FluxPipeline scheduler: FlowMatchEulerDiscreteScheduler x_src: Source image latent (B, C, H, W) src_prompt: Source prompt describing the original image (e.g., "a building") tar_prompt: Positive target prompt (e.g., "a severely decayed building") tar_prompt_neg: Negative target prompt (e.g., "a new building") negative_prompt: Negative prompt for CFG (usually empty for Flux) strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) T_steps: Total number of timesteps n_avg: Number of velocity field averaging iterations src_guidance_scale: Guidance strength for source prompt tar_guidance_scale: Guidance strength for target prompts n_min: Number of final steps using regular sampling n_max: Maximum number of steps to apply flow editing scale_mode: Scaling method - "slider" (default), "interp", "step", "cfg", or "direct" - "slider": Scale the direction vector V_delta_s = V_fid + strength * V_steer (FreeSlider-like) - "interp": FlowEdit-based interpolation V_final = V_src + strength * (V_pos - V_src) - "step": Scale the step size dt (FlowEdit paper experiment, causes degradation) - "cfg": Scale the target guidance (tar_guidance_scale * strength) - "direct": Scale the full velocity difference V_delta_s = strength * (V_pos - V_src) without decomposition normalize_v_dir: If True, normalize V_steer to v_dir_target_norm before scaling. This stabilizes edit strength across different CFG settings and prevents both numerical instability (low CFG) and semantic over-editing (high CFG). Only applies when scale_mode="slider". v_dir_target_norm: Target L2 norm for V_steer normalization (default: 1.0). Higher values produce stronger edits per unit strength. log_vectors: If True, record vector statistics and generate visualization plots log_output_dir: Output directory for vector logs (required if log_vectors=True) Returns: Edited latent tensor """ # Validate log_vectors arguments if log_vectors and log_output_dir is None: raise ValueError("log_output_dir must be specified when log_vectors=True") # Initialize logging variables stats_list = [] if log_vectors else None prev_V_steer = None prev_zt_edit = None device = x_src.device # Note: orig_height/width should match the actual image dimensions for correct latent_image_ids # x_src is VAE-encoded latent (H/8, W/8), so multiply by vae_scale_factor to get original size orig_height = x_src.shape[2] * pipe.vae_scale_factor orig_width = x_src.shape[3] * pipe.vae_scale_factor num_channels_latents = pipe.transformer.config.in_channels // 4 pipe.check_inputs( prompt=src_prompt, prompt_2=None, height=orig_height, width=orig_width, callback_on_step_end_tensor_inputs=None, max_sequence_length=512, ) # Prepare latents x_src, latent_src_image_ids = pipe.prepare_latents( batch_size=x_src.shape[0], num_channels_latents=num_channels_latents, height=orig_height, width=orig_width, dtype=x_src.dtype, device=x_src.device, generator=None, latents=x_src ) x_src_packed = pipe._pack_latents( x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3] ) latent_image_ids = latent_src_image_ids # Prepare timesteps sigmas = np.linspace(1.0, 1 / T_steps, T_steps) image_seq_len = x_src_packed.shape[1] mu = calculate_shift( image_seq_len, scheduler.config.base_image_seq_len, scheduler.config.max_image_seq_len, scheduler.config.base_shift, scheduler.config.max_shift, ) timesteps, T_steps = retrieve_timesteps( scheduler, T_steps, device, timesteps=None, sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - T_steps * pipe.scheduler.order, 0) pipe._num_timesteps = len(timesteps) # ============================================ # Encode prompts (3 prompts) # ============================================ # Source prompt ( src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids, ) = pipe.encode_prompt( prompt=src_prompt, prompt_2=None, device=device, ) # Target positive prompt (e.g., "severely decayed") ( tar_pos_prompt_embeds, tar_pos_pooled_prompt_embeds, tar_pos_text_ids, ) = pipe.encode_prompt( prompt=tar_prompt, prompt_2=None, device=device, ) # Target negative prompt (e.g., "new, pristine") ( tar_neg_prompt_embeds, tar_neg_pooled_prompt_embeds, tar_neg_text_ids, ) = pipe.encode_prompt( prompt=tar_prompt_neg, prompt_2=None, device=device, ) # ============================================ # Handle guidance # ============================================ # For cfg mode, strength is applied to tar_guidance_scale effective_tar_guidance = tar_guidance_scale * strength if scale_mode == "cfg" else tar_guidance_scale if pipe.transformer.config.guidance_embeds: src_guidance = torch.tensor([src_guidance_scale], device=device) src_guidance = src_guidance.expand(x_src_packed.shape[0]) tar_pos_guidance = torch.tensor([effective_tar_guidance], device=device) tar_pos_guidance = tar_pos_guidance.expand(x_src_packed.shape[0]) tar_neg_guidance = torch.tensor([effective_tar_guidance], device=device) tar_neg_guidance = tar_neg_guidance.expand(x_src_packed.shape[0]) else: src_guidance = None tar_pos_guidance = None tar_neg_guidance = None # Initialize ODE: zt_edit = x_src zt_edit = x_src_packed.clone() # ============================================ # Main editing loop # ============================================ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"FlowSlider (strength={strength:.2f})"): if T_steps - i > n_max: continue scheduler._init_step_index(t) t_i = scheduler.sigmas[scheduler.step_index] if i < len(timesteps): t_im1 = scheduler.sigmas[scheduler.step_index + 1] else: t_im1 = t_i if T_steps - i > n_min: # Flow-based editing phase V_delta_s_avg = torch.zeros_like(x_src_packed) for k in range(n_avg): # Forward noise fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) # Source trajectory zt_src = (1 - t_i) * x_src_packed + t_i * fwd_noise # Target trajectory (with offset preservation) zt_tar = zt_edit + zt_src - x_src_packed # ============================================ # 3-prompt velocity computation (CORE CHANGE) # ============================================ # Source velocity Vt_src = calc_v_flux( pipe, latents=zt_src, prompt_embeds=src_prompt_embeds, pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance, text_ids=src_text_ids, latent_image_ids=latent_image_ids, t=t ) # Positive target velocity Vt_pos = calc_v_flux( pipe, latents=zt_tar, prompt_embeds=tar_pos_prompt_embeds, pooled_prompt_embeds=tar_pos_pooled_prompt_embeds, guidance=tar_pos_guidance, text_ids=tar_pos_text_ids, latent_image_ids=latent_image_ids, t=t ) # Negative target velocity Vt_neg = calc_v_flux( pipe, latents=zt_tar, prompt_embeds=tar_neg_prompt_embeds, pooled_prompt_embeds=tar_neg_pooled_prompt_embeds, guidance=tar_neg_guidance, text_ids=tar_neg_text_ids, latent_image_ids=latent_image_ids, t=t ) # ============================================ # Directional decomposition # ============================================ # V_steer: Pure edit direction (e.g., "aging" direction) V_steer = Vt_pos - Vt_neg # V_fid: Base change from source to negative target V_fid = Vt_neg - Vt_src # V_delta_s computation depends on scale_mode if scale_mode == "slider": # Slider mode: strength the direction vector (FreeSlider-like) # strength=0 -> V_fid only (tar_neg direction) # strength=1 -> V_fid + V_steer = Vt_pos - Vt_src (tar_pos direction) # Apply V_steer normalization if enabled if normalize_v_dir: # Compute current norm (mean over sequence dimension) v_dir_norm = V_steer.norm(dim=-1, keepdim=True).mean() # Normalize to target norm V_steer_scaled = V_steer * (v_dir_target_norm / (v_dir_norm + 1e-8)) else: V_steer_scaled = V_steer V_delta_s = V_fid + strength * V_steer_scaled elif scale_mode == "direct": # Direct mode: strength the full velocity difference without decomposition # V_delta_s = strength * (V_pos - V_src) # This strengths both V_fid and V_steer together, causing trajectory collapse at strength > 1 V_delta_full = Vt_pos - Vt_src V_delta_s = strength * V_delta_full elif scale_mode == "interp": # Interp mode: FlowEdit-based interpolation # V_final = V_src + strength * (V_pos - V_src) = (1-strength)*V_src + strength*V_pos # For 3-prompt: V_final = V_src + strength * (V_pos - V_neg) + (V_neg - V_src) # = V_neg + strength * V_steer # But we want: V_final = V_src + strength * V_delta_full # where V_delta_full = V_pos - V_src (full edit direction) V_delta_full = Vt_pos - Vt_src V_final = Vt_src + strength * V_delta_full # Store V_final directly, will be used differently in ODE propagation V_delta_s = V_final # This is actually V_final, not a delta elif scale_mode == "step": # Step mode: V_delta_s is fixed at strength=1, dt will be scaled later V_delta_s = V_fid + V_steer # equivalent to strength=1 elif scale_mode == "cfg": # CFG mode: strength was already applied to guidance, use strength=1 for direction V_delta_s = V_fid + V_steer # equivalent to strength=1 else: raise ValueError(f"Unknown scale_mode: {scale_mode}") V_delta_s_avg += (1 / n_avg) * V_delta_s # ============================================ # Vector Logging (if enabled) # ============================================ if log_vectors: # Use the last computed V_fid and V_steer for logging # (when n_avg > 1, this is from the last iteration) # Note: For slider mode with normalize_v_dir, log the original V_steer # to see the raw values before normalization step_stats = compute_vector_stats( V_fid=V_fid, V_steer=V_steer, V_delta_s=V_delta_s_avg, zt_edit=zt_edit, prev_V_steer=prev_V_steer, prev_zt_edit=prev_zt_edit, ) step_stats["timestep"] = t_i.item() if hasattr(t_i, 'item') else float(t_i) # Log normalization info if enabled if normalize_v_dir and scale_mode == "slider": step_stats["normalize_v_dir"] = True step_stats["v_dir_target_norm"] = v_dir_target_norm step_stats["v_dir_original_norm"] = V_steer.norm(dim=-1).mean().item() stats_list.append(step_stats) # Store current values for next iteration comparison prev_V_steer = V_steer.clone() prev_zt_edit = zt_edit.clone() # Propagate ODE zt_edit = zt_edit.to(torch.float32) if scale_mode == "step": # Step mode: strength the step size dt (FlowEdit paper experiment) zt_edit = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg elif scale_mode == "interp": # Interp mode: V_delta_s_avg is actually V_final, use directly zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg else: # Slider and CFG mode: normal dt zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg zt_edit = zt_edit.to(V_delta_s_avg.dtype) else: # Regular sampling for last n_min steps if i == T_steps - n_min: # Initialize SDEDIT-style generation phase fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise) xt_tar = zt_edit + xt_src - x_src_packed # For final steps, use interpolated target based on strength # Interpolate between neg and pos embeddings interp_prompt_embeds = (1 - strength) * tar_neg_prompt_embeds + strength * tar_pos_prompt_embeds interp_pooled_embeds = (1 - strength) * tar_neg_pooled_prompt_embeds + strength * tar_pos_pooled_prompt_embeds Vt_tar = calc_v_flux( pipe, latents=xt_tar, prompt_embeds=interp_prompt_embeds, pooled_prompt_embeds=interp_pooled_embeds, guidance=tar_pos_guidance, text_ids=tar_pos_text_ids, # text_ids are typically the same latent_image_ids=latent_image_ids, t=t ) xt_tar = xt_tar.to(torch.float32) prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar prev_sample = prev_sample.to(Vt_tar.dtype) xt_tar = prev_sample out = zt_edit if n_min == 0 else xt_tar unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor) # ============================================ # Save and visualize vector statistics # ============================================ if log_vectors and stats_list: stats_path = save_vector_stats(stats_list, log_output_dir, strength) plot_vector_stats(stats_path, log_output_dir) print(f"Vector statistics saved to {log_output_dir}") return unpacked_out @torch.no_grad() def FlowEditFLUX_Slider_batch( pipe, scheduler, x_src, src_prompt: str, tar_prompt: str, tar_prompt_neg: str, negative_prompt: str = "", strengths: list = [0.0, 0.25, 0.5, 0.75, 1.0], T_steps: int = 28, n_avg: int = 1, src_guidance_scale: float = 1.5, tar_guidance_scale: float = 5.5, n_min: int = 0, n_max: int = 24, ): """ Batch processing for multiple strengths. More efficient than calling FlowEditFLUX_Slider multiple times as prompt encoding is done only once. Args: strengths: List of strength values to generate (other args same as FlowEditFLUX_Slider) Returns: Dict[float, Tensor]: Mapping from strength to edited latent """ results = {} for strength in strengths: result = FlowEditFLUX_Slider( pipe=pipe, scheduler=scheduler, x_src=x_src, src_prompt=src_prompt, tar_prompt=tar_prompt, tar_prompt_neg=tar_prompt_neg, negative_prompt=negative_prompt, strength=strength, T_steps=T_steps, n_avg=n_avg, src_guidance_scale=src_guidance_scale, tar_guidance_scale=tar_guidance_scale, n_min=n_min, n_max=n_max, ) results[strength] = result return results @torch.no_grad() def FlowEditFLUX_Slider_2prompt( pipe, scheduler, x_src, src_prompt: str, tar_prompt: str, negative_prompt: str = "", strength: float = 1.0, T_steps: int = 28, n_avg: int = 1, src_guidance_scale: float = 1.5, tar_guidance_scale: float = 5.5, n_min: int = 0, n_max: int = 24, ): """ FlowEdit with 2-prompt simple scaling (without neg prompt). 数式: V_delta_s = strength * (V_tar - V_src) strength=0: 編集なし(元画像のまま) strength=1: 通常のFlowEdit(tar方向への完全編集) strength>1: tar方向への過剰編集 Args: pipe: FluxPipeline scheduler: FlowMatchEulerDiscreteScheduler x_src: Source image latent src_prompt: Source prompt tar_prompt: Target prompt (e.g., "a decayed building") strength: Edit intensity (0=no change, 1=full edit) """ device = x_src.device orig_height = x_src.shape[2] * pipe.vae_scale_factor orig_width = x_src.shape[3] * pipe.vae_scale_factor num_channels_latents = pipe.transformer.config.in_channels // 4 pipe.check_inputs( prompt=src_prompt, prompt_2=None, height=orig_height, width=orig_width, callback_on_step_end_tensor_inputs=None, max_sequence_length=512, ) # Prepare latents x_src, latent_src_image_ids = pipe.prepare_latents( batch_size=x_src.shape[0], num_channels_latents=num_channels_latents, height=orig_height, width=orig_width, dtype=x_src.dtype, device=x_src.device, generator=None, latents=x_src ) x_src_packed = pipe._pack_latents( x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3] ) latent_image_ids = latent_src_image_ids # Prepare timesteps sigmas = np.linspace(1.0, 1 / T_steps, T_steps) image_seq_len = x_src_packed.shape[1] mu = calculate_shift( image_seq_len, scheduler.config.base_image_seq_len, scheduler.config.max_image_seq_len, scheduler.config.base_shift, scheduler.config.max_shift, ) timesteps, T_steps = retrieve_timesteps( scheduler, T_steps, device, timesteps=None, sigmas=sigmas, mu=mu, ) # Encode prompts (2 prompts only) ( src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids, ) = pipe.encode_prompt( prompt=src_prompt, prompt_2=None, device=device, ) ( tar_prompt_embeds, tar_pooled_prompt_embeds, tar_text_ids, ) = pipe.encode_prompt( prompt=tar_prompt, prompt_2=None, device=device, ) # Handle guidance if pipe.transformer.config.guidance_embeds: src_guidance = torch.tensor([src_guidance_scale], device=device) src_guidance = src_guidance.expand(x_src_packed.shape[0]) tar_guidance = torch.tensor([tar_guidance_scale], device=device) tar_guidance = tar_guidance.expand(x_src_packed.shape[0]) else: src_guidance = None tar_guidance = None # Initialize ODE zt_edit = x_src_packed.clone() # Main editing loop for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"FlowEdit-2prompt (strength={strength:.2f})"): if T_steps - i > n_max: continue scheduler._init_step_index(t) t_i = scheduler.sigmas[scheduler.step_index] if i < len(timesteps): t_im1 = scheduler.sigmas[scheduler.step_index + 1] else: t_im1 = t_i if T_steps - i > n_min: V_delta_s_avg = torch.zeros_like(x_src_packed) for k in range(n_avg): fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) zt_src = (1 - t_i) * x_src_packed + t_i * fwd_noise zt_tar = zt_edit + zt_src - x_src_packed # 2-prompt velocity computation Vt_src = calc_v_flux( pipe, latents=zt_src, prompt_embeds=src_prompt_embeds, pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance, text_ids=src_text_ids, latent_image_ids=latent_image_ids, t=t ) Vt_tar = calc_v_flux( pipe, latents=zt_tar, prompt_embeds=tar_prompt_embeds, pooled_prompt_embeds=tar_pooled_prompt_embeds, guidance=tar_guidance, text_ids=tar_text_ids, latent_image_ids=latent_image_ids, t=t ) # Simple scaling: V_delta_s = strength * (V_tar - V_src) V_delta_s = strength * (Vt_tar - Vt_src) V_delta_s_avg += (1 / n_avg) * V_delta_s zt_edit = zt_edit.to(torch.float32) zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg zt_edit = zt_edit.to(V_delta_s_avg.dtype) else: if i == T_steps - n_min: fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise) xt_tar = zt_edit + xt_src - x_src_packed # Prompt interpolation for stability at high strengths # interp = (1 - strength) * src + strength * tar interp_prompt_embeds = (1 - strength) * src_prompt_embeds + strength * tar_prompt_embeds interp_pooled_embeds = (1 - strength) * src_pooled_prompt_embeds + strength * tar_pooled_prompt_embeds Vt_tar = calc_v_flux( pipe, latents=xt_tar, prompt_embeds=interp_prompt_embeds, pooled_prompt_embeds=interp_pooled_embeds, guidance=tar_guidance, text_ids=tar_text_ids, latent_image_ids=latent_image_ids, t=t ) xt_tar = xt_tar.to(torch.float32) prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar prev_sample = prev_sample.to(Vt_tar.dtype) xt_tar = prev_sample out = zt_edit if n_min == 0 else xt_tar unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor) return unpacked_out @torch.no_grad() def FlowEditFLUX_Slider_with_mask( pipe, scheduler, x_src, src_prompt: str, tar_prompt: str, tar_prompt_neg: str, mask: torch.Tensor, negative_prompt: str = "", strength: float = 1.0, T_steps: int = 28, n_avg: int = 1, src_guidance_scale: float = 1.5, tar_guidance_scale: float = 5.5, n_min: int = 0, n_max: int = 24, scale_mode: str = "slider", ): """ FlowEdit with 3-prompt directional decomposition and mask-based local editing. This function applies edits only to the masked region, preserving the unmasked areas from the source image. Args: pipe: FluxPipeline scheduler: FlowMatchEulerDiscreteScheduler x_src: Source image latent (B, C, H, W) src_prompt: Source prompt describing the original image tar_prompt: Positive target prompt (e.g., "a severely decayed building") tar_prompt_neg: Negative target prompt (e.g., "a new building") mask: Binary mask tensor indicating edit region. Shape: (H, W), (1, H, W), (B, H, W), or (B, 1, H, W) Values: 1 = edit region, 0 = preserve original negative_prompt: Negative prompt for CFG (usually empty for Flux) strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) T_steps: Total number of timesteps n_avg: Number of velocity field averaging iterations src_guidance_scale: Guidance strength for source prompt tar_guidance_scale: Guidance strength for target prompts n_min: Number of final steps using regular sampling n_max: Maximum number of steps to apply flow editing scale_mode: Scaling method - "slider" (default), "interp", "step", or "cfg" Returns: Edited latent tensor with edits applied only in masked region """ device = x_src.device orig_height = x_src.shape[2] * pipe.vae_scale_factor orig_width = x_src.shape[3] * pipe.vae_scale_factor num_channels_latents = pipe.transformer.config.in_channels // 4 pipe.check_inputs( prompt=src_prompt, prompt_2=None, height=orig_height, width=orig_width, callback_on_step_end_tensor_inputs=None, max_sequence_length=512, ) # Prepare latents x_src, latent_src_image_ids = pipe.prepare_latents( batch_size=x_src.shape[0], num_channels_latents=num_channels_latents, height=orig_height, width=orig_width, dtype=x_src.dtype, device=x_src.device, generator=None, latents=x_src ) x_src_packed = pipe._pack_latents( x_src, x_src.shape[0], num_channels_latents, x_src.shape[2], x_src.shape[3] ) latent_image_ids = latent_src_image_ids # Prepare mask for packed latent format mask_packed = prepare_mask_for_flux( mask=mask, target_height=x_src.shape[2], target_width=x_src.shape[3], device=device, dtype=x_src.dtype, ) # Prepare timesteps sigmas = np.linspace(1.0, 1 / T_steps, T_steps) image_seq_len = x_src_packed.shape[1] mu = calculate_shift( image_seq_len, scheduler.config.base_image_seq_len, scheduler.config.max_image_seq_len, scheduler.config.base_shift, scheduler.config.max_shift, ) timesteps, T_steps = retrieve_timesteps( scheduler, T_steps, device, timesteps=None, sigmas=sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - T_steps * pipe.scheduler.order, 0) pipe._num_timesteps = len(timesteps) # Encode prompts (3 prompts) ( src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids, ) = pipe.encode_prompt( prompt=src_prompt, prompt_2=None, device=device, ) ( tar_pos_prompt_embeds, tar_pos_pooled_prompt_embeds, tar_pos_text_ids, ) = pipe.encode_prompt( prompt=tar_prompt, prompt_2=None, device=device, ) ( tar_neg_prompt_embeds, tar_neg_pooled_prompt_embeds, tar_neg_text_ids, ) = pipe.encode_prompt( prompt=tar_prompt_neg, prompt_2=None, device=device, ) # Handle guidance # For cfg mode, strength is applied to tar_guidance_scale effective_tar_guidance = tar_guidance_scale * strength if scale_mode == "cfg" else tar_guidance_scale if pipe.transformer.config.guidance_embeds: src_guidance = torch.tensor([src_guidance_scale], device=device) src_guidance = src_guidance.expand(x_src_packed.shape[0]) tar_pos_guidance = torch.tensor([effective_tar_guidance], device=device) tar_pos_guidance = tar_pos_guidance.expand(x_src_packed.shape[0]) tar_neg_guidance = torch.tensor([effective_tar_guidance], device=device) tar_neg_guidance = tar_neg_guidance.expand(x_src_packed.shape[0]) else: src_guidance = None tar_pos_guidance = None tar_neg_guidance = None # Initialize ODE: zt_edit = x_src zt_edit = x_src_packed.clone() # Main editing loop for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"FlowEdit-Slider-Mask (strength={strength:.2f})"): if T_steps - i > n_max: continue scheduler._init_step_index(t) t_i = scheduler.sigmas[scheduler.step_index] if i < len(timesteps): t_im1 = scheduler.sigmas[scheduler.step_index + 1] else: t_im1 = t_i if T_steps - i > n_min: # Flow-based editing phase V_delta_s_avg = torch.zeros_like(x_src_packed) for k in range(n_avg): # Forward noise fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) # Source trajectory zt_src = (1 - t_i) * x_src_packed + t_i * fwd_noise # Target trajectory (with offset preservation) zt_tar = zt_edit + zt_src - x_src_packed # 3-prompt velocity computation Vt_src = calc_v_flux( pipe, latents=zt_src, prompt_embeds=src_prompt_embeds, pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance, text_ids=src_text_ids, latent_image_ids=latent_image_ids, t=t ) Vt_pos = calc_v_flux( pipe, latents=zt_tar, prompt_embeds=tar_pos_prompt_embeds, pooled_prompt_embeds=tar_pos_pooled_prompt_embeds, guidance=tar_pos_guidance, text_ids=tar_pos_text_ids, latent_image_ids=latent_image_ids, t=t ) Vt_neg = calc_v_flux( pipe, latents=zt_tar, prompt_embeds=tar_neg_prompt_embeds, pooled_prompt_embeds=tar_neg_pooled_prompt_embeds, guidance=tar_neg_guidance, text_ids=tar_neg_text_ids, latent_image_ids=latent_image_ids, t=t ) # Directional decomposition V_steer = Vt_pos - Vt_neg V_fid = Vt_neg - Vt_src # V_delta_s computation depends on scale_mode if scale_mode == "slider": V_delta_s = V_fid + strength * V_steer elif scale_mode == "interp": # Interp mode: FlowEdit-based interpolation V_delta_full = Vt_pos - Vt_src V_final = Vt_src + strength * V_delta_full V_delta_s = V_final # This is actually V_final, not a delta elif scale_mode == "step": V_delta_s = V_fid + V_steer # equivalent to strength=1 elif scale_mode == "cfg": V_delta_s = V_fid + V_steer # equivalent to strength=1 else: raise ValueError(f"Unknown scale_mode: {scale_mode}") V_delta_s_avg += (1 / n_avg) * V_delta_s # Propagate ODE (without mask first) zt_edit = zt_edit.to(torch.float32) if scale_mode == "step": # Step mode: strength the step size dt zt_edit_new = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg elif scale_mode == "interp": # Interp mode: V_delta_s_avg is actually V_final, use directly zt_edit_new = zt_edit + (t_im1 - t_i) * V_delta_s_avg else: zt_edit_new = zt_edit + (t_im1 - t_i) * V_delta_s_avg # ============================================ # MASK APPLICATION (LocalBlend style): # Apply mask to the RESULT, not the velocity # zt_edit = x_src + mask * (zt_edit_new - x_src) # This forces unmasked regions to stay at source # ============================================ zt_edit = x_src_packed + mask_packed * (zt_edit_new - x_src_packed) zt_edit = zt_edit.to(V_delta_s_avg.dtype) else: # Regular sampling for last n_min steps if i == T_steps - n_min: # Initialize SDEDIT-style generation phase fwd_noise = torch.randn_like(x_src_packed).to(x_src_packed.device) xt_src = scale_noise(scheduler, x_src_packed, t, noise=fwd_noise) xt_tar = zt_edit + xt_src - x_src_packed # Interpolate between neg and pos embeddings interp_prompt_embeds = (1 - strength) * tar_neg_prompt_embeds + strength * tar_pos_prompt_embeds interp_pooled_embeds = (1 - strength) * tar_neg_pooled_prompt_embeds + strength * tar_pos_pooled_prompt_embeds Vt_tar = calc_v_flux( pipe, latents=xt_tar, prompt_embeds=interp_prompt_embeds, pooled_prompt_embeds=interp_pooled_embeds, guidance=tar_pos_guidance, text_ids=tar_pos_text_ids, latent_image_ids=latent_image_ids, t=t ) xt_tar = xt_tar.to(torch.float32) xt_tar_new = xt_tar + (t_im1 - t_i) * Vt_tar # LocalBlend style mask application for n_min phase xt_tar = x_src_packed + mask_packed * (xt_tar_new - x_src_packed) xt_tar = xt_tar.to(Vt_tar.dtype) out = zt_edit if n_min == 0 else xt_tar unpacked_out = pipe._unpack_latents(out, orig_height, orig_width, pipe.vae_scale_factor) return unpacked_out # ============================================ # SD3 Slider Implementation # ============================================ @torch.no_grad() def FlowEditSD3_Slider( pipe, scheduler, x_src, src_prompt: str, tar_prompt: str, tar_prompt_neg: str, negative_prompt: str = "", strength: float = 1.0, T_steps: int = 50, n_avg: int = 1, src_guidance_scale: float = 3.5, tar_guidance_scale: float = 13.5, n_min: int = 0, n_max: int = 33, scale_mode: str = "slider", normalize_v_dir: bool = False, v_dir_target_norm: float = 1.0, log_vectors: bool = False, log_output_dir: Optional[str] = None, ): """ FlowSlider for SD3 with 3-prompt directional decomposition. Uses 6-way CFG batching for efficient computation: - [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] Args: pipe: StableDiffusion3Pipeline scheduler: Scheduler (typically FlowMatchEulerDiscreteScheduler) x_src: Source image latent (B, C, H, W) src_prompt: Source prompt describing the original image tar_prompt: Positive target prompt (e.g., "a severely decayed building") tar_prompt_neg: Negative target prompt (e.g., "a new building") negative_prompt: Negative prompt for CFG (usually empty) strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) T_steps: Total number of timesteps (default: 50 for SD3) n_avg: Number of velocity field averaging iterations src_guidance_scale: Guidance strength for source prompt (default: 3.5 for SD3) tar_guidance_scale: Guidance strength for target prompts (default: 13.5 for SD3) n_min: Number of final steps using regular sampling n_max: Maximum number of steps to apply flow editing (default: 33 for SD3) scale_mode: Scaling method - "slider" (default), "interp", "step", "cfg", or "direct" - "slider": Scale the direction vector V_delta_s = V_fid + strength * V_steer - "direct": Scale the full velocity difference V_delta_s = strength * (V_pos - V_src) without decomposition normalize_v_dir: If True, normalize V_steer to v_dir_target_norm before scaling v_dir_target_norm: Target L2 norm for V_steer normalization log_vectors: If True, record vector statistics log_output_dir: Output directory for vector logs Returns: Edited latent tensor """ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps # Validate log_vectors arguments if log_vectors and log_output_dir is None: raise ValueError("log_output_dir must be specified when log_vectors=True") # Initialize logging variables stats_list = [] if log_vectors else None prev_V_steer = None prev_zt_edit = None device = x_src.device # Retrieve timesteps timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, device, timesteps=None) num_warmup_steps = max(len(timesteps) - T_steps * scheduler.order, 0) pipe._num_timesteps = len(timesteps) # ============================================ # Encode prompts (3 prompts with CFG) # ============================================ # Source prompt pipe._guidance_scale = src_guidance_scale ( src_prompt_embeds, src_negative_prompt_embeds, src_pooled_prompt_embeds, src_negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt=src_prompt, prompt_2=None, prompt_3=None, negative_prompt=negative_prompt, do_classifier_free_guidance=pipe.do_classifier_free_guidance, device=device, ) # Target positive prompt pipe._guidance_scale = tar_guidance_scale ( tar_pos_prompt_embeds, tar_pos_negative_prompt_embeds, tar_pos_pooled_prompt_embeds, tar_pos_negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt=tar_prompt, prompt_2=None, prompt_3=None, negative_prompt=negative_prompt, do_classifier_free_guidance=pipe.do_classifier_free_guidance, device=device, ) # Target negative prompt ( tar_neg_prompt_embeds, tar_neg_negative_prompt_embeds, tar_neg_pooled_prompt_embeds, tar_neg_negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt=tar_prompt_neg, prompt_2=None, prompt_3=None, negative_prompt=negative_prompt, do_classifier_free_guidance=pipe.do_classifier_free_guidance, device=device, ) # ============================================ # Prepare 6-way CFG embeddings # [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] # ============================================ all_prompt_embeds = torch.cat([ src_negative_prompt_embeds, # src_uncond src_prompt_embeds, # src_cond tar_pos_negative_prompt_embeds, # tar_pos_uncond tar_pos_prompt_embeds, # tar_pos_cond tar_neg_negative_prompt_embeds, # tar_neg_uncond tar_neg_prompt_embeds, # tar_neg_cond ], dim=0) all_pooled_prompt_embeds = torch.cat([ src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds, tar_pos_negative_pooled_prompt_embeds, tar_pos_pooled_prompt_embeds, tar_neg_negative_pooled_prompt_embeds, tar_neg_pooled_prompt_embeds, ], dim=0) # Initialize ODE: zt_edit = x_src zt_edit = x_src.clone() # ============================================ # Main editing loop # ============================================ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"SD3-FlowSlider (strength={strength:.2f})"): if T_steps - i > n_max: continue t_i = t / 1000 if i + 1 < len(timesteps): t_im1 = timesteps[i + 1] / 1000 else: t_im1 = torch.zeros_like(t_i).to(t_i.device) if T_steps - i > n_min: # Flow-based editing phase V_delta_s_avg = torch.zeros_like(x_src) for k in range(n_avg): # Forward noise fwd_noise = torch.randn_like(x_src).to(x_src.device) # Source trajectory zt_src = (1 - t_i) * x_src + t_i * fwd_noise # Target trajectory (with offset preservation) zt_tar = zt_edit + zt_src - x_src # ============================================ # 6-way CFG batched computation # ============================================ # Latents: [zt_src, zt_src, zt_tar, zt_tar, zt_tar, zt_tar] all_latents = torch.cat([ zt_src, zt_src, # src_uncond, src_cond zt_tar, zt_tar, # tar_pos_uncond, tar_pos_cond zt_tar, zt_tar, # tar_neg_uncond, tar_neg_cond ]) # Timestep broadcast timestep_batch = t.expand(all_latents.shape[0]) # Single transformer call with torch.no_grad(): noise_pred_all = pipe.transformer( hidden_states=all_latents, timestep=timestep_batch, encoder_hidden_states=all_prompt_embeds, pooled_projections=all_pooled_prompt_embeds, joint_attention_kwargs=None, return_dict=False, )[0] # Split into 6 parts ( src_noise_uncond, src_noise_cond, tar_pos_noise_uncond, tar_pos_noise_cond, tar_neg_noise_uncond, tar_neg_noise_cond, ) = noise_pred_all.chunk(6) # Apply CFG to get 3 velocities Vt_src = src_noise_uncond + src_guidance_scale * (src_noise_cond - src_noise_uncond) Vt_pos = tar_pos_noise_uncond + tar_guidance_scale * (tar_pos_noise_cond - tar_pos_noise_uncond) Vt_neg = tar_neg_noise_uncond + tar_guidance_scale * (tar_neg_noise_cond - tar_neg_noise_uncond) # ============================================ # Directional decomposition (same as FLUX) # ============================================ V_steer = Vt_pos - Vt_neg V_fid = Vt_neg - Vt_src # V_delta_s computation depends on scale_mode if scale_mode == "slider": if normalize_v_dir: v_dir_norm = V_steer.norm(dim=-1, keepdim=True).mean() V_steer_scaled = V_steer * (v_dir_target_norm / (v_dir_norm + 1e-8)) else: V_steer_scaled = V_steer V_delta_s = V_fid + strength * V_steer_scaled elif scale_mode == "direct": # Direct mode: strength the full velocity difference without decomposition V_delta_full = Vt_pos - Vt_src V_delta_s = strength * V_delta_full elif scale_mode == "interp": V_delta_full = Vt_pos - Vt_src V_final = Vt_src + strength * V_delta_full V_delta_s = V_final elif scale_mode == "step": V_delta_s = V_fid + V_steer elif scale_mode == "cfg": V_delta_s = V_fid + V_steer else: raise ValueError(f"Unknown scale_mode: {scale_mode}") V_delta_s_avg += (1 / n_avg) * V_delta_s # ============================================ # Vector Logging (if enabled) # ============================================ if log_vectors: step_stats = compute_vector_stats( V_fid=V_fid, V_steer=V_steer, V_delta_s=V_delta_s_avg, zt_edit=zt_edit, prev_V_steer=prev_V_steer, prev_zt_edit=prev_zt_edit, ) step_stats["timestep"] = t_i.item() if hasattr(t_i, 'item') else float(t_i) if normalize_v_dir and scale_mode == "slider": step_stats["normalize_v_dir"] = True step_stats["v_dir_target_norm"] = v_dir_target_norm step_stats["v_dir_original_norm"] = V_steer.norm(dim=-1).mean().item() stats_list.append(step_stats) prev_V_steer = V_steer.clone() prev_zt_edit = zt_edit.clone() # Propagate ODE zt_edit = zt_edit.to(torch.float32) if scale_mode == "step": zt_edit = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg else: zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg zt_edit = zt_edit.to(V_delta_s_avg.dtype) else: # Regular sampling for last n_min steps if i == T_steps - n_min: # Initialize SDEDIT-style generation phase fwd_noise = torch.randn_like(x_src).to(x_src.device) xt_src = scale_noise(scheduler, x_src, t, noise=fwd_noise) xt_tar = zt_edit + xt_src - x_src # For final steps, use interpolated target interp_prompt_embeds = (1 - strength) * tar_neg_prompt_embeds + strength * tar_pos_prompt_embeds interp_pooled_embeds = (1 - strength) * tar_neg_pooled_prompt_embeds + strength * tar_pos_pooled_prompt_embeds # 2-way CFG for interpolated target interp_all_embeds = torch.cat([tar_pos_negative_prompt_embeds, interp_prompt_embeds], dim=0) interp_all_pooled = torch.cat([tar_pos_negative_pooled_prompt_embeds, interp_pooled_embeds], dim=0) interp_latents = torch.cat([xt_tar, xt_tar]) timestep_batch = t.expand(2) with torch.no_grad(): noise_pred_interp = pipe.transformer( hidden_states=interp_latents, timestep=timestep_batch, encoder_hidden_states=interp_all_embeds, pooled_projections=interp_all_pooled, joint_attention_kwargs=None, return_dict=False, )[0] interp_uncond, interp_cond = noise_pred_interp.chunk(2) Vt_tar = interp_uncond + tar_guidance_scale * (interp_cond - interp_uncond) xt_tar = xt_tar.to(torch.float32) prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar prev_sample = prev_sample.to(Vt_tar.dtype) xt_tar = prev_sample out = zt_edit if n_min == 0 else xt_tar # ============================================ # Save and visualize vector statistics # ============================================ if log_vectors and stats_list: stats_path = save_vector_stats(stats_list, log_output_dir, strength) plot_vector_stats(stats_path, log_output_dir) print(f"Vector statistics saved to {log_output_dir}") return out # ============================================ # Z-Image Slider Implementation # ============================================ @torch.no_grad() def FlowEditZImage_Slider( pipe, scheduler, x_src, src_prompt: str, tar_prompt: str, tar_prompt_neg: str, negative_prompt: str = "", strength: float = 1.0, T_steps: int = 28, n_avg: int = 1, src_guidance_scale: float = 2.0, tar_guidance_scale: float = 6.0, n_min: int = 0, n_max: int = 20, max_sequence_length: int = 512, scale_mode: str = "slider", normalize_v_dir: bool = False, v_dir_target_norm: float = 1.0, log_vectors: bool = False, log_output_dir: Optional[str] = None, ): """ FlowSlider for Z-Image with 3-prompt directional decomposition. Uses 6-way CFG with list-based processing (Z-Image specific): - [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] Args: pipe: ZImagePipeline scheduler: Scheduler (typically FlowMatchEulerDiscreteScheduler) x_src: Source image latent (B, C, H, W) src_prompt: Source prompt describing the original image tar_prompt: Positive target prompt (e.g., "a severely decayed building") tar_prompt_neg: Negative target prompt (e.g., "a new building") negative_prompt: Negative prompt for CFG (usually empty) strength: Edit intensity strength (0.0 = tar_neg direction, 1.0 = tar_pos direction) T_steps: Total number of timesteps (default: 28 for Z-Image) n_avg: Number of velocity field averaging iterations src_guidance_scale: Guidance strength for source prompt (default: 2.0 for Z-Image) tar_guidance_scale: Guidance strength for target prompts (default: 6.0 for Z-Image) n_min: Number of final steps using regular sampling n_max: Maximum number of steps to apply flow editing (default: 20 for Z-Image) max_sequence_length: Maximum prompt token length (default: 512) scale_mode: Scaling method - "slider" (default), "interp", "step", "cfg", or "direct" - "slider": Scale the direction vector V_delta_s = V_fid + strength * V_steer - "direct": Scale the full velocity difference V_delta_s = strength * (V_pos - V_src) without decomposition normalize_v_dir: If True, normalize V_steer to v_dir_target_norm before scaling v_dir_target_norm: Target L2 norm for V_steer normalization log_vectors: If True, record vector statistics log_output_dir: Output directory for vector logs Returns: Edited latent tensor """ from FlowEdit_utils import calculate_shift # Validate log_vectors arguments if log_vectors and log_output_dir is None: raise ValueError("log_output_dir must be specified when log_vectors=True") # Initialize logging variables stats_list = [] if log_vectors else None prev_V_steer = None prev_zt_edit = None device = x_src.device # ============================================ # Timestep preparation (Z-Image specific) # ============================================ image_seq_len = (x_src.shape[2] // 2) * (x_src.shape[3] // 2) mu = calculate_shift( image_seq_len, scheduler.config.get("base_image_seq_len", 256), scheduler.config.get("max_image_seq_len", 4096), scheduler.config.get("base_shift", 0.5), scheduler.config.get("max_shift", 1.15), ) scheduler.sigma_min = 0.0 timesteps, T_steps = retrieve_timesteps( scheduler, T_steps, device, sigmas=None, mu=mu, ) # ============================================ # Encode prompts (3 prompts with CFG) # Z-Image returns List[Tensor] format # ============================================ # Source prompt src_prompt_embeds, src_negative_prompt_embeds = pipe.encode_prompt( prompt=src_prompt, device=device, do_classifier_free_guidance=True, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length, ) # Target positive prompt tar_pos_prompt_embeds, tar_pos_negative_prompt_embeds = pipe.encode_prompt( prompt=tar_prompt, device=device, do_classifier_free_guidance=True, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length, ) # Target negative prompt tar_neg_prompt_embeds, tar_neg_negative_prompt_embeds = pipe.encode_prompt( prompt=tar_prompt_neg, device=device, do_classifier_free_guidance=True, negative_prompt=negative_prompt, max_sequence_length=max_sequence_length, ) # Extract embeddings from list format src_neg_emb = src_negative_prompt_embeds[0] if isinstance(src_negative_prompt_embeds, list) else src_negative_prompt_embeds src_pos_emb = src_prompt_embeds[0] if isinstance(src_prompt_embeds, list) else src_prompt_embeds tar_pos_neg_emb = tar_pos_negative_prompt_embeds[0] if isinstance(tar_pos_negative_prompt_embeds, list) else tar_pos_negative_prompt_embeds tar_pos_pos_emb = tar_pos_prompt_embeds[0] if isinstance(tar_pos_prompt_embeds, list) else tar_pos_prompt_embeds tar_neg_neg_emb = tar_neg_negative_prompt_embeds[0] if isinstance(tar_neg_negative_prompt_embeds, list) else tar_neg_negative_prompt_embeds tar_neg_pos_emb = tar_neg_prompt_embeds[0] if isinstance(tar_neg_prompt_embeds, list) else tar_neg_prompt_embeds # 6-way prompt embeddings list: # [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] prompt_embeds_list = [ src_neg_emb, # src_uncond src_pos_emb, # src_cond tar_pos_neg_emb, # tar_pos_uncond tar_pos_pos_emb, # tar_pos_cond tar_neg_neg_emb, # tar_neg_uncond tar_neg_pos_emb, # tar_neg_cond ] # Initialize ODE: zt_edit = x_src zt_edit = x_src.clone() # ============================================ # Main editing loop # ============================================ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc=f"ZImage-FlowSlider (strength={strength:.2f})"): if T_steps - i > n_max: continue # Get timestep values from scheduler sigmas scheduler._init_step_index(t) t_i = scheduler.sigmas[scheduler.step_index] if scheduler.step_index + 1 < len(scheduler.sigmas): t_im1 = scheduler.sigmas[scheduler.step_index + 1] else: t_im1 = torch.zeros_like(t_i) if T_steps - i > n_min: # Flow-based editing phase V_delta_s_avg = torch.zeros_like(x_src) for k in range(n_avg): # Forward noise fwd_noise = torch.randn_like(x_src).to(device) # Source trajectory zt_src = (1 - t_i) * x_src + t_i * fwd_noise # Target trajectory (with offset preservation) zt_tar = zt_edit + zt_src - x_src # ============================================ # 6-way CFG with list-based processing (Z-Image specific) # ============================================ # Prepare latents list: [src_uncond, src_cond, tar_pos_uncond, tar_pos_cond, tar_neg_uncond, tar_neg_cond] # Z-Image expects List[(C, 1, H, W)] format transformer_dtype = pipe.transformer.dtype latents_list = [ zt_src.squeeze(0).unsqueeze(1).to(transformer_dtype), # src_uncond zt_src.squeeze(0).unsqueeze(1).to(transformer_dtype), # src_cond zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_pos_uncond zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_pos_cond zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_neg_uncond zt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # tar_neg_cond ] # Z-Image timestep format: (1000 - t) / 1000 timestep_zimage = (1000 - t) / 1000 timestep_batch = timestep_zimage.expand(len(latents_list)) # Single transformer call with list input with torch.no_grad(): noise_pred_list = pipe.transformer( latents_list, timestep_batch, prompt_embeds_list, return_dict=False, )[0] # Apply sign inversion and squeeze frame dimension (Z-Image specific) noise_pred_list = [-pred.squeeze(1) for pred in noise_pred_list] # Split into 6 predictions ( src_noise_uncond, src_noise_cond, tar_pos_noise_uncond, tar_pos_noise_cond, tar_neg_noise_uncond, tar_neg_noise_cond, ) = noise_pred_list # Apply CFG to get 3 velocities Vt_src = src_noise_uncond + src_guidance_scale * (src_noise_cond - src_noise_uncond) Vt_pos = tar_pos_noise_uncond + tar_guidance_scale * (tar_pos_noise_cond - tar_pos_noise_uncond) Vt_neg = tar_neg_noise_uncond + tar_guidance_scale * (tar_neg_noise_cond - tar_neg_noise_uncond) # ============================================ # Directional decomposition (same as FLUX/SD3) # ============================================ V_steer = Vt_pos - Vt_neg V_fid = Vt_neg - Vt_src # V_delta_s computation depends on scale_mode if scale_mode == "slider": if normalize_v_dir: v_dir_norm = V_steer.norm(dim=-1, keepdim=True).mean() V_steer_scaled = V_steer * (v_dir_target_norm / (v_dir_norm + 1e-8)) else: V_steer_scaled = V_steer V_delta_s = V_fid + strength * V_steer_scaled elif scale_mode == "direct": # Direct mode: strength the full velocity difference without decomposition V_delta_full = Vt_pos - Vt_src V_delta_s = strength * V_delta_full elif scale_mode == "interp": V_delta_full = Vt_pos - Vt_src V_final = Vt_src + strength * V_delta_full V_delta_s = V_final elif scale_mode == "step": V_delta_s = V_fid + V_steer elif scale_mode == "cfg": V_delta_s = V_fid + V_steer else: raise ValueError(f"Unknown scale_mode: {scale_mode}") # Add batch dimension back for accumulation V_delta_s_avg += (1 / n_avg) * V_delta_s.unsqueeze(0) # ============================================ # Vector Logging (if enabled) # ============================================ if log_vectors: step_stats = compute_vector_stats( V_fid=V_fid.unsqueeze(0), V_steer=V_steer.unsqueeze(0), V_delta_s=V_delta_s_avg, zt_edit=zt_edit, prev_V_steer=prev_V_steer, prev_zt_edit=prev_zt_edit, ) step_stats["timestep"] = t_i.item() if hasattr(t_i, 'item') else float(t_i) if normalize_v_dir and scale_mode == "slider": step_stats["normalize_v_dir"] = True step_stats["v_dir_target_norm"] = v_dir_target_norm step_stats["v_dir_original_norm"] = V_steer.norm(dim=-1).mean().item() stats_list.append(step_stats) prev_V_steer = V_steer.unsqueeze(0).clone() prev_zt_edit = zt_edit.clone() # Propagate ODE zt_edit = zt_edit.to(torch.float32) if scale_mode == "step": zt_edit = zt_edit + strength * (t_im1 - t_i) * V_delta_s_avg else: zt_edit = zt_edit + (t_im1 - t_i) * V_delta_s_avg zt_edit = zt_edit.to(V_delta_s_avg.dtype) else: # Regular sampling for last n_min steps if i == T_steps - n_min: # Initialize SDEDIT-style generation phase fwd_noise = torch.randn_like(x_src).to(device) xt_src = scale_noise(scheduler, x_src, t, noise=fwd_noise) xt_tar = zt_edit + xt_src - x_src # For final steps, use interpolated target embedding interp_emb = (1 - strength) * tar_neg_pos_emb + strength * tar_pos_pos_emb # 2-way CFG for interpolated target transformer_dtype = pipe.transformer.dtype latents_list = [ xt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # uncond xt_tar.squeeze(0).unsqueeze(1).to(transformer_dtype), # cond ] prompt_embeds_2way = [tar_pos_neg_emb, interp_emb] timestep_zimage = (1000 - t) / 1000 timestep_batch = timestep_zimage.expand(2) with torch.no_grad(): noise_pred_list = pipe.transformer( latents_list, timestep_batch, prompt_embeds_2way, return_dict=False, )[0] # Apply sign inversion and squeeze noise_pred_list = [-pred.squeeze(1) for pred in noise_pred_list] interp_uncond, interp_cond = noise_pred_list Vt_tar = interp_uncond + tar_guidance_scale * (interp_cond - interp_uncond) xt_tar = xt_tar.to(torch.float32) prev_sample = xt_tar + (t_im1 - t_i) * Vt_tar.unsqueeze(0) prev_sample = prev_sample.to(Vt_tar.dtype) xt_tar = prev_sample out = zt_edit if n_min == 0 else xt_tar # ============================================ # Save and visualize vector statistics # ============================================ if log_vectors and stats_list: stats_path = save_vector_stats(stats_list, log_output_dir, strength) plot_vector_stats(stats_path, log_output_dir) print(f"Vector statistics saved to {log_output_dir}") return out