Fgdfgfthgr commited on
Commit
51fedf6
·
verified ·
1 Parent(s): a006890

Upload gallery_review.py

Browse files
Files changed (1) hide show
  1. gallery_review.py +316 -0
gallery_review.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import lightning.pytorch as pl
5
+ import gradio as gr
6
+ import imageio
7
+ import random
8
+ import matplotlib.pyplot as plt
9
+ import cv2
10
+
11
+ from torch.utils.data import Dataset, DataLoader
12
+
13
+ from PIL import Image
14
+ from matplotlib import cm
15
+
16
+ from minimal_script import EmbeddingNetworkSmall, closest_interval, down_to_1k
17
+ from sklearn.cluster import AgglomerativeClustering
18
+ from sklearn.manifold import TSNE
19
+ from sklearn.neighbors import KDTree
20
+
21
+
22
+ class PLModule(pl.LightningModule):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.save_hyperparameters()
26
+ self.network = EmbeddingNetworkSmall()
27
+
28
+ def forward(self, x):
29
+ return self.network(x)
30
+
31
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
32
+ outputs = self.forward(batch[0])
33
+ return outputs, batch[1]
34
+
35
+
36
+ class PredictDataset(Dataset):
37
+ def __init__(self, data_dir, sample=None):
38
+ self.image_paths = []
39
+ extensions = ('jpg', 'jpeg', 'png', 'tif', 'webp')
40
+ for fname in sorted(os.listdir(data_dir)):
41
+ if any(fname.lower().endswith(ext) for ext in extensions):
42
+ self.image_paths.append(os.path.join(data_dir, fname))
43
+ if sample:
44
+ self.image_paths = random.sample(self.image_paths, sample)
45
+
46
+ def __len__(self):
47
+ return len(self.image_paths)
48
+
49
+ def __getitem__(self, idx):
50
+ path = self.image_paths[idx]
51
+ image = imageio.v3.imread(path).copy()
52
+ image = torch.from_numpy(image).permute(2, 0, 1)
53
+ processed = closest_interval(down_to_1k(image, 1024))
54
+ processed = 2*(processed/255)-1
55
+ return processed.detach(), path
56
+
57
+
58
+ def explore_embedding_space(embeddings, image_paths, model):
59
+ """
60
+ Create an interface for exploring N-dimensional image embeddings
61
+
62
+ Args:
63
+ embeddings: NumPy array of shape [B, N]
64
+ image_paths: List of B image file paths
65
+ """
66
+ # Validate inputs
67
+ assert len(embeddings) == len(image_paths), "Mismatch between embeddings and image paths"
68
+ assert embeddings.ndim == 2, "Embeddings should be 2-dimensional"
69
+
70
+ # Precompute min/max for each dimension
71
+ min_vals = embeddings.min(axis=0)
72
+ max_vals = embeddings.max(axis=0)
73
+ ranges = max_vals - min_vals
74
+
75
+ # Build KDTree for efficient nearest neighbor search
76
+ tree = KDTree(embeddings)
77
+
78
+ # Create initial point (mean of embeddings)
79
+ initial_point = embeddings.mean(axis=0).tolist()
80
+
81
+ # Create slider components for each dimension
82
+ sliders = []
83
+ for i in range(embeddings.shape[1]):
84
+ slider = gr.Slider(
85
+ float(min_vals[i]),
86
+ float(max_vals[i]),
87
+ value=float(initial_point[i]),
88
+ step=float(ranges[i]) / 100,
89
+ label=f"Dimension {i + 1}"
90
+ )
91
+ sliders.append(slider)
92
+
93
+ def compute_gradient_heatmap(image_path):
94
+ """Compute gradient heatmap for an image"""
95
+ # Load and preprocess image
96
+ img = imageio.v3.imread(image_path).copy()
97
+ img = torch.from_numpy(img).permute(2, 0, 1)
98
+ img_tensor = closest_interval(down_to_1k(img, 1024)).unsqueeze(0)
99
+ img_tensor = 2*(img_tensor/255)-1
100
+ img_tensor.requires_grad_(True)
101
+
102
+ # Move to GPU if available
103
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
104
+ img_tensor = img_tensor.to(device)
105
+
106
+ # Compute embedding and gradient
107
+ with torch.enable_grad():
108
+ embd = model(img_tensor)
109
+ norm = embd.norm(p=2, dim=1).sum()
110
+ grad = torch.autograd.grad(norm, img_tensor, retain_graph=False)[0]
111
+
112
+ # Compute gradient magnitude
113
+ grad_mag = grad.squeeze(0).norm(dim=0).detach().cpu().numpy()
114
+
115
+ # Normalize and apply colormap
116
+ grad_min, grad_max = grad_mag.min(), grad_mag.max()
117
+ if grad_max > grad_min:
118
+ grad_norm = (grad_mag - grad_min) / (grad_max - grad_min)
119
+ else:
120
+ grad_norm = grad_mag * 0 # Handle uniform case
121
+
122
+ heatmap = cm.jet(grad_norm)[..., :3] # Use jet colormap
123
+ return heatmap
124
+
125
+ def overlay_heatmap(original_img, heatmap, alpha=0.6):
126
+ """Overlay heatmap on original image"""
127
+ # Resize heatmap to match original image
128
+ heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8))
129
+ heatmap_img = heatmap_img.resize(original_img.size)
130
+
131
+ # Convert original to RGBA and heatmap to RGBA
132
+ #original_rgba = original_img.convert("RGBA")
133
+ #heatmap_rgba = heatmap_img.convert("RGBA")
134
+
135
+ # Blend images
136
+ blended = Image.blend(original_img, heatmap_img, alpha)
137
+ return blended
138
+
139
+ def get_overlay_image(image_path):
140
+ """Get image with gradient overlay"""
141
+ img = Image.open(image_path).convert('RGB')
142
+ heatmap = compute_gradient_heatmap(image_path)
143
+ return overlay_heatmap(img, heatmap)
144
+ #return img
145
+
146
+ def add_caption_to_image(image, caption):
147
+ """Add text caption to the bottom of an image"""
148
+ # Convert to OpenCV format
149
+ if isinstance(image, Image.Image):
150
+ img = np.array(image)
151
+ else:
152
+ img = image.copy()
153
+
154
+ # Add black bar at bottom
155
+ bar_height = 30
156
+ img = cv2.copyMakeBorder(img, 0, bar_height, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
157
+
158
+ # Add white text
159
+ font = cv2.FONT_HERSHEY_SIMPLEX
160
+ text_size = cv2.getTextSize(caption, font, 0.5, 1)[0]
161
+ text_x = (img.shape[1] - text_size[0]) // 2
162
+ text_y = img.shape[0] - 10
163
+ cv2.putText(img, caption, (text_x, text_y), font, 0.5, (255, 255, 255), 1)
164
+
165
+ return Image.fromarray(img)
166
+
167
+ # Function to find nearby images
168
+ def find_nearby_images(*point):
169
+ point = np.array(point).reshape(1, -1)
170
+ distances, indices = tree.query(point, k=8)
171
+ indices = indices[0]
172
+ distances = distances[0]
173
+
174
+ # Get paths and create overlay images
175
+ paths = [image_paths[i] for i in indices]
176
+ images_with_gradients = [get_overlay_image(p) for p in paths]
177
+
178
+ # Create images with baked-in captions
179
+ final_images = []
180
+ for img, dist in zip(images_with_gradients, distances):
181
+ caption = f"Dist: {dist:.2f}"
182
+ final_img = add_caption_to_image(img, caption)
183
+ final_images.append(final_img)
184
+
185
+ warning = ""
186
+ if distances[0] > 5.0: # Warn if nearest image is far
187
+ warning = "⚠️ Nearest image is far (distance={:.2f}). Consider adjusting sliders.".format(distances[0])
188
+
189
+ return final_images, warning
190
+
191
+ # Build interface
192
+ with gr.Blocks() as demo:
193
+ gr.Markdown("## N-Dimensional Embedding Space Explorer")
194
+ gr.Markdown("Adjust sliders to navigate. Images show gradient of embedding norm w.r.t. input.")
195
+
196
+ # Warning output
197
+ warning = gr.Textbox(label="Status", interactive=False)
198
+
199
+ # Gallery for images
200
+ gallery = gr.Gallery(
201
+ label="Nearest Images (Distance Ordered)",
202
+ columns=4,
203
+ object_fit="contain",
204
+ height="auto",
205
+ show_label=True,
206
+ )
207
+
208
+ # Create sliders in a compact row
209
+ with gr.Row():
210
+ for slider in sliders:
211
+ slider.render()
212
+
213
+ # Connect slider changes to update function
214
+ for slider in sliders:
215
+ slider.change(
216
+ find_nearby_images,
217
+ inputs=sliders,
218
+ outputs=[gallery, warning]
219
+ )
220
+
221
+ # Initial trigger
222
+ demo.load(
223
+ find_nearby_images,
224
+ inputs=sliders,
225
+ outputs=[gallery, warning]
226
+ )
227
+
228
+ return demo
229
+
230
+
231
+
232
+ def generate_embeddings(image_folder, mode, model):
233
+ predict_dataset = PredictDataset(image_folder, 1000)
234
+ predict_loader = DataLoader(predict_dataset, batch_size=1, num_workers=5, pin_memory=True)
235
+ trainer = pl.Trainer(accelerator="gpu", logger=False, enable_checkpointing=False)
236
+ predictions_0 = trainer.predict(model, predict_loader)
237
+ predictions = torch.cat([pred[0] for pred in predictions_0], dim=0).numpy()
238
+ paths = []
239
+ for pred in predictions_0:
240
+ for i in pred[1]:
241
+ paths.append(i)
242
+ if mode == 'Grouping':
243
+ labels = cluster_embeddings(predictions)
244
+
245
+ row_norms = np.linalg.norm(predictions, axis=1)
246
+ average_norms = np.mean(np.abs(predictions), axis=0)
247
+ plt.figure(figsize=(8, 5))
248
+ plt.bar(range(predictions.shape[1]), average_norms, color='skyblue')
249
+ plt.xlabel('Feature Index (C)')
250
+ plt.ylabel('Average Norm')
251
+ plt.title(f'Average Norm for Each Feature (Column)')
252
+ plt.xticks(range(predictions.shape[1]))
253
+ plt.show()
254
+
255
+ plt.figure(figsize=(8, 6))
256
+ tsne = TSNE(n_components=2, random_state=42)
257
+ reduced_data = tsne.fit_transform(predictions)
258
+ plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=row_norms, cmap='viridis', s=50, edgecolor='k', label="Data Points")
259
+ plt.colorbar(label='Norm Value')
260
+ plt.xlabel('Feature 1')
261
+ plt.ylabel('Feature 2')
262
+ plt.title(f'Scatter Plot of Data Points and Average Norm')
263
+ plt.legend()
264
+ plt.grid(True)
265
+ plt.axis('equal')
266
+ plt.show()
267
+
268
+ # List unique clusters
269
+ unique_clusters = np.unique(labels)
270
+ # Gradio UI
271
+ with gr.Blocks() as demo:
272
+ gr.Markdown("## Explore Image Clusters by Style")
273
+
274
+ # Dropdown for selecting a cluster
275
+ cluster_selector = gr.Dropdown(choices=unique_clusters.tolist(), label="Select Cluster to Explore")
276
+
277
+ # Gallery to display images
278
+ image_gallery = gr.Gallery(label="Sample Images from Selected Cluster")
279
+
280
+
281
+ # Gradio Interface for Cluster Exploration
282
+ def explore_clusters(cluster_idx):
283
+ # Find images that belong to the selected cluster
284
+ cluster_images = [paths[i] for i in range(len(labels)) if labels[i] == cluster_idx]
285
+ # Load and return images
286
+ images = [Image.open(img_path) for img_path in cluster_images[:50]] # Show a sample of 50 images
287
+ return images
288
+
289
+ # Update function for the gallery
290
+ cluster_selector.change(fn=explore_clusters, inputs=cluster_selector, outputs=image_gallery)
291
+
292
+ demo.launch()
293
+ elif mode == 'Explore':
294
+ demo = explore_embedding_space(predictions, paths, model.to('cuda'))
295
+ demo.launch()
296
+
297
+
298
+ # Apply Agglomerative Clustering
299
+ def cluster_embeddings(predictions, distance_threshold=6.0):
300
+ agg_clustering = AgglomerativeClustering(
301
+ n_clusters=None,
302
+ distance_threshold=distance_threshold,
303
+ linkage='ward'
304
+ )
305
+ labels = agg_clustering.fit_predict(predictions)
306
+ return labels
307
+
308
+
309
+
310
+ if __name__ == '__main__':
311
+ folder = 'Enter Images folder name here'
312
+ #folder = 'images_for_style_embedding'
313
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
314
+ model = PLModule.load_from_checkpoint('Final_8.ckpt')
315
+ # 'Grouping' or 'Explore'
316
+ generate_embeddings(folder, 'Grouping', model)