bidr-relight / src /plotting.py
maxhuber's picture
Upload 14 files
3336231 verified
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import hsv_to_rgb
from mpl_toolkits.mplot3d import Axes3D
from src.image_util import normalized_linear_to_srgb
# ============================================================================
# 3D Scatter Plotting
# ============================================================================
def plot_3d_scatter(pixels, colors, ax, title, axis_labels, point_size=1, alpha=0.3):
"""Unified 3D scatter plot on existing axis."""
ax.scatter(
pixels[:, 0], pixels[:, 1], pixels[:, 2],
c=colors, s=point_size, alpha=alpha, rasterized=True
)
ax.set_xlabel(axis_labels[0], fontsize=10)
ax.set_ylabel(axis_labels[1], fontsize=10)
ax.set_zlabel(axis_labels[2], fontsize=10)
ax.set_title(title, fontsize=12)
def plot_rgb_space(pixels, colors, title="RGB Space", save_path=None, alpha=0.3, point_size=1):
"""Plot RGB space with origin marker."""
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
plot_3d_scatter(pixels, colors, ax, title, ["R", "G", "B"], point_size, alpha)
ax.scatter([0], [0], [0], c="black", s=100, marker="o", label="Origin", depthshade=False)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_zlim([0, 1])
plt.tight_layout()
_save_or_show(save_path)
def plot_log_rgb_space(pixels, colors, title="Log RGB Space", save_path=None, alpha=0.3, point_size=1):
"""Plot log-RGB space."""
log_pixels = np.log(pixels + 1e-6)
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
plot_3d_scatter(log_pixels, colors, ax, title, ["log(R)", "log(G)", "log(B)"], point_size, alpha)
plt.tight_layout()
_save_or_show(save_path)
# ============================================================================
# Image & Color Space Visualization
# ============================================================================
def plot_img_rgb_logrgb(axs, norm_content_img, norm_style_img, log_content_img, log_style_img,
log_cluster_bin_masks=None, log_cluster_dark_points=None,
log_cluster_bright_points=None):
"""Complete image + RGB + log-RGB visualization."""
content_srgb = normalized_linear_to_srgb(norm_content_img)
style_srgb = normalized_linear_to_srgb(norm_style_img)
# Images
axs["style_img"].imshow(style_srgb)
axs["style_img"].set_title("Style", fontsize=12)
axs["style_img"].axis("off")
axs["content_img"].imshow(content_srgb)
axs["content_img"].set_title("Content", fontsize=12)
axs["content_img"].axis("off")
# Sample pixels
num_samples = 5000
log_content_flat = log_content_img.reshape(-1, 3)
log_style_flat = log_style_img.reshape(-1, 3)
content_flat = norm_content_img.reshape(-1, 3)
style_flat = norm_style_img.reshape(-1, 3)
content_color_flat = content_srgb.reshape(-1, 3) / 255.0
style_color_flat = style_srgb.reshape(-1, 3) / 255.0
if len(log_content_flat) > num_samples:
indices = np.random.choice(len(log_content_flat), num_samples, replace=False)
log_content_sampled = log_content_flat[indices]
log_style_sampled = log_style_flat[indices]
content_sampled = content_flat[indices]
style_sampled = style_flat[indices]
content_color_sampled = content_color_flat[indices]
style_color_sampled = style_color_flat[indices]
else:
log_content_sampled = log_content_flat
log_style_sampled = log_style_flat
content_sampled = content_flat
style_sampled = style_flat
content_color_sampled = content_color_flat
style_color_sampled = style_color_flat
# RGB Space
plot_3d_scatter(style_sampled, style_color_sampled, axs["style_rgb"],
"Style RGB Space", ["R", "G", "B"], 2, 0.3)
plot_3d_scatter(content_sampled, content_color_sampled, axs["content_rgb"],
"Content RGB Space", ["R", "G", "B"], 2, 0.3)
# Log-RGB Space
plot_3d_scatter(log_style_sampled, style_color_sampled, axs["style_log_rgb"],
"Style Log-RGB Space", ["Log(R)", "Log(G)", "Log(B)"], 2, 0.3)
plot_3d_scatter(log_content_sampled, content_color_sampled, axs["content_log_rgb"],
"Content Log-RGB Space", ["Log(R)", "Log(G)", "Log(B)"], 2, 0.3)
# Overlays
_plot_overlay(axs["mixed_rgb"], style_sampled, content_sampled,
["Red", "Green", "Blue"], "RGB Comparison")
_plot_overlay(axs["mixed_log_rgb"], log_style_sampled, log_content_sampled,
["log(Red)", "log(Green)", "log(Blue)"], "Log-RGB Comparison")
# Clustering visualization
if log_cluster_bin_masks:
_plot_clusters(axs["clustered_content_log_rgb"], log_content_sampled,
log_cluster_bin_masks, log_cluster_dark_points,
log_cluster_bright_points, indices)
def plot_content_log_chroma(axs, log_chroma_content, content_bit_depth, norm_content_img):
"""Plot log chromaticity image and scatter."""
# Project to linear and clip
linear_chroma = np.exp(log_chroma_content).astype(np.float32)
max_val = 2**content_bit_depth - 1
img_normalized = np.clip(linear_chroma / max_val * 255.0, 0, 255).astype(np.uint8)
axs["content_projected_img"].imshow(img_normalized)
axs["content_projected_img"].set_title("Content's Log Chromaticity Image", fontsize=12)
axs["content_projected_img"].axis("off")
# Sample and plot
num_samples = 5000
log_chroma_flat = log_chroma_content.reshape(-1, 3)
content_flat = norm_content_img.reshape(-1, 3)
if len(content_flat) > num_samples:
indices = np.random.choice(len(content_flat), num_samples, replace=False)
content_sampled = content_flat[indices]
log_chroma_sampled = log_chroma_flat[indices]
else:
content_sampled = content_flat
log_chroma_sampled = log_chroma_flat
plot_3d_scatter(log_chroma_sampled, content_sampled, axs["content_projected_log_rgb"],
"Content's Log Chromaticity Normalized Log RGB",
["Log(Red)", "Log(Green)", "Log(Blue)"], 2, 0.3)
def plot_transformed_img_logrgb(axs, tf_log_img, log_img, bit_depth):
"""Plot transformed image and its log-RGB."""
# Convert to sRGB
linear_img = np.exp(tf_log_img).astype(np.float32)
norm_linear = np.clip(linear_img / (2**bit_depth - 1), 0.0, 1.0)
img = normalized_linear_to_srgb(norm_linear)
axs["tf_content_img"].imshow(img)
axs["tf_content_img"].set_title("Transformed Content Image", fontsize=12)
axs["tf_content_img"].axis("off")
# Sample
num_samples = 5000
tf_log_flat = tf_log_img.reshape(-1, 3)
color_flat = img.reshape(-1, 3) / 255.0
log_flat = log_img.reshape(-1, 3)
if len(tf_log_flat) > num_samples:
indices = np.random.choice(len(tf_log_flat), num_samples, replace=False)
tf_log_sampled = tf_log_flat[indices]
color_sampled = color_flat[indices]
log_sampled = log_flat[indices]
else:
tf_log_sampled = tf_log_flat
color_sampled = color_flat
log_sampled = log_flat
# Transformed log-RGB
plot_3d_scatter(tf_log_sampled, color_sampled, axs["tf_content_log_rgb"],
"Transformed Content Log RGB", ["Log(Red)", "Log(Green)", "Log(Blue)"], 2, 0.3)
# Overlay comparison
axs["mixed_tf_log_rgb"].scatter(tf_log_sampled[:, 0], tf_log_sampled[:, 1], tf_log_sampled[:, 2],
c="green", s=2, alpha=0.2, label="Transformed Content")
axs["mixed_tf_log_rgb"].scatter(log_sampled[:, 0], log_sampled[:, 1], log_sampled[:, 2],
c="blue", s=2, alpha=0.2, label="Original Content")
axs["mixed_tf_log_rgb"].set_xlabel("log(Red)", fontsize=10)
axs["mixed_tf_log_rgb"].set_ylabel("log(Green)", fontsize=10)
axs["mixed_tf_log_rgb"].set_zlabel("log(Blue)", fontsize=10)
axs["mixed_tf_log_rgb"].set_title("Log-RGB Comparison", fontsize=12)
axs["mixed_tf_log_rgb"].legend()
# ============================================================================
# Chromaticity & Clustering
# ============================================================================
def plot_log_chroma_plane_pre_clustering(log_chroma_content, isd_map, content_img, content_bit_depth):
"""2D log chromaticity plane before clustering."""
H, W, _ = log_chroma_content.shape
log_chroma_flat = log_chroma_content.reshape(H * W, 3)
# Sample
num_samples = 200000
if len(log_chroma_flat) > num_samples:
indices = np.random.choice(len(log_chroma_flat), num_samples, replace=False)
sampled_chroma = log_chroma_flat[indices]
content_flat = content_img.reshape(H * W, 3)
sampled_colors = content_flat[indices]
else:
sampled_chroma = log_chroma_flat
sampled_colors = content_img.reshape(H * W, 3)
norm_colors = np.clip(sampled_colors / (2**content_bit_depth - 1), 0, 1)
norm_colors = normalized_linear_to_srgb(norm_colors) / 255.0
# Project to 2D
mean_isd = isd_map.reshape(H * W, 3).mean(axis=0)
mean_isd = mean_isd / np.linalg.norm(mean_isd)
arbitrary = np.array([1.0, 0.0, 0.0]) if abs(mean_isd[0]) < 0.9 else np.array([0.0, 1.0, 0.0])
u = arbitrary - np.dot(arbitrary, mean_isd) * mean_isd
u = u / np.linalg.norm(u)
v = np.cross(mean_isd, u)
coords_2d = np.zeros((len(sampled_chroma), 2))
coords_2d[:, 0] = np.dot(sampled_chroma, u)
coords_2d[:, 1] = np.dot(sampled_chroma, v)
# Projected RGB
projected_3d = coords_2d[:, 0:1] * u + coords_2d[:, 1:2] * v
projected_rgb = np.clip(np.exp(projected_3d) / np.max(np.exp(projected_3d), axis=0, keepdims=True), 0, 1)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
ax1.scatter(coords_2d[:, 0], coords_2d[:, 1], c=norm_colors, s=5, alpha=0.6, rasterized=True)
ax1.set_xlabel("Chromaticity Dimension 1", fontsize=12)
ax1.set_ylabel("Chromaticity Dimension 2", fontsize=12)
ax1.set_title("Colored by Original RGB Values", fontsize=14)
ax1.grid(True, alpha=0.3)
ax1.set_aspect("equal", adjustable="box")
ax2.scatter(coords_2d[:, 0], coords_2d[:, 1], c=projected_rgb, s=5, alpha=0.6, rasterized=True)
ax2.set_xlabel("Chromaticity Dimension 1", fontsize=12)
ax2.set_ylabel("Chromaticity Dimension 2", fontsize=12)
ax2.set_title("Colored by Projected Chromaticity", fontsize=14)
ax2.grid(True, alpha=0.3)
ax2.set_aspect("equal", adjustable="box")
fig.suptitle("Log Chromaticity Plane (Pre-Clustering)", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()
def plot_log_chroma_plane_post_clustering(log_chroma_content, isd_map, bin_masks, bin_radius):
"""2D log chromaticity plane after clustering."""
H, W, _ = log_chroma_content.shape
log_chroma_flat = log_chroma_content.reshape(H * W, 3)
cluster_ids = np.zeros(H * W, dtype=int)
for bin_id, mask in enumerate(bin_masks, start=1):
cluster_ids[mask.ravel()] = bin_id
# Sample
num_samples = 200000
if len(log_chroma_flat) > num_samples:
indices = np.random.choice(len(log_chroma_flat), num_samples, replace=False)
sampled_chroma = log_chroma_flat[indices]
sampled_clusters = cluster_ids[indices]
else:
sampled_chroma = log_chroma_flat
sampled_clusters = cluster_ids
# Project to 2D
mean_isd = isd_map.reshape(H * W, 3).mean(axis=0)
mean_isd = mean_isd / np.linalg.norm(mean_isd)
arbitrary = np.array([1.0, 0.0, 0.0]) if abs(mean_isd[0]) < 0.9 else np.array([0.0, 1.0, 0.0])
u = arbitrary - np.dot(arbitrary, mean_isd) * mean_isd
u = u / np.linalg.norm(u)
v = np.cross(mean_isd, u)
coords_2d = np.zeros((len(sampled_chroma), 2))
coords_2d[:, 0] = np.dot(sampled_chroma, u)
coords_2d[:, 1] = np.dot(sampled_chroma, v)
fig, ax = plt.subplots(figsize=(12, 10))
num_clusters = len(bin_masks)
cmap = plt.cm.get_cmap("tab20" if num_clusters <= 20 else "hsv", num_clusters)
scatter = ax.scatter(coords_2d[:, 0], coords_2d[:, 1], c=sampled_clusters,
cmap=cmap, s=5, alpha=0.6, rasterized=True)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label("Cluster ID", fontsize=12)
ax.set_xlabel("Chromaticity Dimension 1", fontsize=12)
ax.set_ylabel("Chromaticity Dimension 2", fontsize=12)
ax.set_title(f"Log Chromaticity Plane (Post-Clustering)\n{num_clusters} clusters with radius={bin_radius}", fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_aspect("equal", adjustable="box")
plt.tight_layout()
plt.show()
def plot_cluster_spatial_distribution(bin_masks, content_img, content_bit_depth):
"""Show spatial distribution of clusters."""
H, W, _ = content_img.shape
cluster_img = np.zeros((H, W), dtype=int)
for bin_id, mask in enumerate(bin_masks, start=1):
cluster_img[mask] = bin_id
norm_content = np.clip(content_img / (2**content_bit_depth - 1), 0, 1)
srgb_content = normalized_linear_to_srgb(norm_content)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
ax1.imshow(srgb_content)
ax1.set_title("Original Content Image", fontsize=14)
ax1.axis("off")
num_clusters = len(bin_masks)
cmap = plt.cm.get_cmap("tab20" if num_clusters <= 20 else "hsv", num_clusters)
im = ax2.imshow(cluster_img, cmap=cmap)
ax2.set_title(f"Material Clusters ({num_clusters} clusters)", fontsize=14)
ax2.axis("off")
cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
cbar.set_label("Cluster ID", fontsize=12)
plt.tight_layout()
plt.show()
def plot_log_chroma_plane_posterized(log_chroma_original, log_chroma_posterized,
isd_map, content_img, content_bit_depth, levels):
"""Compare original vs posterized on 2D chromaticity plane."""
H, W, _ = log_chroma_original.shape
# Sample
num_samples = 100000
orig_flat = log_chroma_original.reshape(H * W, 3)
post_flat = log_chroma_posterized.reshape(H * W, 3)
color_flat = content_img.reshape(H * W, 3)
if len(orig_flat) > num_samples:
indices = np.random.choice(len(orig_flat), num_samples, replace=False)
orig_sampled = orig_flat[indices]
post_sampled = post_flat[indices]
color_sampled = color_flat[indices]
else:
orig_sampled = orig_flat
post_sampled = post_flat
color_sampled = color_flat
norm_colors = np.clip(color_sampled / (2**content_bit_depth - 1), 0, 1)
norm_colors = normalized_linear_to_srgb(norm_colors) / 255.0
# Project to 2D
mean_isd = isd_map.reshape(H * W, 3).mean(axis=0)
mean_isd = mean_isd / np.linalg.norm(mean_isd)
arbitrary = np.array([1.0, 0.0, 0.0]) if abs(mean_isd[0]) < 0.9 else np.array([0.0, 1.0, 0.0])
u = arbitrary - np.dot(arbitrary, mean_isd) * mean_isd
u = u / np.linalg.norm(u)
v = np.cross(mean_isd, u)
# Project both versions
orig_2d = np.zeros((len(orig_sampled), 2))
orig_2d[:, 0] = np.dot(orig_sampled, u)
orig_2d[:, 1] = np.dot(orig_sampled, v)
post_2d = np.zeros((len(post_sampled), 2))
post_2d[:, 0] = np.dot(post_sampled, u)
post_2d[:, 1] = np.dot(post_sampled, v)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 7))
# Original
ax1.scatter(orig_2d[:, 0], orig_2d[:, 1], c=norm_colors, s=3, alpha=0.5, rasterized=True)
ax1.set_xlabel("Chromaticity Dimension 1", fontsize=12)
ax1.set_ylabel("Chromaticity Dimension 2", fontsize=12)
ax1.set_title("Original", fontsize=14)
ax1.grid(True, alpha=0.3)
ax1.set_aspect("equal", adjustable="box")
# Posterized
ax2.scatter(post_2d[:, 0], post_2d[:, 1], c=norm_colors, s=3, alpha=0.5, rasterized=True)
ax2.set_xlabel("Chromaticity Dimension 1", fontsize=12)
ax2.set_ylabel("Chromaticity Dimension 2", fontsize=12)
ax2.set_title(f"Posterized ({levels} levels)", fontsize=14)
ax2.grid(True, alpha=0.3)
ax2.set_aspect("equal", adjustable="box")
# Overlay comparison
ax3.scatter(orig_2d[:, 0], orig_2d[:, 1], c='blue', s=2, alpha=0.3,
label='Original', rasterized=True)
ax3.scatter(post_2d[:, 0], post_2d[:, 1], c='red', s=2, alpha=0.3,
label='Posterized', rasterized=True)
ax3.set_xlabel("Chromaticity Dimension 1", fontsize=12)
ax3.set_ylabel("Chromaticity Dimension 2", fontsize=12)
ax3.set_title("Overlay Comparison", fontsize=14)
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_aspect("equal", adjustable="box")
fig.suptitle(f"Effect of Posterization on Log Chromaticity Plane",
fontsize=16, y=1.02)
plt.tight_layout()
plt.show()
# ============================================================================
# ISD Visualization
# ============================================================================
def visualize_isd_as_direction(isd_map, save_path="isd_direction.png"):
"""Color-coded ISD direction using spherical coordinates."""
C, H, W = isd_map.shape
x, y, z = isd_map[0], isd_map[1], isd_map[2]
azimuth = np.arctan2(y, x)
hue = (azimuth + np.pi) / (2 * np.pi)
elevation = np.arcsin(np.clip(z, -1, 1))
saturation = (elevation + np.pi / 2) / np.pi
value = np.ones_like(hue)
hsv = np.stack([hue, saturation, value], axis=-1)
rgb = hsv_to_rgb(hsv)
plt.figure(figsize=(12, 8))
plt.imshow(rgb)
plt.title("ISD Direction Map (Color = Direction in 3D space)")
plt.colorbar(label="Direction")
plt.axis("off")
plt.tight_layout()
_save_or_show(save_path)
return rgb
def visualize_isd_components(isd_map, save_path="isd_components.png"):
"""Visualize R, G, B components separately."""
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
labels = ["ISD_R", "ISD_G", "ISD_B"]
for i, (ax, label) in enumerate(zip(axes, labels)):
component = isd_map[i]
vmin, vmax = component.min(), component.max()
normalized = (component - vmin) / (vmax - vmin + 1e-8)
im = ax.imshow(normalized, cmap="RdBu_r")
ax.set_title(label)
ax.axis("off")
plt.colorbar(im, ax=ax, label=f"Range: [{vmin:.3f}, {vmax:.3f}]")
plt.tight_layout()
_save_or_show(save_path)
def visualize_isd_magnitude(isd_map, save_path="isd_magnitude.png"):
"""Check ISD magnitude (should be ~1 if normalized)."""
magnitude = np.sqrt((isd_map**2).sum(axis=0))
plt.figure(figsize=(10, 8))
plt.imshow(magnitude, cmap="viridis")
plt.colorbar(label="ISD Magnitude")
plt.title(f"ISD Magnitude (should ≈1.0)\nMean: {magnitude.mean():.4f}, Std: {magnitude.std():.4f}")
plt.axis("off")
plt.tight_layout()
_save_or_show(save_path)
if abs(magnitude.mean() - 1.0) > 0.1:
print("⚠️ Warning: ISD vectors may not be properly normalized!")
def visualize_isd_arrow_field(isd_map, original_img, stride=50, save_path="isd_arrows.png"):
"""Overlay ISD as arrow field."""
if original_img.shape[0] == 3:
original_img = np.transpose(original_img, (1, 2, 0))
H, W = original_img.shape[:2]
fig, ax = plt.subplots(figsize=(14, 10))
ax.imshow(original_img)
y_coords = np.arange(stride // 2, H, stride)
x_coords = np.arange(stride // 2, W, stride)
for y in y_coords:
for x in x_coords:
isd_vec = isd_map[:, y, x]
dx = isd_vec[0] * stride * 0.4
dy = isd_vec[1] * stride * 0.4
b_component = isd_vec[2]
color = plt.cm.RdBu_r((b_component + 1) / 2)
ax.arrow(x, y, dx, dy, head_width=stride * 0.15, head_length=stride * 0.15,
fc=color, ec=color, alpha=0.7, width=1.5)
ax.set_title("ISD Vector Field")
ax.axis("off")
plt.tight_layout()
_save_or_show(save_path)
# ============================================================================
# Utilities
# ============================================================================
def plot_bin_masks(bin_mask: np.ndarray):
"""Simple binary mask visualization."""
fig = plt.figure(figsize=(6, 6), facecolor="gray")
ax = fig.add_axes([0, 0, 1, 1])
ax.imshow(bin_mask, cmap="gray")
ax.axis("off")
plt.show()
def plot_plane(axs, normal, point, bounds, alpha=0.3, color="red"):
"""Plot plane perpendicular to normal through point."""
normal = np.array(normal) / np.linalg.norm(normal)
point = np.array(point)
a, b, c = normal
x0, y0, z0 = point
(x_min, x_max), (y_min, y_max), (z_min, z_max) = bounds
if abs(c) > abs(a) and abs(c) > abs(b):
x = np.linspace(x_min, x_max, 20)
y = np.linspace(y_min, y_max, 20)
X, Y = np.meshgrid(x, y)
Z = z0 - (a * (X - x0) + b * (Y - y0)) / c
elif abs(b) > abs(a):
x = np.linspace(x_min, x_max, 20)
z = np.linspace(z_min, z_max, 20)
X, Z = np.meshgrid(x, z)
Y = y0 - (a * (X - x0) + c * (Z - z0)) / b
else:
y = np.linspace(y_min, y_max, 20)
z = np.linspace(z_min, z_max, 20)
Y, Z = np.meshgrid(y, z)
X = x0 - (b * (Y - y0) + c * (Z - z0)) / a
for ax in axs:
ax.plot_surface(X, Y, Z, alpha=alpha, color=color)
def calculate_shared_limits(data_arrays, padding=0.1):
"""Calculate shared axis limits from datasets."""
all_data = np.vstack(data_arrays) if isinstance(data_arrays, list) else data_arrays
mins = all_data.min(axis=0)
maxs = all_data.max(axis=0)
ranges = maxs - mins
x_limits = [mins[0] - padding * ranges[0], maxs[0] + padding * ranges[0]]
y_limits = [mins[1] - padding * ranges[1], maxs[1] + padding * ranges[1]]
z_limits = [mins[2] - padding * ranges[2], maxs[2] + padding * ranges[2]]
return x_limits, y_limits, z_limits
def plane_view_from_normal(normal):
"""Compute Matplotlib 3D view (elev, azim) from plane normal."""
normal = np.array(normal) / np.linalg.norm(normal)
nx, ny, nz = normal
elev = np.degrees(np.arcsin(nz))
azim = np.degrees(np.arctan2(ny, nx))
return elev, azim
# ============================================================================
# Private Helpers
# ============================================================================
def _save_or_show(save_path):
"""Save or show figure."""
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Saved to {save_path}")
plt.close()
else:
plt.show()
def _plot_overlay(ax, data1, data2, labels, title):
"""Plot two datasets overlaid."""
ax.scatter(data1[:, 0], data1[:, 1], data1[:, 2], c="green", s=2, alpha=0.2, label="Style")
ax.scatter(data2[:, 0], data2[:, 1], data2[:, 2], c="blue", s=2, alpha=0.2, label="Original")
ax.set_xlabel(labels[0], fontsize=10)
ax.set_ylabel(labels[1], fontsize=10)
ax.set_zlabel(labels[2], fontsize=10)
ax.set_title(title, fontsize=12)
ax.legend()
def _plot_clusters(ax, log_content_sampled, bin_masks, dark_points, bright_points, indices):
"""Plot clustered points with markers."""
num_clusters = len(bin_masks)
cmap = plt.cm.get_cmap("tab20" if num_clusters <= 20 else "hsv", num_clusters)
for i in range(num_clusters):
bin_mask_flat = bin_masks[i].ravel()
bin_mask_sampled = bin_mask_flat[indices]
ax.scatter(log_content_sampled[bin_mask_sampled, 0],
log_content_sampled[bin_mask_sampled, 1],
log_content_sampled[bin_mask_sampled, 2],
c=[cmap(i)], s=2, alpha=0.1)
for i in range(num_clusters):
if dark_points is not None:
ax.scatter(dark_points[i][0], dark_points[i][1], dark_points[i][2],
c=[cmap(i)], edgecolors="black", linewidth=0.5, s=20, alpha=1.0)
if bright_points is not None:
ax.scatter(bright_points[i][0], bright_points[i][1], bright_points[i][2],
c=[cmap(i)], edgecolors="red", linewidth=0.5, s=20, alpha=1.0)
ax.set_xlabel("log(Red)", fontsize=10)
ax.set_ylabel("log(Green)", fontsize=10)
ax.set_zlabel("log(Blue)", fontsize=10)
ax.set_title("Content Log-RGB Clustered", fontsize=12)