Spaces:
Sleeping
Sleeping
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.colors import hsv_to_rgb | |
| from mpl_toolkits.mplot3d import Axes3D | |
| from src.image_util import normalized_linear_to_srgb | |
| # ============================================================================ | |
| # 3D Scatter Plotting | |
| # ============================================================================ | |
| def plot_3d_scatter(pixels, colors, ax, title, axis_labels, point_size=1, alpha=0.3): | |
| """Unified 3D scatter plot on existing axis.""" | |
| ax.scatter( | |
| pixels[:, 0], pixels[:, 1], pixels[:, 2], | |
| c=colors, s=point_size, alpha=alpha, rasterized=True | |
| ) | |
| ax.set_xlabel(axis_labels[0], fontsize=10) | |
| ax.set_ylabel(axis_labels[1], fontsize=10) | |
| ax.set_zlabel(axis_labels[2], fontsize=10) | |
| ax.set_title(title, fontsize=12) | |
| def plot_rgb_space(pixels, colors, title="RGB Space", save_path=None, alpha=0.3, point_size=1): | |
| """Plot RGB space with origin marker.""" | |
| fig = plt.figure(figsize=(10, 8)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| plot_3d_scatter(pixels, colors, ax, title, ["R", "G", "B"], point_size, alpha) | |
| ax.scatter([0], [0], [0], c="black", s=100, marker="o", label="Origin", depthshade=False) | |
| ax.set_xlim([0, 1]) | |
| ax.set_ylim([0, 1]) | |
| ax.set_zlim([0, 1]) | |
| plt.tight_layout() | |
| _save_or_show(save_path) | |
| def plot_log_rgb_space(pixels, colors, title="Log RGB Space", save_path=None, alpha=0.3, point_size=1): | |
| """Plot log-RGB space.""" | |
| log_pixels = np.log(pixels + 1e-6) | |
| fig = plt.figure(figsize=(10, 8)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| plot_3d_scatter(log_pixels, colors, ax, title, ["log(R)", "log(G)", "log(B)"], point_size, alpha) | |
| plt.tight_layout() | |
| _save_or_show(save_path) | |
| # ============================================================================ | |
| # Image & Color Space Visualization | |
| # ============================================================================ | |
| def plot_img_rgb_logrgb(axs, norm_content_img, norm_style_img, log_content_img, log_style_img, | |
| log_cluster_bin_masks=None, log_cluster_dark_points=None, | |
| log_cluster_bright_points=None): | |
| """Complete image + RGB + log-RGB visualization.""" | |
| content_srgb = normalized_linear_to_srgb(norm_content_img) | |
| style_srgb = normalized_linear_to_srgb(norm_style_img) | |
| # Images | |
| axs["style_img"].imshow(style_srgb) | |
| axs["style_img"].set_title("Style", fontsize=12) | |
| axs["style_img"].axis("off") | |
| axs["content_img"].imshow(content_srgb) | |
| axs["content_img"].set_title("Content", fontsize=12) | |
| axs["content_img"].axis("off") | |
| # Sample pixels | |
| num_samples = 5000 | |
| log_content_flat = log_content_img.reshape(-1, 3) | |
| log_style_flat = log_style_img.reshape(-1, 3) | |
| content_flat = norm_content_img.reshape(-1, 3) | |
| style_flat = norm_style_img.reshape(-1, 3) | |
| content_color_flat = content_srgb.reshape(-1, 3) / 255.0 | |
| style_color_flat = style_srgb.reshape(-1, 3) / 255.0 | |
| if len(log_content_flat) > num_samples: | |
| indices = np.random.choice(len(log_content_flat), num_samples, replace=False) | |
| log_content_sampled = log_content_flat[indices] | |
| log_style_sampled = log_style_flat[indices] | |
| content_sampled = content_flat[indices] | |
| style_sampled = style_flat[indices] | |
| content_color_sampled = content_color_flat[indices] | |
| style_color_sampled = style_color_flat[indices] | |
| else: | |
| log_content_sampled = log_content_flat | |
| log_style_sampled = log_style_flat | |
| content_sampled = content_flat | |
| style_sampled = style_flat | |
| content_color_sampled = content_color_flat | |
| style_color_sampled = style_color_flat | |
| # RGB Space | |
| plot_3d_scatter(style_sampled, style_color_sampled, axs["style_rgb"], | |
| "Style RGB Space", ["R", "G", "B"], 2, 0.3) | |
| plot_3d_scatter(content_sampled, content_color_sampled, axs["content_rgb"], | |
| "Content RGB Space", ["R", "G", "B"], 2, 0.3) | |
| # Log-RGB Space | |
| plot_3d_scatter(log_style_sampled, style_color_sampled, axs["style_log_rgb"], | |
| "Style Log-RGB Space", ["Log(R)", "Log(G)", "Log(B)"], 2, 0.3) | |
| plot_3d_scatter(log_content_sampled, content_color_sampled, axs["content_log_rgb"], | |
| "Content Log-RGB Space", ["Log(R)", "Log(G)", "Log(B)"], 2, 0.3) | |
| # Overlays | |
| _plot_overlay(axs["mixed_rgb"], style_sampled, content_sampled, | |
| ["Red", "Green", "Blue"], "RGB Comparison") | |
| _plot_overlay(axs["mixed_log_rgb"], log_style_sampled, log_content_sampled, | |
| ["log(Red)", "log(Green)", "log(Blue)"], "Log-RGB Comparison") | |
| # Clustering visualization | |
| if log_cluster_bin_masks: | |
| _plot_clusters(axs["clustered_content_log_rgb"], log_content_sampled, | |
| log_cluster_bin_masks, log_cluster_dark_points, | |
| log_cluster_bright_points, indices) | |
| def plot_content_log_chroma(axs, log_chroma_content, content_bit_depth, norm_content_img): | |
| """Plot log chromaticity image and scatter.""" | |
| # Project to linear and clip | |
| linear_chroma = np.exp(log_chroma_content).astype(np.float32) | |
| max_val = 2**content_bit_depth - 1 | |
| img_normalized = np.clip(linear_chroma / max_val * 255.0, 0, 255).astype(np.uint8) | |
| axs["content_projected_img"].imshow(img_normalized) | |
| axs["content_projected_img"].set_title("Content's Log Chromaticity Image", fontsize=12) | |
| axs["content_projected_img"].axis("off") | |
| # Sample and plot | |
| num_samples = 5000 | |
| log_chroma_flat = log_chroma_content.reshape(-1, 3) | |
| content_flat = norm_content_img.reshape(-1, 3) | |
| if len(content_flat) > num_samples: | |
| indices = np.random.choice(len(content_flat), num_samples, replace=False) | |
| content_sampled = content_flat[indices] | |
| log_chroma_sampled = log_chroma_flat[indices] | |
| else: | |
| content_sampled = content_flat | |
| log_chroma_sampled = log_chroma_flat | |
| plot_3d_scatter(log_chroma_sampled, content_sampled, axs["content_projected_log_rgb"], | |
| "Content's Log Chromaticity Normalized Log RGB", | |
| ["Log(Red)", "Log(Green)", "Log(Blue)"], 2, 0.3) | |
| def plot_transformed_img_logrgb(axs, tf_log_img, log_img, bit_depth): | |
| """Plot transformed image and its log-RGB.""" | |
| # Convert to sRGB | |
| linear_img = np.exp(tf_log_img).astype(np.float32) | |
| norm_linear = np.clip(linear_img / (2**bit_depth - 1), 0.0, 1.0) | |
| img = normalized_linear_to_srgb(norm_linear) | |
| axs["tf_content_img"].imshow(img) | |
| axs["tf_content_img"].set_title("Transformed Content Image", fontsize=12) | |
| axs["tf_content_img"].axis("off") | |
| # Sample | |
| num_samples = 5000 | |
| tf_log_flat = tf_log_img.reshape(-1, 3) | |
| color_flat = img.reshape(-1, 3) / 255.0 | |
| log_flat = log_img.reshape(-1, 3) | |
| if len(tf_log_flat) > num_samples: | |
| indices = np.random.choice(len(tf_log_flat), num_samples, replace=False) | |
| tf_log_sampled = tf_log_flat[indices] | |
| color_sampled = color_flat[indices] | |
| log_sampled = log_flat[indices] | |
| else: | |
| tf_log_sampled = tf_log_flat | |
| color_sampled = color_flat | |
| log_sampled = log_flat | |
| # Transformed log-RGB | |
| plot_3d_scatter(tf_log_sampled, color_sampled, axs["tf_content_log_rgb"], | |
| "Transformed Content Log RGB", ["Log(Red)", "Log(Green)", "Log(Blue)"], 2, 0.3) | |
| # Overlay comparison | |
| axs["mixed_tf_log_rgb"].scatter(tf_log_sampled[:, 0], tf_log_sampled[:, 1], tf_log_sampled[:, 2], | |
| c="green", s=2, alpha=0.2, label="Transformed Content") | |
| axs["mixed_tf_log_rgb"].scatter(log_sampled[:, 0], log_sampled[:, 1], log_sampled[:, 2], | |
| c="blue", s=2, alpha=0.2, label="Original Content") | |
| axs["mixed_tf_log_rgb"].set_xlabel("log(Red)", fontsize=10) | |
| axs["mixed_tf_log_rgb"].set_ylabel("log(Green)", fontsize=10) | |
| axs["mixed_tf_log_rgb"].set_zlabel("log(Blue)", fontsize=10) | |
| axs["mixed_tf_log_rgb"].set_title("Log-RGB Comparison", fontsize=12) | |
| axs["mixed_tf_log_rgb"].legend() | |
| # ============================================================================ | |
| # Chromaticity & Clustering | |
| # ============================================================================ | |
| def plot_log_chroma_plane_pre_clustering(log_chroma_content, isd_map, content_img, content_bit_depth): | |
| """2D log chromaticity plane before clustering.""" | |
| H, W, _ = log_chroma_content.shape | |
| log_chroma_flat = log_chroma_content.reshape(H * W, 3) | |
| # Sample | |
| num_samples = 200000 | |
| if len(log_chroma_flat) > num_samples: | |
| indices = np.random.choice(len(log_chroma_flat), num_samples, replace=False) | |
| sampled_chroma = log_chroma_flat[indices] | |
| content_flat = content_img.reshape(H * W, 3) | |
| sampled_colors = content_flat[indices] | |
| else: | |
| sampled_chroma = log_chroma_flat | |
| sampled_colors = content_img.reshape(H * W, 3) | |
| norm_colors = np.clip(sampled_colors / (2**content_bit_depth - 1), 0, 1) | |
| norm_colors = normalized_linear_to_srgb(norm_colors) / 255.0 | |
| # Project to 2D | |
| mean_isd = isd_map.reshape(H * W, 3).mean(axis=0) | |
| mean_isd = mean_isd / np.linalg.norm(mean_isd) | |
| arbitrary = np.array([1.0, 0.0, 0.0]) if abs(mean_isd[0]) < 0.9 else np.array([0.0, 1.0, 0.0]) | |
| u = arbitrary - np.dot(arbitrary, mean_isd) * mean_isd | |
| u = u / np.linalg.norm(u) | |
| v = np.cross(mean_isd, u) | |
| coords_2d = np.zeros((len(sampled_chroma), 2)) | |
| coords_2d[:, 0] = np.dot(sampled_chroma, u) | |
| coords_2d[:, 1] = np.dot(sampled_chroma, v) | |
| # Projected RGB | |
| projected_3d = coords_2d[:, 0:1] * u + coords_2d[:, 1:2] * v | |
| projected_rgb = np.clip(np.exp(projected_3d) / np.max(np.exp(projected_3d), axis=0, keepdims=True), 0, 1) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) | |
| ax1.scatter(coords_2d[:, 0], coords_2d[:, 1], c=norm_colors, s=5, alpha=0.6, rasterized=True) | |
| ax1.set_xlabel("Chromaticity Dimension 1", fontsize=12) | |
| ax1.set_ylabel("Chromaticity Dimension 2", fontsize=12) | |
| ax1.set_title("Colored by Original RGB Values", fontsize=14) | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_aspect("equal", adjustable="box") | |
| ax2.scatter(coords_2d[:, 0], coords_2d[:, 1], c=projected_rgb, s=5, alpha=0.6, rasterized=True) | |
| ax2.set_xlabel("Chromaticity Dimension 1", fontsize=12) | |
| ax2.set_ylabel("Chromaticity Dimension 2", fontsize=12) | |
| ax2.set_title("Colored by Projected Chromaticity", fontsize=14) | |
| ax2.grid(True, alpha=0.3) | |
| ax2.set_aspect("equal", adjustable="box") | |
| fig.suptitle("Log Chromaticity Plane (Pre-Clustering)", fontsize=16, y=1.02) | |
| plt.tight_layout() | |
| plt.show() | |
| def plot_log_chroma_plane_post_clustering(log_chroma_content, isd_map, bin_masks, bin_radius): | |
| """2D log chromaticity plane after clustering.""" | |
| H, W, _ = log_chroma_content.shape | |
| log_chroma_flat = log_chroma_content.reshape(H * W, 3) | |
| cluster_ids = np.zeros(H * W, dtype=int) | |
| for bin_id, mask in enumerate(bin_masks, start=1): | |
| cluster_ids[mask.ravel()] = bin_id | |
| # Sample | |
| num_samples = 200000 | |
| if len(log_chroma_flat) > num_samples: | |
| indices = np.random.choice(len(log_chroma_flat), num_samples, replace=False) | |
| sampled_chroma = log_chroma_flat[indices] | |
| sampled_clusters = cluster_ids[indices] | |
| else: | |
| sampled_chroma = log_chroma_flat | |
| sampled_clusters = cluster_ids | |
| # Project to 2D | |
| mean_isd = isd_map.reshape(H * W, 3).mean(axis=0) | |
| mean_isd = mean_isd / np.linalg.norm(mean_isd) | |
| arbitrary = np.array([1.0, 0.0, 0.0]) if abs(mean_isd[0]) < 0.9 else np.array([0.0, 1.0, 0.0]) | |
| u = arbitrary - np.dot(arbitrary, mean_isd) * mean_isd | |
| u = u / np.linalg.norm(u) | |
| v = np.cross(mean_isd, u) | |
| coords_2d = np.zeros((len(sampled_chroma), 2)) | |
| coords_2d[:, 0] = np.dot(sampled_chroma, u) | |
| coords_2d[:, 1] = np.dot(sampled_chroma, v) | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| num_clusters = len(bin_masks) | |
| cmap = plt.cm.get_cmap("tab20" if num_clusters <= 20 else "hsv", num_clusters) | |
| scatter = ax.scatter(coords_2d[:, 0], coords_2d[:, 1], c=sampled_clusters, | |
| cmap=cmap, s=5, alpha=0.6, rasterized=True) | |
| cbar = plt.colorbar(scatter, ax=ax) | |
| cbar.set_label("Cluster ID", fontsize=12) | |
| ax.set_xlabel("Chromaticity Dimension 1", fontsize=12) | |
| ax.set_ylabel("Chromaticity Dimension 2", fontsize=12) | |
| ax.set_title(f"Log Chromaticity Plane (Post-Clustering)\n{num_clusters} clusters with radius={bin_radius}", fontsize=14) | |
| ax.grid(True, alpha=0.3) | |
| ax.set_aspect("equal", adjustable="box") | |
| plt.tight_layout() | |
| plt.show() | |
| def plot_cluster_spatial_distribution(bin_masks, content_img, content_bit_depth): | |
| """Show spatial distribution of clusters.""" | |
| H, W, _ = content_img.shape | |
| cluster_img = np.zeros((H, W), dtype=int) | |
| for bin_id, mask in enumerate(bin_masks, start=1): | |
| cluster_img[mask] = bin_id | |
| norm_content = np.clip(content_img / (2**content_bit_depth - 1), 0, 1) | |
| srgb_content = normalized_linear_to_srgb(norm_content) | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8)) | |
| ax1.imshow(srgb_content) | |
| ax1.set_title("Original Content Image", fontsize=14) | |
| ax1.axis("off") | |
| num_clusters = len(bin_masks) | |
| cmap = plt.cm.get_cmap("tab20" if num_clusters <= 20 else "hsv", num_clusters) | |
| im = ax2.imshow(cluster_img, cmap=cmap) | |
| ax2.set_title(f"Material Clusters ({num_clusters} clusters)", fontsize=14) | |
| ax2.axis("off") | |
| cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04) | |
| cbar.set_label("Cluster ID", fontsize=12) | |
| plt.tight_layout() | |
| plt.show() | |
| def plot_log_chroma_plane_posterized(log_chroma_original, log_chroma_posterized, | |
| isd_map, content_img, content_bit_depth, levels): | |
| """Compare original vs posterized on 2D chromaticity plane.""" | |
| H, W, _ = log_chroma_original.shape | |
| # Sample | |
| num_samples = 100000 | |
| orig_flat = log_chroma_original.reshape(H * W, 3) | |
| post_flat = log_chroma_posterized.reshape(H * W, 3) | |
| color_flat = content_img.reshape(H * W, 3) | |
| if len(orig_flat) > num_samples: | |
| indices = np.random.choice(len(orig_flat), num_samples, replace=False) | |
| orig_sampled = orig_flat[indices] | |
| post_sampled = post_flat[indices] | |
| color_sampled = color_flat[indices] | |
| else: | |
| orig_sampled = orig_flat | |
| post_sampled = post_flat | |
| color_sampled = color_flat | |
| norm_colors = np.clip(color_sampled / (2**content_bit_depth - 1), 0, 1) | |
| norm_colors = normalized_linear_to_srgb(norm_colors) / 255.0 | |
| # Project to 2D | |
| mean_isd = isd_map.reshape(H * W, 3).mean(axis=0) | |
| mean_isd = mean_isd / np.linalg.norm(mean_isd) | |
| arbitrary = np.array([1.0, 0.0, 0.0]) if abs(mean_isd[0]) < 0.9 else np.array([0.0, 1.0, 0.0]) | |
| u = arbitrary - np.dot(arbitrary, mean_isd) * mean_isd | |
| u = u / np.linalg.norm(u) | |
| v = np.cross(mean_isd, u) | |
| # Project both versions | |
| orig_2d = np.zeros((len(orig_sampled), 2)) | |
| orig_2d[:, 0] = np.dot(orig_sampled, u) | |
| orig_2d[:, 1] = np.dot(orig_sampled, v) | |
| post_2d = np.zeros((len(post_sampled), 2)) | |
| post_2d[:, 0] = np.dot(post_sampled, u) | |
| post_2d[:, 1] = np.dot(post_sampled, v) | |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 7)) | |
| # Original | |
| ax1.scatter(orig_2d[:, 0], orig_2d[:, 1], c=norm_colors, s=3, alpha=0.5, rasterized=True) | |
| ax1.set_xlabel("Chromaticity Dimension 1", fontsize=12) | |
| ax1.set_ylabel("Chromaticity Dimension 2", fontsize=12) | |
| ax1.set_title("Original", fontsize=14) | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_aspect("equal", adjustable="box") | |
| # Posterized | |
| ax2.scatter(post_2d[:, 0], post_2d[:, 1], c=norm_colors, s=3, alpha=0.5, rasterized=True) | |
| ax2.set_xlabel("Chromaticity Dimension 1", fontsize=12) | |
| ax2.set_ylabel("Chromaticity Dimension 2", fontsize=12) | |
| ax2.set_title(f"Posterized ({levels} levels)", fontsize=14) | |
| ax2.grid(True, alpha=0.3) | |
| ax2.set_aspect("equal", adjustable="box") | |
| # Overlay comparison | |
| ax3.scatter(orig_2d[:, 0], orig_2d[:, 1], c='blue', s=2, alpha=0.3, | |
| label='Original', rasterized=True) | |
| ax3.scatter(post_2d[:, 0], post_2d[:, 1], c='red', s=2, alpha=0.3, | |
| label='Posterized', rasterized=True) | |
| ax3.set_xlabel("Chromaticity Dimension 1", fontsize=12) | |
| ax3.set_ylabel("Chromaticity Dimension 2", fontsize=12) | |
| ax3.set_title("Overlay Comparison", fontsize=14) | |
| ax3.legend() | |
| ax3.grid(True, alpha=0.3) | |
| ax3.set_aspect("equal", adjustable="box") | |
| fig.suptitle(f"Effect of Posterization on Log Chromaticity Plane", | |
| fontsize=16, y=1.02) | |
| plt.tight_layout() | |
| plt.show() | |
| # ============================================================================ | |
| # ISD Visualization | |
| # ============================================================================ | |
| def visualize_isd_as_direction(isd_map, save_path="isd_direction.png"): | |
| """Color-coded ISD direction using spherical coordinates.""" | |
| C, H, W = isd_map.shape | |
| x, y, z = isd_map[0], isd_map[1], isd_map[2] | |
| azimuth = np.arctan2(y, x) | |
| hue = (azimuth + np.pi) / (2 * np.pi) | |
| elevation = np.arcsin(np.clip(z, -1, 1)) | |
| saturation = (elevation + np.pi / 2) / np.pi | |
| value = np.ones_like(hue) | |
| hsv = np.stack([hue, saturation, value], axis=-1) | |
| rgb = hsv_to_rgb(hsv) | |
| plt.figure(figsize=(12, 8)) | |
| plt.imshow(rgb) | |
| plt.title("ISD Direction Map (Color = Direction in 3D space)") | |
| plt.colorbar(label="Direction") | |
| plt.axis("off") | |
| plt.tight_layout() | |
| _save_or_show(save_path) | |
| return rgb | |
| def visualize_isd_components(isd_map, save_path="isd_components.png"): | |
| """Visualize R, G, B components separately.""" | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 4)) | |
| labels = ["ISD_R", "ISD_G", "ISD_B"] | |
| for i, (ax, label) in enumerate(zip(axes, labels)): | |
| component = isd_map[i] | |
| vmin, vmax = component.min(), component.max() | |
| normalized = (component - vmin) / (vmax - vmin + 1e-8) | |
| im = ax.imshow(normalized, cmap="RdBu_r") | |
| ax.set_title(label) | |
| ax.axis("off") | |
| plt.colorbar(im, ax=ax, label=f"Range: [{vmin:.3f}, {vmax:.3f}]") | |
| plt.tight_layout() | |
| _save_or_show(save_path) | |
| def visualize_isd_magnitude(isd_map, save_path="isd_magnitude.png"): | |
| """Check ISD magnitude (should be ~1 if normalized).""" | |
| magnitude = np.sqrt((isd_map**2).sum(axis=0)) | |
| plt.figure(figsize=(10, 8)) | |
| plt.imshow(magnitude, cmap="viridis") | |
| plt.colorbar(label="ISD Magnitude") | |
| plt.title(f"ISD Magnitude (should ≈1.0)\nMean: {magnitude.mean():.4f}, Std: {magnitude.std():.4f}") | |
| plt.axis("off") | |
| plt.tight_layout() | |
| _save_or_show(save_path) | |
| if abs(magnitude.mean() - 1.0) > 0.1: | |
| print("⚠️ Warning: ISD vectors may not be properly normalized!") | |
| def visualize_isd_arrow_field(isd_map, original_img, stride=50, save_path="isd_arrows.png"): | |
| """Overlay ISD as arrow field.""" | |
| if original_img.shape[0] == 3: | |
| original_img = np.transpose(original_img, (1, 2, 0)) | |
| H, W = original_img.shape[:2] | |
| fig, ax = plt.subplots(figsize=(14, 10)) | |
| ax.imshow(original_img) | |
| y_coords = np.arange(stride // 2, H, stride) | |
| x_coords = np.arange(stride // 2, W, stride) | |
| for y in y_coords: | |
| for x in x_coords: | |
| isd_vec = isd_map[:, y, x] | |
| dx = isd_vec[0] * stride * 0.4 | |
| dy = isd_vec[1] * stride * 0.4 | |
| b_component = isd_vec[2] | |
| color = plt.cm.RdBu_r((b_component + 1) / 2) | |
| ax.arrow(x, y, dx, dy, head_width=stride * 0.15, head_length=stride * 0.15, | |
| fc=color, ec=color, alpha=0.7, width=1.5) | |
| ax.set_title("ISD Vector Field") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| _save_or_show(save_path) | |
| # ============================================================================ | |
| # Utilities | |
| # ============================================================================ | |
| def plot_bin_masks(bin_mask: np.ndarray): | |
| """Simple binary mask visualization.""" | |
| fig = plt.figure(figsize=(6, 6), facecolor="gray") | |
| ax = fig.add_axes([0, 0, 1, 1]) | |
| ax.imshow(bin_mask, cmap="gray") | |
| ax.axis("off") | |
| plt.show() | |
| def plot_plane(axs, normal, point, bounds, alpha=0.3, color="red"): | |
| """Plot plane perpendicular to normal through point.""" | |
| normal = np.array(normal) / np.linalg.norm(normal) | |
| point = np.array(point) | |
| a, b, c = normal | |
| x0, y0, z0 = point | |
| (x_min, x_max), (y_min, y_max), (z_min, z_max) = bounds | |
| if abs(c) > abs(a) and abs(c) > abs(b): | |
| x = np.linspace(x_min, x_max, 20) | |
| y = np.linspace(y_min, y_max, 20) | |
| X, Y = np.meshgrid(x, y) | |
| Z = z0 - (a * (X - x0) + b * (Y - y0)) / c | |
| elif abs(b) > abs(a): | |
| x = np.linspace(x_min, x_max, 20) | |
| z = np.linspace(z_min, z_max, 20) | |
| X, Z = np.meshgrid(x, z) | |
| Y = y0 - (a * (X - x0) + c * (Z - z0)) / b | |
| else: | |
| y = np.linspace(y_min, y_max, 20) | |
| z = np.linspace(z_min, z_max, 20) | |
| Y, Z = np.meshgrid(y, z) | |
| X = x0 - (b * (Y - y0) + c * (Z - z0)) / a | |
| for ax in axs: | |
| ax.plot_surface(X, Y, Z, alpha=alpha, color=color) | |
| def calculate_shared_limits(data_arrays, padding=0.1): | |
| """Calculate shared axis limits from datasets.""" | |
| all_data = np.vstack(data_arrays) if isinstance(data_arrays, list) else data_arrays | |
| mins = all_data.min(axis=0) | |
| maxs = all_data.max(axis=0) | |
| ranges = maxs - mins | |
| x_limits = [mins[0] - padding * ranges[0], maxs[0] + padding * ranges[0]] | |
| y_limits = [mins[1] - padding * ranges[1], maxs[1] + padding * ranges[1]] | |
| z_limits = [mins[2] - padding * ranges[2], maxs[2] + padding * ranges[2]] | |
| return x_limits, y_limits, z_limits | |
| def plane_view_from_normal(normal): | |
| """Compute Matplotlib 3D view (elev, azim) from plane normal.""" | |
| normal = np.array(normal) / np.linalg.norm(normal) | |
| nx, ny, nz = normal | |
| elev = np.degrees(np.arcsin(nz)) | |
| azim = np.degrees(np.arctan2(ny, nx)) | |
| return elev, azim | |
| # ============================================================================ | |
| # Private Helpers | |
| # ============================================================================ | |
| def _save_or_show(save_path): | |
| """Save or show figure.""" | |
| if save_path: | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| print(f"Saved to {save_path}") | |
| plt.close() | |
| else: | |
| plt.show() | |
| def _plot_overlay(ax, data1, data2, labels, title): | |
| """Plot two datasets overlaid.""" | |
| ax.scatter(data1[:, 0], data1[:, 1], data1[:, 2], c="green", s=2, alpha=0.2, label="Style") | |
| ax.scatter(data2[:, 0], data2[:, 1], data2[:, 2], c="blue", s=2, alpha=0.2, label="Original") | |
| ax.set_xlabel(labels[0], fontsize=10) | |
| ax.set_ylabel(labels[1], fontsize=10) | |
| ax.set_zlabel(labels[2], fontsize=10) | |
| ax.set_title(title, fontsize=12) | |
| ax.legend() | |
| def _plot_clusters(ax, log_content_sampled, bin_masks, dark_points, bright_points, indices): | |
| """Plot clustered points with markers.""" | |
| num_clusters = len(bin_masks) | |
| cmap = plt.cm.get_cmap("tab20" if num_clusters <= 20 else "hsv", num_clusters) | |
| for i in range(num_clusters): | |
| bin_mask_flat = bin_masks[i].ravel() | |
| bin_mask_sampled = bin_mask_flat[indices] | |
| ax.scatter(log_content_sampled[bin_mask_sampled, 0], | |
| log_content_sampled[bin_mask_sampled, 1], | |
| log_content_sampled[bin_mask_sampled, 2], | |
| c=[cmap(i)], s=2, alpha=0.1) | |
| for i in range(num_clusters): | |
| if dark_points is not None: | |
| ax.scatter(dark_points[i][0], dark_points[i][1], dark_points[i][2], | |
| c=[cmap(i)], edgecolors="black", linewidth=0.5, s=20, alpha=1.0) | |
| if bright_points is not None: | |
| ax.scatter(bright_points[i][0], bright_points[i][1], bright_points[i][2], | |
| c=[cmap(i)], edgecolors="red", linewidth=0.5, s=20, alpha=1.0) | |
| ax.set_xlabel("log(Red)", fontsize=10) | |
| ax.set_ylabel("log(Green)", fontsize=10) | |
| ax.set_zlabel("log(Blue)", fontsize=10) | |
| ax.set_title("Content Log-RGB Clustered", fontsize=12) | |