Tenbatsu24 commited on
Commit
d647dfd
Β·
1 Parent(s): 69eec85

add: vitb16 support and more user inputs.

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -132,7 +132,6 @@ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
132
  image_size = int(image_size)
133
  pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
134
 
135
- # 6 total positions: ViT-S [dino, ibot, lejepa], ViT-B [dino, ibot, lejepa]
136
  results = [pending_img] * 6
137
  yield tuple(results)
138
 
@@ -144,7 +143,6 @@ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
144
  for model_key in MODEL_KEYS:
145
  repo_id = MODEL_IDS[arch][model_key]
146
 
147
- # LeJEPA only supports student weights
148
  current_weight = "student" if model_key == "LeJEPA" else weight_type
149
  revision = f"{epoch}/{current_weight}"
150
 
@@ -153,7 +151,6 @@ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
153
  results[idx] = pca_vis(model, image_tensor, image_size)
154
  except Exception as e:
155
  print(f"Error processing {repo_id} ({revision}): {e}")
156
- # Create an error placeholder card if a model/revision download fails
157
  error_canvas = Image.new("RGB", (image_size, image_size), color=(40, 20, 20))
158
  results[idx] = error_canvas
159
 
@@ -189,7 +186,6 @@ CSS = """
189
  color: #374151;
190
  padding: 0.25rem 0;
191
  }
192
- /* Ensure strict rigid layouts for outputs to avoid layout shifting */
193
  .output-col {
194
  display: flex !important;
195
  flex-direction: column !important;
@@ -204,14 +200,22 @@ CSS = """
204
  max-height: 350px !important;
205
  width: 100% !important;
206
  }
207
- .subtitle-row a, .model-label a {
208
  color: inherit;
209
  text-decoration: underline;
210
  text-decoration-color: #d1d5db;
211
  }
212
- .model-label a:hover, .subtitle-row a:hover {
213
  text-decoration-color: currentColor;
214
  }
 
 
 
 
 
 
 
 
215
  footer { display: none !important; }
216
  """
217
 
@@ -273,28 +277,43 @@ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
273
  gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
274
  with gr.Row(equal_height=True):
275
  with gr.Column(elem_classes="output-col"):
276
- gr.HTML('<div class="model-label">DiNO (S/16)</div>')
 
277
  out_dino_s = gr.Image(show_label=False, interactive=False)
278
  with gr.Column(elem_classes="output-col"):
279
- gr.HTML('<div class="model-label">iBOT (S/16)</div>')
 
280
  out_ibot_s = gr.Image(show_label=False, interactive=False)
281
  with gr.Column(elem_classes="output-col"):
282
- gr.HTML('<div class="model-label">LeJEPA (S/16)</div>')
 
283
  out_lejepa_s = gr.Image(show_label=False, interactive=False)
284
 
285
  # ── ViT-B/16 Row ──
286
  gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
287
  with gr.Row(equal_height=True):
288
  with gr.Column(elem_classes="output-col"):
289
- gr.HTML('<div class="model-label">DiNO (B/16)</div>')
 
290
  out_dino_b = gr.Image(show_label=False, interactive=False)
291
  with gr.Column(elem_classes="output-col"):
292
- gr.HTML('<div class="model-label">iBOT (B/16)</div>')
 
293
  out_ibot_b = gr.Image(show_label=False, interactive=False)
294
  with gr.Column(elem_classes="output-col"):
295
- gr.HTML('<div class="model-label">LeJEPA (B/16)</div>')
 
296
  out_lejepa_b = gr.Image(show_label=False, interactive=False)
297
 
 
 
 
 
 
 
 
 
 
298
  # Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
299
  output_targets = [
300
  out_dino_s, out_ibot_s, out_lejepa_s,
 
132
  image_size = int(image_size)
133
  pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
134
 
 
135
  results = [pending_img] * 6
136
  yield tuple(results)
137
 
 
143
  for model_key in MODEL_KEYS:
144
  repo_id = MODEL_IDS[arch][model_key]
145
 
 
146
  current_weight = "student" if model_key == "LeJEPA" else weight_type
147
  revision = f"{epoch}/{current_weight}"
148
 
 
151
  results[idx] = pca_vis(model, image_tensor, image_size)
152
  except Exception as e:
153
  print(f"Error processing {repo_id} ({revision}): {e}")
 
154
  error_canvas = Image.new("RGB", (image_size, image_size), color=(40, 20, 20))
155
  results[idx] = error_canvas
156
 
 
186
  color: #374151;
187
  padding: 0.25rem 0;
188
  }
 
189
  .output-col {
190
  display: flex !important;
191
  flex-direction: column !important;
 
200
  max-height: 350px !important;
201
  width: 100% !important;
202
  }
203
+ .subtitle-row a, .model-label a, .custom-footer a {
204
  color: inherit;
205
  text-decoration: underline;
206
  text-decoration-color: #d1d5db;
207
  }
208
+ .model-label a:hover, .subtitle-row a:hover, .custom-footer a:hover {
209
  text-decoration-color: currentColor;
210
  }
211
+ .custom-footer {
212
+ text-align: center;
213
+ margin-top: 2.5rem;
214
+ padding-top: 1rem;
215
+ border-top: 1px solid #e5e7eb;
216
+ font-size: 0.8rem;
217
+ color: #9ca3af;
218
+ }
219
  footer { display: none !important; }
220
  """
221
 
 
277
  gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
278
  with gr.Row(equal_height=True):
279
  with gr.Column(elem_classes="output-col"):
280
+ gr.HTML(
281
+ f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["DiNO"]}" target="_blank">DiNO (S/16)</a></div>')
282
  out_dino_s = gr.Image(show_label=False, interactive=False)
283
  with gr.Column(elem_classes="output-col"):
284
+ gr.HTML(
285
+ f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["iBOT"]}" target="_blank">iBOT (S/16)</a></div>')
286
  out_ibot_s = gr.Image(show_label=False, interactive=False)
287
  with gr.Column(elem_classes="output-col"):
288
+ gr.HTML(
289
+ f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["LeJEPA"]}" target="_blank">LeJEPA (S/16)</a></div>')
290
  out_lejepa_s = gr.Image(show_label=False, interactive=False)
291
 
292
  # ── ViT-B/16 Row ──
293
  gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
294
  with gr.Row(equal_height=True):
295
  with gr.Column(elem_classes="output-col"):
296
+ gr.HTML(
297
+ f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["DiNO"]}" target="_blank">DiNO (B/16)</a></div>')
298
  out_dino_b = gr.Image(show_label=False, interactive=False)
299
  with gr.Column(elem_classes="output-col"):
300
+ gr.HTML(
301
+ f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["iBOT"]}" target="_blank">iBOT (B/16)</a></div>')
302
  out_ibot_b = gr.Image(show_label=False, interactive=False)
303
  with gr.Column(elem_classes="output-col"):
304
+ gr.HTML(
305
+ f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["LeJEPA"]}" target="_blank">LeJEPA (B/16)</a></div>')
306
  out_lejepa_b = gr.Image(show_label=False, interactive=False)
307
 
308
+ # Custom Clean Footer layout containing links to organization and codebase
309
+ gr.HTML("""
310
+ <div class="custom-footer">
311
+ Models: <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI on HuggingFace</a>
312
+ &nbsp;Β·&nbsp;
313
+ Code: <a href="https://github.com/Open-Knowledge-AI/lite_ssl" target="_blank">lite_ssl Github</a>
314
+ </div>
315
+ """)
316
+
317
  # Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
318
  output_targets = [
319
  out_dino_s, out_ibot_s, out_lejepa_s,