Tenbatsu24 commited on
Commit Β·
d647dfd
1
Parent(s): 69eec85
add: vitb16 support and more user inputs.
Browse files
app.py
CHANGED
|
@@ -132,7 +132,6 @@ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
|
|
| 132 |
image_size = int(image_size)
|
| 133 |
pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
|
| 134 |
|
| 135 |
-
# 6 total positions: ViT-S [dino, ibot, lejepa], ViT-B [dino, ibot, lejepa]
|
| 136 |
results = [pending_img] * 6
|
| 137 |
yield tuple(results)
|
| 138 |
|
|
@@ -144,7 +143,6 @@ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
|
|
| 144 |
for model_key in MODEL_KEYS:
|
| 145 |
repo_id = MODEL_IDS[arch][model_key]
|
| 146 |
|
| 147 |
-
# LeJEPA only supports student weights
|
| 148 |
current_weight = "student" if model_key == "LeJEPA" else weight_type
|
| 149 |
revision = f"{epoch}/{current_weight}"
|
| 150 |
|
|
@@ -153,7 +151,6 @@ def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
|
|
| 153 |
results[idx] = pca_vis(model, image_tensor, image_size)
|
| 154 |
except Exception as e:
|
| 155 |
print(f"Error processing {repo_id} ({revision}): {e}")
|
| 156 |
-
# Create an error placeholder card if a model/revision download fails
|
| 157 |
error_canvas = Image.new("RGB", (image_size, image_size), color=(40, 20, 20))
|
| 158 |
results[idx] = error_canvas
|
| 159 |
|
|
@@ -189,7 +186,6 @@ CSS = """
|
|
| 189 |
color: #374151;
|
| 190 |
padding: 0.25rem 0;
|
| 191 |
}
|
| 192 |
-
/* Ensure strict rigid layouts for outputs to avoid layout shifting */
|
| 193 |
.output-col {
|
| 194 |
display: flex !important;
|
| 195 |
flex-direction: column !important;
|
|
@@ -204,14 +200,22 @@ CSS = """
|
|
| 204 |
max-height: 350px !important;
|
| 205 |
width: 100% !important;
|
| 206 |
}
|
| 207 |
-
.subtitle-row a, .model-label a {
|
| 208 |
color: inherit;
|
| 209 |
text-decoration: underline;
|
| 210 |
text-decoration-color: #d1d5db;
|
| 211 |
}
|
| 212 |
-
.model-label a:hover, .subtitle-row a:hover {
|
| 213 |
text-decoration-color: currentColor;
|
| 214 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
footer { display: none !important; }
|
| 216 |
"""
|
| 217 |
|
|
@@ -273,28 +277,43 @@ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
|
|
| 273 |
gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
|
| 274 |
with gr.Row(equal_height=True):
|
| 275 |
with gr.Column(elem_classes="output-col"):
|
| 276 |
-
gr.HTML(
|
|
|
|
| 277 |
out_dino_s = gr.Image(show_label=False, interactive=False)
|
| 278 |
with gr.Column(elem_classes="output-col"):
|
| 279 |
-
gr.HTML(
|
|
|
|
| 280 |
out_ibot_s = gr.Image(show_label=False, interactive=False)
|
| 281 |
with gr.Column(elem_classes="output-col"):
|
| 282 |
-
gr.HTML(
|
|
|
|
| 283 |
out_lejepa_s = gr.Image(show_label=False, interactive=False)
|
| 284 |
|
| 285 |
# ββ ViT-B/16 Row ββ
|
| 286 |
gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
|
| 287 |
with gr.Row(equal_height=True):
|
| 288 |
with gr.Column(elem_classes="output-col"):
|
| 289 |
-
gr.HTML(
|
|
|
|
| 290 |
out_dino_b = gr.Image(show_label=False, interactive=False)
|
| 291 |
with gr.Column(elem_classes="output-col"):
|
| 292 |
-
gr.HTML(
|
|
|
|
| 293 |
out_ibot_b = gr.Image(show_label=False, interactive=False)
|
| 294 |
with gr.Column(elem_classes="output-col"):
|
| 295 |
-
gr.HTML(
|
|
|
|
| 296 |
out_lejepa_b = gr.Image(show_label=False, interactive=False)
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
# Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
|
| 299 |
output_targets = [
|
| 300 |
out_dino_s, out_ibot_s, out_lejepa_s,
|
|
|
|
| 132 |
image_size = int(image_size)
|
| 133 |
pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
|
| 134 |
|
|
|
|
| 135 |
results = [pending_img] * 6
|
| 136 |
yield tuple(results)
|
| 137 |
|
|
|
|
| 143 |
for model_key in MODEL_KEYS:
|
| 144 |
repo_id = MODEL_IDS[arch][model_key]
|
| 145 |
|
|
|
|
| 146 |
current_weight = "student" if model_key == "LeJEPA" else weight_type
|
| 147 |
revision = f"{epoch}/{current_weight}"
|
| 148 |
|
|
|
|
| 151 |
results[idx] = pca_vis(model, image_tensor, image_size)
|
| 152 |
except Exception as e:
|
| 153 |
print(f"Error processing {repo_id} ({revision}): {e}")
|
|
|
|
| 154 |
error_canvas = Image.new("RGB", (image_size, image_size), color=(40, 20, 20))
|
| 155 |
results[idx] = error_canvas
|
| 156 |
|
|
|
|
| 186 |
color: #374151;
|
| 187 |
padding: 0.25rem 0;
|
| 188 |
}
|
|
|
|
| 189 |
.output-col {
|
| 190 |
display: flex !important;
|
| 191 |
flex-direction: column !important;
|
|
|
|
| 200 |
max-height: 350px !important;
|
| 201 |
width: 100% !important;
|
| 202 |
}
|
| 203 |
+
.subtitle-row a, .model-label a, .custom-footer a {
|
| 204 |
color: inherit;
|
| 205 |
text-decoration: underline;
|
| 206 |
text-decoration-color: #d1d5db;
|
| 207 |
}
|
| 208 |
+
.model-label a:hover, .subtitle-row a:hover, .custom-footer a:hover {
|
| 209 |
text-decoration-color: currentColor;
|
| 210 |
}
|
| 211 |
+
.custom-footer {
|
| 212 |
+
text-align: center;
|
| 213 |
+
margin-top: 2.5rem;
|
| 214 |
+
padding-top: 1rem;
|
| 215 |
+
border-top: 1px solid #e5e7eb;
|
| 216 |
+
font-size: 0.8rem;
|
| 217 |
+
color: #9ca3af;
|
| 218 |
+
}
|
| 219 |
footer { display: none !important; }
|
| 220 |
"""
|
| 221 |
|
|
|
|
| 277 |
gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
|
| 278 |
with gr.Row(equal_height=True):
|
| 279 |
with gr.Column(elem_classes="output-col"):
|
| 280 |
+
gr.HTML(
|
| 281 |
+
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["DiNO"]}" target="_blank">DiNO (S/16)</a></div>')
|
| 282 |
out_dino_s = gr.Image(show_label=False, interactive=False)
|
| 283 |
with gr.Column(elem_classes="output-col"):
|
| 284 |
+
gr.HTML(
|
| 285 |
+
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["iBOT"]}" target="_blank">iBOT (S/16)</a></div>')
|
| 286 |
out_ibot_s = gr.Image(show_label=False, interactive=False)
|
| 287 |
with gr.Column(elem_classes="output-col"):
|
| 288 |
+
gr.HTML(
|
| 289 |
+
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-S/16"]["LeJEPA"]}" target="_blank">LeJEPA (S/16)</a></div>')
|
| 290 |
out_lejepa_s = gr.Image(show_label=False, interactive=False)
|
| 291 |
|
| 292 |
# ββ ViT-B/16 Row ββ
|
| 293 |
gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
|
| 294 |
with gr.Row(equal_height=True):
|
| 295 |
with gr.Column(elem_classes="output-col"):
|
| 296 |
+
gr.HTML(
|
| 297 |
+
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["DiNO"]}" target="_blank">DiNO (B/16)</a></div>')
|
| 298 |
out_dino_b = gr.Image(show_label=False, interactive=False)
|
| 299 |
with gr.Column(elem_classes="output-col"):
|
| 300 |
+
gr.HTML(
|
| 301 |
+
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["iBOT"]}" target="_blank">iBOT (B/16)</a></div>')
|
| 302 |
out_ibot_b = gr.Image(show_label=False, interactive=False)
|
| 303 |
with gr.Column(elem_classes="output-col"):
|
| 304 |
+
gr.HTML(
|
| 305 |
+
f'<div class="model-label"><a href="https://huggingface.co/{MODEL_IDS["ViT-B/16"]["LeJEPA"]}" target="_blank">LeJEPA (B/16)</a></div>')
|
| 306 |
out_lejepa_b = gr.Image(show_label=False, interactive=False)
|
| 307 |
|
| 308 |
+
# Custom Clean Footer layout containing links to organization and codebase
|
| 309 |
+
gr.HTML("""
|
| 310 |
+
<div class="custom-footer">
|
| 311 |
+
Models: <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI on HuggingFace</a>
|
| 312 |
+
Β·
|
| 313 |
+
Code: <a href="https://github.com/Open-Knowledge-AI/lite_ssl" target="_blank">lite_ssl Github</a>
|
| 314 |
+
</div>
|
| 315 |
+
""")
|
| 316 |
+
|
| 317 |
# Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
|
| 318 |
output_targets = [
|
| 319 |
out_dino_s, out_ibot_s, out_lejepa_s,
|