Tenbatsu24 commited on
Commit
69eec85
Β·
1 Parent(s): 3cd97ec

add: vitb16 support and more user inputs.

Browse files
Files changed (1) hide show
  1. app.py +143 -63
app.py CHANGED
@@ -14,42 +14,51 @@ from sklearn.decomposition import PCA
14
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
15
  IMAGENET_STD = [0.229, 0.224, 0.225]
16
 
17
- IMAGE_SIZE = 672
18
  PATCH_SIZE = 16
19
  PCA_COMPONENTS = 3
20
 
21
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
  MODEL_IDS = {
24
- "DiNO": "OK-AI/dino-vits16-pretrain-in1k",
25
- "iBOT": "OK-AI/ibot-vits16-pretrain-in1k",
26
- "LeJEPA": "OK-AI/lejepa-vits16-pretrain-in1k",
 
 
 
 
 
 
 
27
  }
28
- MODEL_NAMES = list(MODEL_IDS.keys()) # fixed order
 
29
 
30
  # ── model loading (cached) ────────────────────────────────────────────────────
31
 
32
  _model_cache: dict[str, torch.nn.Module] = {}
33
 
34
 
35
- def get_model(name: str) -> torch.nn.Module:
36
- if name not in _model_cache:
 
37
  model = AutoModel.from_pretrained(
38
- MODEL_IDS[name],
 
39
  trust_remote_code=True,
40
  )
41
  model.eval().to(DEVICE)
42
- _model_cache[name] = model
43
- return _model_cache[name]
44
 
45
 
46
  # ── image helpers ─────────────────────────────────────────────────────────────
47
 
48
 
49
  def resize_image_for_patches(
50
- image: Image.Image,
51
- image_size: int = IMAGE_SIZE,
52
- patch_size: int = PATCH_SIZE,
53
  ) -> torch.Tensor:
54
  """Resize so height = image_size and width is patch-aligned,
55
  preserving aspect ratio. Returns (1, 3, H, W) float tensor."""
@@ -71,12 +80,12 @@ def preprocess(image_tensor: torch.Tensor) -> torch.Tensor:
71
  ).unsqueeze(0)
72
 
73
 
74
- def pad_to_square(img: Image.Image) -> Image.Image:
75
  """Letterbox/pillarbox img onto a square canvas with a dark background.
76
  Ensures all output images share the same dimensions so the Gradio row
77
  never reflows or stretches when aspect ratios differ."""
78
  w, h = img.size
79
- size = max(w, h)
80
  canvas = Image.new("RGB", (size, size), color=(18, 18, 18))
81
  canvas.paste(img, ((size - w) // 2, (size - h) // 2))
82
  return canvas
@@ -85,7 +94,7 @@ def pad_to_square(img: Image.Image) -> Image.Image:
85
  # ── PCA visualisation ─────────────────────────────────────────────────────────
86
 
87
 
88
- def pca_vis(model: torch.nn.Module, image_tensor: torch.Tensor) -> Image.Image:
89
  """Run image through model, PCA patch features β†’ square-padded RGB PIL image."""
90
  model_input = preprocess(image_tensor).to(DEVICE)
91
 
@@ -107,30 +116,49 @@ def pca_vis(model: torch.nn.Module, image_tensor: torch.Tensor) -> Image.Image:
107
 
108
  # nearest-neighbour upscale β†’ pad to square so all outputs are the same size
109
  upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST)
110
- return pad_to_square(upscaled)
111
 
112
 
113
  # ── streaming inference ───────────────────────────────────────────────────────
114
 
115
- PENDING = Image.new("RGB", (IMAGE_SIZE, IMAGE_SIZE), color=(18, 18, 18))
116
-
117
 
118
- def run(pil_image: Image.Image):
119
  """
120
- Generator: yields (dino_out, ibot_out, lejepa_out) after each model
121
- finishes, so the UI updates one image at a time.
122
  """
123
  if pil_image is None:
124
  raise gr.Error("Please upload an image.")
125
 
 
 
 
 
 
 
 
126
  pil_image = pil_image.convert("RGB")
127
- image_tensor = resize_image_for_patches(pil_image)
128
- results = [PENDING, PENDING, PENDING]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- for i, name in enumerate(MODEL_NAMES):
131
- model = get_model(name)
132
- results[i] = pca_vis(model, image_tensor)
133
- yield tuple(results)
134
 
135
 
136
  # ── UI ────────────────────────────────────────────────────────────────────────
@@ -146,6 +174,14 @@ CSS = """
146
  font-size: 0.9rem;
147
  padding-bottom: 1rem;
148
  }
 
 
 
 
 
 
 
 
149
  .model-label {
150
  text-align: center;
151
  font-weight: 600;
@@ -153,11 +189,20 @@ CSS = """
153
  color: #374151;
154
  padding: 0.25rem 0;
155
  }
 
