| | import os |
| | import torch |
| | import numpy as np |
| | import lightning.pytorch as pl |
| | import gradio as gr |
| | import imageio |
| | import random |
| | import matplotlib.pyplot as plt |
| | import cv2 |
| | import skdim |
| |
|
| | from torch.utils.data import Dataset, DataLoader |
| |
|
| | from PIL import Image |
| | from matplotlib import cm |
| | from safetensors.torch import save_file, load_file |
| | from sklearn.cluster import AgglomerativeClustering |
| | from sklearn.manifold import TSNE |
| | from sklearn.neighbors import KDTree |
| | from sklearn.preprocessing import StandardScaler |
| |
|
| | from minimal_script import EmbeddingNetwork, closest_interval, adj_size, PLModule |
| |
|
| |
|
| | class PredictDataset(Dataset): |
| | def __init__(self, data_dir, sample=None): |
| | self.image_paths = [] |
| | extensions = ('jpg', 'jpeg', 'png', 'tif', 'webp') |
| | for fname in sorted(os.listdir(data_dir)): |
| | if any(fname.lower().endswith(ext) for ext in extensions): |
| | self.image_paths.append(os.path.join(data_dir, fname)) |
| | if sample: |
| | self.image_paths = random.sample(self.image_paths, sample) |
| |
|
| | def __len__(self): |
| | return len(self.image_paths) |
| |
|
| | def __getitem__(self, idx): |
| | path = self.image_paths[idx] |
| | image = imageio.v3.imread(path).copy() |
| | image = torch.from_numpy(image).permute(2, 0, 1) |
| | processed = closest_interval(adj_size(image, 1024)) |
| | processed = 2*(processed/255)-1 |
| | return processed.detach(), path |
| |
|
| |
|
| | def explore_embedding_space(embeddings, image_paths, model): |
| | """ |
| | Create an interface for exploring N-dimensional image embeddings |
| | |
| | Args: |
| | embeddings: NumPy array of shape [B, N] |
| | image_paths: List of B image file paths |
| | """ |
| | |
| | assert len(embeddings) == len(image_paths), "Mismatch between embeddings and image paths" |
| | assert embeddings.ndim == 2, "Embeddings should be 2-dimensional" |
| |
|
| | |
| | min_vals = embeddings.min(axis=0) |
| | max_vals = embeddings.max(axis=0) |
| | ranges = max_vals - min_vals |
| |
|
| | |
| | tree = KDTree(embeddings) |
| |
|
| | |
| | initial_point = embeddings.mean(axis=0).tolist() |
| |
|
| | |
| | sliders = [] |
| | for i in range(embeddings.shape[1]): |
| | slider = gr.Slider( |
| | float(min_vals[i]), |
| | float(max_vals[i]), |
| | value=float(initial_point[i]), |
| | step=float(ranges[i]) / 100, |
| | label=f"Dimension {i + 1}" |
| | ) |
| | sliders.append(slider) |
| |
|
| | def compute_gradient_heatmap(image_path): |
| | """Compute gradient heatmap for an image""" |
| | |
| | img = imageio.v3.imread(image_path).copy() |
| | img = torch.from_numpy(img).permute(2, 0, 1) |
| | img_tensor = closest_interval(adj_size(img, 1024)).unsqueeze(0) |
| | img_tensor = 2*(img_tensor/255)-1 |
| | img_tensor.requires_grad_(True) |
| |
|
| | |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | img_tensor = img_tensor.to(device).to(torch.float16) |
| |
|
| | |
| | with torch.enable_grad(): |
| | embd = model(img_tensor) |
| | norm = embd.norm(p=2, dim=1).sum() |
| | grad = torch.autograd.grad(norm, img_tensor, retain_graph=False)[0] |
| |
|
| | |
| | grad_mag = grad.squeeze(0).norm(dim=0).detach().cpu().numpy() |
| |
|
| | |
| | grad_min, grad_max = grad_mag.min(), grad_mag.max() |
| | if grad_max > grad_min: |
| | grad_norm = (grad_mag - grad_min) / (grad_max - grad_min) |
| | else: |
| | grad_norm = grad_mag * 0 |
| |
|
| | heatmap = cm.jet(grad_norm)[..., :3] |
| | return heatmap |
| |
|
| | def overlay_heatmap(original_img, heatmap, alpha=0.4): |
| | """Overlay heatmap on original image""" |
| | |
| | heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)) |
| | heatmap_img = heatmap_img.resize(original_img.size) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | blended = Image.blend(original_img, heatmap_img, alpha) |
| | return blended |
| |
|
| | def get_overlay_image(image_path): |
| | """Get image with gradient overlay""" |
| | img = Image.open(image_path).convert('RGB') |
| | |
| | |
| | return img |
| |
|
| | def add_caption_to_image(image, caption): |
| | """Add text caption to the bottom of an image""" |
| | |
| | if isinstance(image, Image.Image): |
| | img = np.array(image) |
| | else: |
| | img = image.copy() |
| |
|
| | |
| | bar_height = 30 |
| | img = cv2.copyMakeBorder(img, 0, bar_height, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
| |
|
| | |
| | font = cv2.FONT_HERSHEY_SIMPLEX |
| | text_size = cv2.getTextSize(caption, font, 0.5, 1)[0] |
| | text_x = (img.shape[1] - text_size[0]) // 2 |
| | text_y = img.shape[0] - 10 |
| | cv2.putText(img, caption, (text_x, text_y), font, 0.5, (255, 255, 255), 1) |
| |
|
| | return Image.fromarray(img) |
| |
|
| | |
| | def find_nearby_images(*point): |
| | point = np.array(point).reshape(1, -1) |
| | distances, indices = tree.query(point, k=8) |
| | indices = indices[0] |
| | distances = distances[0] |
| |
|
| | |
| | paths = [image_paths[i] for i in indices] |
| | images_with_gradients = [get_overlay_image(p) for p in paths] |
| |
|
| | |
| | final_images = [] |
| | for img, dist in zip(images_with_gradients, distances): |
| | caption = f"Dist: {dist:.2f}" |
| | final_img = add_caption_to_image(img, caption) |
| | final_images.append(final_img) |
| |
|
| | warning = "" |
| | if distances[0] > 5.0: |
| | warning = "⚠️ Nearest image is far (distance={:.2f}). Consider adjusting sliders.".format(distances[0]) |
| |
|
| | return final_images, warning |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## N-Dimensional Embedding Space Explorer") |
| | gr.Markdown("Adjust sliders to navigate. Images show gradient of embedding norm w.r.t. input.") |
| |
|
| | |
| | warning = gr.Textbox(label="Status", interactive=False) |
| |
|
| | |
| | gallery = gr.Gallery( |
| | label="Nearest Images (Distance Ordered)", |
| | columns=4, |
| | object_fit="contain", |
| | height="auto", |
| | show_label=True, |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| | for slider in sliders: |
| | slider.render() |
| |
|
| | |
| | for slider in sliders: |
| | slider.change( |
| | find_nearby_images, |
| | inputs=sliders, |
| | outputs=[gallery, warning] |
| | ) |
| |
|
| | |
| | demo.load( |
| | find_nearby_images, |
| | inputs=sliders, |
| | outputs=[gallery, warning] |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| |
|
| | def generate_embeddings(image_folder, mode, model): |
| | predict_dataset = PredictDataset(image_folder, 5000) |
| | predict_loader = DataLoader(predict_dataset, batch_size=1, num_workers=5, pin_memory=True) |
| | trainer = pl.Trainer(accelerator="gpu", logger=False, enable_checkpointing=False, precision="16-mixed") |
| | predictions_0 = trainer.predict(model, predict_loader) |
| | predictions = torch.cat([pred[0] for pred in predictions_0], dim=0).numpy() |
| | paths = [] |
| | for pred in predictions_0: |
| | for i in pred[1]: |
| | paths.append(i) |
| | if mode == 'Grouping': |
| | |
| | |
| | |
| | |
| | estimators = [skdim.id.TwoNN(), skdim.id.CorrInt(), skdim.id.DANCo()] |
| | results = {} |
| | |
| | for est in estimators: |
| | est.fit(predictions) |
| | results[type(est).__name__] = est.dimension_ |
| |
|
| | print("Intrinsic Dimension Estimates:") |
| | for name, dim in results.items(): |
| | print(f"{name}: {dim:.2f}") |
| | labels = cluster_embeddings(predictions) |
| |
|
| | row_norms = np.linalg.norm(predictions, axis=1) |
| | average_norms = np.mean(np.abs(predictions), axis=0) |
| | plt.figure(figsize=(8, 5)) |
| | plt.bar(range(predictions.shape[1]), average_norms, color='skyblue') |
| | plt.xlabel('Feature Index (C)') |
| | plt.ylabel('Average Norm') |
| | plt.title(f'Average Norm for Each Feature (Column)') |
| | plt.xticks(range(predictions.shape[1])) |
| | |
| | plt.savefig('Norms.png') |
| |
|
| | plt.figure(figsize=(8, 6)) |
| | tsne = TSNE(n_components=2, random_state=42) |
| | reduced_data = tsne.fit_transform(predictions) |
| | plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=row_norms, cmap='viridis', s=50, edgecolor='k', label="Data Points") |
| | plt.colorbar(label='Norm Value') |
| | plt.xlabel('Feature 1') |
| | plt.ylabel('Feature 2') |
| | plt.title(f'Scatter Plot of Data Points and Average Norm') |
| | plt.legend() |
| | plt.grid(True) |
| | plt.axis('equal') |
| | |
| | plt.savefig('Groups.png') |
| |
|
| | |
| | unique_clusters = np.unique(labels) |
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("## Explore Image Clusters by Style") |
| |
|
| | |
| | cluster_selector = gr.Dropdown(choices=unique_clusters.tolist(), label="Select Cluster to Explore") |
| |
|
| | |
| | image_gallery = gr.Gallery(label="Sample Images from Selected Cluster") |
| |
|
| |
|
| | |
| | def explore_clusters(cluster_idx): |
| | |
| | cluster_images = [paths[i] for i in range(len(labels)) if labels[i] == cluster_idx] |
| | |
| | images = [Image.open(img_path) for img_path in cluster_images[:50]] |
| | return images |
| |
|
| | |
| | cluster_selector.change(fn=explore_clusters, inputs=cluster_selector, outputs=image_gallery) |
| |
|
| | demo.launch() |
| | elif mode == 'Explore': |
| | demo = explore_embedding_space(predictions, paths, model.to('cuda').to(torch.float16)) |
| | demo.launch() |
| |
|
| |
|
| | |
| | def cluster_embeddings(predictions, distance_threshold=32.0): |
| | agg_clustering = AgglomerativeClustering( |
| | n_clusters=None, |
| | distance_threshold=distance_threshold, |
| | linkage='ward' |
| | ) |
| | labels = agg_clustering.fit_predict(predictions) |
| | return labels |
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | folder = 'Enter Images folder name here' |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = PLModule() |
| | state_dict = load_file("Style_Embedder_v3.safetensors") |
| | model.network.load_state_dict(state_dict) |
| | |
| | generate_embeddings(folder, 'Grouping', model) |
| |
|