156
  .output-col {
157
- display: flex;
158
- flex-direction: column;
159
- align-items: center;
160
- gap: 0.25rem;
 
 
 
 
 
 
 
 
161
  }
162
  .subtitle-row a, .model-label a {
163
  color: inherit;
@@ -171,7 +216,6 @@ footer { display: none !important; }
171
  """
172
 
173
  with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
174
-
175
  gr.HTML("""
176
  <div class="title-row">
177
  <h1 style="font-size:1.6rem; font-weight:700; margin:0;">
@@ -179,10 +223,8 @@ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
179
  </h1>
180
  </div>
181
  <div class="subtitle-row">
182
- ViT-S/16 &nbsp;Β·&nbsp; ImageNet-1K pre-training &nbsp;Β·&nbsp;
183
- <a href="https://huggingface.co/OK-AI/dino-vits16-pretrain-in1k" target="_blank">DiNO</a> &nbsp;Β·&nbsp;
184
- <a href="https://huggingface.co/OK-AI/ibot-vits16-pretrain-in1k" target="_blank">iBOT</a> &nbsp;Β·&nbsp;
185
- <a href="https://huggingface.co/OK-AI/lejepa-vits16-pretrain-in1k" target="_blank">LeJEPA</a>
186
  </div>
187
  """)
188
 
@@ -193,49 +235,87 @@ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
193
  label="Input image",
194
  show_label=True,
195
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  run_btn = gr.Button("Visualise", variant="primary")
 
197
  gr.HTML("""
198
  <p style="font-size:0.8rem; color:#9ca3af; margin-top:0.5rem; line-height:1.5;">
199
- Image is resized to 672 px tall (patch-aligned, aspect preserved)
200
- before inference. PCA is fit on all patch tokens and projected to
201
  3 components, then scaled with sigmoid for colour display.
202
- Results appear as each model finishes.
203
- </p>
204
-
205
- <p style="font-size:0.75rem; color:#9ca3af; margin-top:0.25rem;">
206
- Models: <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI on HuggingFace</a>
207
- &nbsp;Β·&nbsp;
208
- Code: <a href="https://github.com/Open-Knowledge-AI/lite_ssl" target="_blank">lite_ssl</a>
209
  </p>
210
  """)
211
 
212
  with gr.Column(scale=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  with gr.Row(equal_height=True):
214
  with gr.Column(elem_classes="output-col"):
215
- gr.HTML('<div class="model-label"><a href="https://huggingface.co/OK-AI/dino-vits16-pretrain-in1k" target="_blank">DiNO</a></div>')
216
- out_dino = gr.Image(show_label=False, interactive=False)
217
  with gr.Column(elem_classes="output-col"):
218
- gr.HTML('<div class="model-label"><a href="https://huggingface.co/OK-AI/ibot-vits16-pretrain-in1k" target="_blank">iBOT</a></div>')
219
- out_ibot = gr.Image(show_label=False, interactive=False)
220
  with gr.Column(elem_classes="output-col"):
221
- gr.HTML('<div class="model-label"><a href="https://huggingface.co/OK-AI/lejepa-vits16-pretrain-in1k" target="_blank">LeJEPA</a></div>')
222
- out_lejepa = gr.Image(show_label=False, interactive=False)
 
 
 
 
 
 
223
 
224
  run_btn.click(
225
  fn=run,
226
- inputs=[input_image],
227
- outputs=[out_dino, out_ibot, out_lejepa],
228
- )
229
-
230
- gr.Examples(
231
- examples=[
232
- [f"examples/{f}"]
233
- for f in sorted(os.listdir("examples"))
234
- if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
235
- ],
236
- inputs=[input_image],
237
  )
238
 
 
 
 
 
 
 
 
 
 
239
 
240
  if __name__ == "__main__":
241
  demo.launch()
 
14
  IMAGENET_MEAN = [0.485, 0.456, 0.406]
15
  IMAGENET_STD = [0.229, 0.224, 0.225]
16
 
 
17
  PATCH_SIZE = 16
18
  PCA_COMPONENTS = 3
19
 
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
  MODEL_IDS = {
23
+ "ViT-S/16": {
24
+ "DiNO": "OK-AI/dino-vits16-pretrain-in1k",
25
+ "iBOT": "OK-AI/ibot-vits16-pretrain-in1k",
26
+ "LeJEPA": "OK-AI/lejepa-vits16-pretrain-in1k",
27
+ },
28
+ "ViT-B/16": {
29
+ "DiNO": "OK-AI/dino-vitb16-pretrain-in1k",
30
+ "iBOT": "OK-AI/ibot-vitb16-pretrain-in1k",
31
+ "LeJEPA": "OK-AI/lejepa-vitb16-pretrain-in1k",
32
+ }
33
  }
34
+
35
+ MODEL_KEYS = ["DiNO", "iBOT", "LeJEPA"]
36
 
37
  # ── model loading (cached) ────────────────────────────────────────────────────
38
 
39
  _model_cache: dict[str, torch.nn.Module] = {}
40
 
41
 
42
+ def get_model(repo_id: str, revision: str) -> torch.nn.Module:
43
+ cache_key = f"{repo_id}@{revision}"
44
+ if cache_key not in _model_cache:
45
  model = AutoModel.from_pretrained(
46
+ repo_id,
47
+ revision=revision,
48
  trust_remote_code=True,
49
  )
50
  model.eval().to(DEVICE)
51
+ _model_cache[cache_key] = model
52
+ return _model_cache[cache_key]
53
 
54
 
55
  # ── image helpers ─────────────────────────────────────────────────────────────
56
 
57
 
58
  def resize_image_for_patches(
59
+ image: Image.Image,
60
+ image_size: int,
61
+ patch_size: int = PATCH_SIZE,
62
  ) -> torch.Tensor:
63
  """Resize so height = image_size and width is patch-aligned,
64
  preserving aspect ratio. Returns (1, 3, H, W) float tensor."""
 
80
  ).unsqueeze(0)
81
 
82
 
83
+ def pad_to_square(img: Image.Image, canvas_size: int) -> Image.Image:
84
  """Letterbox/pillarbox img onto a square canvas with a dark background.
85
  Ensures all output images share the same dimensions so the Gradio row
86
  never reflows or stretches when aspect ratios differ."""
87
  w, h = img.size
88
+ size = max(w, h, canvas_size)
89
  canvas = Image.new("RGB", (size, size), color=(18, 18, 18))
90
  canvas.paste(img, ((size - w) // 2, (size - h) // 2))
91
  return canvas
 
94
  # ── PCA visualisation ─────────────────────────────────────────────────────────
95
 
96
 
97
+ def pca_vis(model: torch.nn.Module, image_tensor: torch.Tensor, canvas_size: int) -> Image.Image:
98
  """Run image through model, PCA patch features β†’ square-padded RGB PIL image."""
99
  model_input = preprocess(image_tensor).to(DEVICE)
100
 
 
116
 
117
  # nearest-neighbour upscale β†’ pad to square so all outputs are the same size
118
  upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST)
119
+ return pad_to_square(upscaled, canvas_size)
120
 
121
 
122
  # ── streaming inference ───────────────────────────────────────────────────────
123
 
 
 
124
 
125
+ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
126
  """
127
+ Generator: yields updates sequentially across models and sizes.
 
128
  """
129
  if pil_image is None:
130
  raise gr.Error("Please upload an image.")
131
 
132
+ image_size = int(image_size)
133
+ pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
134
+
135
+ # 6 total positions: ViT-S [dino, ibot, lejepa], ViT-B [dino, ibot, lejepa]
136
+ results = [pending_img] * 6
137
+ yield tuple(results)
138
+
139
  pil_image = pil_image.convert("RGB")
140
+ image_tensor = resize_image_for_patches(pil_image, image_size)
141
+
142
+ idx = 0
143
+ for arch in ["ViT-S/16", "ViT-B/16"]:
144
+ for model_key in MODEL_KEYS:
145
+ repo_id = MODEL_IDS[arch][model_key]
146
+
147
+ # LeJEPA only supports student weights
148
+ current_weight = "student" if model_key == "LeJEPA" else weight_type
149
+ revision = f"{epoch}/{current_weight}"
150
+
151
+ try:
152
+ model = get_model(repo_id, revision)
153
+ results[idx] = pca_vis(model, image_tensor, image_size)
154
+ except Exception as e:
155
+ print(f"Error processing {repo_id} ({revision}): {e}")
156
+ # Create an error placeholder card if a model/revision download fails
157
+ error_canvas = Image.new("RGB", (image_size, image_size), color=(40, 20, 20))
158
+ results[idx] = error_canvas
159
 
160
+ yield tuple(results)
161
+ idx += 1
 
 
162
 
163
 
164
  # ── UI ────────────────────────────────────────────────────────────────────────
 
174
  font-size: 0.9rem;
175
  padding-bottom: 1rem;
176
  }
177
+ .arch-header {
178
+ font-size: 1.2rem;
179
+ font-weight: 700;
180
+ margin-top: 1rem;
181
+ padding-left: 0.5rem;
182
+ border-left: 4px solid #3b82f6;
183
+ color: #1f2937;
184
+ }
185
  .model-label {
186
  text-align: center;
187
  font-weight: 600;
 
189
  color: #374151;
190
  padding: 0.25rem 0;
191
  }
192
+ /* Ensure strict rigid layouts for outputs to avoid layout shifting */
193
  .output-col {
194
+ display: flex !important;
195
+ flex-direction: column !important;
196
+ align-items: center !important;
197
+ gap: 0.25rem !important;
198
+ flex: 1 1 0% !important;
199
+ min-width: 150px !important;
200
+ }
201
+ .output-col img {
202
+ aspect-ratio: 1 / 1 !important;
203
+ object-fit: contain !important;
204
+ max-height: 350px !important;
205
+ width: 100% !important;
206
  }
207
  .subtitle-row a, .model-label a {
208
  color: inherit;
 
216
  """
217
 
218
  with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
 
219
  gr.HTML("""
220
  <div class="title-row">
221
  <h1 style="font-size:1.6rem; font-weight:700; margin:0;">
 
223
  </h1>
224
  </div>
225
  <div class="subtitle-row">
226
+ ImageNet-1K pre-training &nbsp;Β·&nbsp;
227
+ <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI Models</a>
 
 
228
  </div>
229
  """)
230
 
 
235
  label="Input image",
236
  show_label=True,
237
  )
238
+
239
+ with gr.Row():
240
+ opt_epoch = gr.Dropdown(
241
+ choices=["ep100", "ep300"],
242
+ value="ep300",
243
+ label="Epochs",
244
+ interactive=True
245
+ )
246
+ opt_weight = gr.Dropdown(
247
+ choices=["student", "teacher"],
248
+ value="teacher",
249
+ label="Weight Type",
250
+ info="LeJEPA always uses student",
251
+ interactive=True
252
+ )
253
+
254
+ opt_size = gr.Dropdown(
255
+ choices=["224", "448", "672", "1280"],
256
+ value="672",
257
+ label="Image Target Resolution",
258
+ interactive=True
259
+ )
260
+
261
  run_btn = gr.Button("Visualise", variant="primary")
262
+
263
  gr.HTML("""
264
  <p style="font-size:0.8rem; color:#9ca3af; margin-top:0.5rem; line-height:1.5;">
265
+ PCA is fit on all patch tokens and projected to
 
266
  3 components, then scaled with sigmoid for colour display.
267
+ Results stream seamlessly into view as individual variants complete.
 
 
 
 
 
 
268
  </p>
269
  """)
270
 
271
  with gr.Column(scale=3):
272
+ # ── ViT-S/16 Row ──
273
+ gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
274
+ with gr.Row(equal_height=True):
275
+ with gr.Column(elem_classes="output-col"):
276
+ gr.HTML('<div class="model-label">DiNO (S/16)</div>')
277
+ out_dino_s = gr.Image(show_label=False, interactive=False)
278
+ with gr.Column(elem_classes="output-col"):
279
+ gr.HTML('<div class="model-label">iBOT (S/16)</div>')
280
+ out_ibot_s = gr.Image(show_label=False, interactive=False)
281
+ with gr.Column(elem_classes="output-col"):
282
+ gr.HTML('<div class="model-label">LeJEPA (S/16)</div>')
283
+ out_lejepa_s = gr.Image(show_label=False, interactive=False)
284
+
285
+ # ── ViT-B/16 Row ──
286
+ gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
287
  with gr.Row(equal_height=True):
288
  with gr.Column(elem_classes="output-col"):
289
+ gr.HTML('<div class="model-label">DiNO (B/16)</div>')
290
+ out_dino_b = gr.Image(show_label=False, interactive=False)
291
  with gr.Column(elem_classes="output-col"):
292
+ gr.HTML('<div class="model-label">iBOT (B/16)</div>')
293
+ out_ibot_b = gr.Image(show_label=False, interactive=False)
294
  with gr.Column(elem_classes="output-col"):
295
+ gr.HTML('<div class="model-label">LeJEPA (B/16)</div>')
296
+ out_lejepa_b = gr.Image(show_label=False, interactive=False)
297
+
298
+ # Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
299
+ output_targets = [
300
+ out_dino_s, out_ibot_s, out_lejepa_s,
301
+ out_dino_b, out_ibot_b, out_lejepa_b
302
+ ]
303
 
304
  run_btn.click(
305
  fn=run,
306
+ inputs=[input_image, opt_epoch, opt_weight, opt_size],
307
+ outputs=output_targets,
 
 
 
 
 
 
 
 
 
308
  )
309
 
310
+ if os.path.exists("examples"):
311
+ gr.Examples(
312
+ examples=[
313
+ [f"examples/{f}", "ep300", "teacher", "672"]
314
+ for f in sorted(os.listdir("examples"))
315
+ if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
316
+ ],
317
+ inputs=[input_image, opt_epoch, opt_weight, opt_size],
318
+ )
319
 
320
  if __name__ == "__main__":
321
  demo.launch()