Tenbatsu24 commited on
Commit Β·
69eec85
1
Parent(s): 3cd97ec
add: vitb16 support and more user inputs.
Browse files
app.py
CHANGED
|
@@ -14,42 +14,51 @@ from sklearn.decomposition import PCA
|
|
| 14 |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 15 |
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 16 |
|
| 17 |
-
IMAGE_SIZE = 672
|
| 18 |
PATCH_SIZE = 16
|
| 19 |
PCA_COMPONENTS = 3
|
| 20 |
|
| 21 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
|
| 23 |
MODEL_IDS = {
|
| 24 |
-
"
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
}
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
# ββ model loading (cached) ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
|
| 32 |
_model_cache: dict[str, torch.nn.Module] = {}
|
| 33 |
|
| 34 |
|
| 35 |
-
def get_model(
|
| 36 |
-
|
|
|
|
| 37 |
model = AutoModel.from_pretrained(
|
| 38 |
-
|
|
|
|
| 39 |
trust_remote_code=True,
|
| 40 |
)
|
| 41 |
model.eval().to(DEVICE)
|
| 42 |
-
_model_cache[
|
| 43 |
-
return _model_cache[
|
| 44 |
|
| 45 |
|
| 46 |
# ββ image helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
|
| 48 |
|
| 49 |
def resize_image_for_patches(
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
) -> torch.Tensor:
|
| 54 |
"""Resize so height = image_size and width is patch-aligned,
|
| 55 |
preserving aspect ratio. Returns (1, 3, H, W) float tensor."""
|
|
@@ -71,12 +80,12 @@ def preprocess(image_tensor: torch.Tensor) -> torch.Tensor:
|
|
| 71 |
).unsqueeze(0)
|
| 72 |
|
| 73 |
|
| 74 |
-
def pad_to_square(img: Image.Image) -> Image.Image:
|
| 75 |
"""Letterbox/pillarbox img onto a square canvas with a dark background.
|
| 76 |
Ensures all output images share the same dimensions so the Gradio row
|
| 77 |
never reflows or stretches when aspect ratios differ."""
|
| 78 |
w, h = img.size
|
| 79 |
-
size = max(w, h)
|
| 80 |
canvas = Image.new("RGB", (size, size), color=(18, 18, 18))
|
| 81 |
canvas.paste(img, ((size - w) // 2, (size - h) // 2))
|
| 82 |
return canvas
|
|
@@ -85,7 +94,7 @@ def pad_to_square(img: Image.Image) -> Image.Image:
|
|
| 85 |
# ββ PCA visualisation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 86 |
|
| 87 |
|
| 88 |
-
def pca_vis(model: torch.nn.Module, image_tensor: torch.Tensor) -> Image.Image:
|
| 89 |
"""Run image through model, PCA patch features β square-padded RGB PIL image."""
|
| 90 |
model_input = preprocess(image_tensor).to(DEVICE)
|
| 91 |
|
|
@@ -107,30 +116,49 @@ def pca_vis(model: torch.nn.Module, image_tensor: torch.Tensor) -> Image.Image:
|
|
| 107 |
|
| 108 |
# nearest-neighbour upscale β pad to square so all outputs are the same size
|
| 109 |
upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST)
|
| 110 |
-
return pad_to_square(upscaled)
|
| 111 |
|
| 112 |
|
| 113 |
# ββ streaming inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
|
| 115 |
-
PENDING = Image.new("RGB", (IMAGE_SIZE, IMAGE_SIZE), color=(18, 18, 18))
|
| 116 |
-
|
| 117 |
|
| 118 |
-
def run(pil_image: Image.Image):
|
| 119 |
"""
|
| 120 |
-
Generator: yields
|
| 121 |
-
finishes, so the UI updates one image at a time.
|
| 122 |
"""
|
| 123 |
if pil_image is None:
|
| 124 |
raise gr.Error("Please upload an image.")
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
pil_image = pil_image.convert("RGB")
|
| 127 |
-
image_tensor = resize_image_for_patches(pil_image)
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
results[i] = pca_vis(model, image_tensor)
|
| 133 |
-
yield tuple(results)
|
| 134 |
|
| 135 |
|
| 136 |
# ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -146,6 +174,14 @@ CSS = """
|
|
| 146 |
font-size: 0.9rem;
|
| 147 |
padding-bottom: 1rem;
|
| 148 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
.model-label {
|
| 150 |
text-align: center;
|
| 151 |
font-weight: 600;
|
|
@@ -153,11 +189,20 @@ CSS = """
|
|
| 153 |
color: #374151;
|
| 154 |
padding: 0.25rem 0;
|
| 155 |
}
|
|
|
|
| 156 |
.output-col {
|
| 157 |
-
display: flex;
|
| 158 |
-
flex-direction: column;
|
| 159 |
-
align-items: center;
|
| 160 |
-
gap: 0.25rem;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
}
|
| 162 |
.subtitle-row a, .model-label a {
|
| 163 |
color: inherit;
|
|
@@ -171,7 +216,6 @@ footer { display: none !important; }
|
|
| 171 |
"""
|
| 172 |
|
| 173 |
with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
|
| 174 |
-
|
| 175 |
gr.HTML("""
|
| 176 |
<div class="title-row">
|
| 177 |
<h1 style="font-size:1.6rem; font-weight:700; margin:0;">
|
|
@@ -179,10 +223,8 @@ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
|
|
| 179 |
</h1>
|
| 180 |
</div>
|
| 181 |
<div class="subtitle-row">
|
| 182 |
-
|
| 183 |
-
<a href="https://huggingface.co/OK-AI
|
| 184 |
-
<a href="https://huggingface.co/OK-AI/ibot-vits16-pretrain-in1k" target="_blank">iBOT</a> Β·
|
| 185 |
-
<a href="https://huggingface.co/OK-AI/lejepa-vits16-pretrain-in1k" target="_blank">LeJEPA</a>
|
| 186 |
</div>
|
| 187 |
""")
|
| 188 |
|
|
@@ -193,49 +235,87 @@ with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
|
|
| 193 |
label="Input image",
|
| 194 |
show_label=True,
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
run_btn = gr.Button("Visualise", variant="primary")
|
|
|
|
| 197 |
gr.HTML("""
|
| 198 |
<p style="font-size:0.8rem; color:#9ca3af; margin-top:0.5rem; line-height:1.5;">
|
| 199 |
-
|
| 200 |
-
before inference. PCA is fit on all patch tokens and projected to
|
| 201 |
3 components, then scaled with sigmoid for colour display.
|
| 202 |
-
Results
|
| 203 |
-
</p>
|
| 204 |
-
|
| 205 |
-
<p style="font-size:0.75rem; color:#9ca3af; margin-top:0.25rem;">
|
| 206 |
-
Models: <a href="https://huggingface.co/OK-AI" target="_blank">OK-AI on HuggingFace</a>
|
| 207 |
-
Β·
|
| 208 |
-
Code: <a href="https://github.com/Open-Knowledge-AI/lite_ssl" target="_blank">lite_ssl</a>
|
| 209 |
</p>
|
| 210 |
""")
|
| 211 |
|
| 212 |
with gr.Column(scale=3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
with gr.Row(equal_height=True):
|
| 214 |
with gr.Column(elem_classes="output-col"):
|
| 215 |
-
gr.HTML('<div class="model-label">
|
| 216 |
-
|
| 217 |
with gr.Column(elem_classes="output-col"):
|
| 218 |
-
gr.HTML('<div class="model-label">
|
| 219 |
-
|
| 220 |
with gr.Column(elem_classes="output-col"):
|
| 221 |
-
gr.HTML('<div class="model-label">
|
| 222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
run_btn.click(
|
| 225 |
fn=run,
|
| 226 |
-
inputs=[input_image],
|
| 227 |
-
outputs=
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
gr.Examples(
|
| 231 |
-
examples=[
|
| 232 |
-
[f"examples/{f}"]
|
| 233 |
-
for f in sorted(os.listdir("examples"))
|
| 234 |
-
if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
|
| 235 |
-
],
|
| 236 |
-
inputs=[input_image],
|
| 237 |
)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
if __name__ == "__main__":
|
| 241 |
demo.launch()
|
|
|
|
| 14 |
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 15 |
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 16 |
|
|
|
|
| 17 |
PATCH_SIZE = 16
|
| 18 |
PCA_COMPONENTS = 3
|
| 19 |
|
| 20 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
MODEL_IDS = {
|
| 23 |
+
"ViT-S/16": {
|
| 24 |
+
"DiNO": "OK-AI/dino-vits16-pretrain-in1k",
|
| 25 |
+
"iBOT": "OK-AI/ibot-vits16-pretrain-in1k",
|
| 26 |
+
"LeJEPA": "OK-AI/lejepa-vits16-pretrain-in1k",
|
| 27 |
+
},
|
| 28 |
+
"ViT-B/16": {
|
| 29 |
+
"DiNO": "OK-AI/dino-vitb16-pretrain-in1k",
|
| 30 |
+
"iBOT": "OK-AI/ibot-vitb16-pretrain-in1k",
|
| 31 |
+
"LeJEPA": "OK-AI/lejepa-vitb16-pretrain-in1k",
|
| 32 |
+
}
|
| 33 |
}
|
| 34 |
+
|
| 35 |
+
MODEL_KEYS = ["DiNO", "iBOT", "LeJEPA"]
|
| 36 |
|
| 37 |
# ββ model loading (cached) ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
|
| 39 |
_model_cache: dict[str, torch.nn.Module] = {}
|
| 40 |
|
| 41 |
|
| 42 |
+
def get_model(repo_id: str, revision: str) -> torch.nn.Module:
|
| 43 |
+
cache_key = f"{repo_id}@{revision}"
|
| 44 |
+
if cache_key not in _model_cache:
|
| 45 |
model = AutoModel.from_pretrained(
|
| 46 |
+
repo_id,
|
| 47 |
+
revision=revision,
|
| 48 |
trust_remote_code=True,
|
| 49 |
)
|
| 50 |
model.eval().to(DEVICE)
|
| 51 |
+
_model_cache[cache_key] = model
|
| 52 |
+
return _model_cache[cache_key]
|
| 53 |
|
| 54 |
|
| 55 |
# ββ image helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
|
| 57 |
|
| 58 |
def resize_image_for_patches(
|
| 59 |
+
image: Image.Image,
|
| 60 |
+
image_size: int,
|
| 61 |
+
patch_size: int = PATCH_SIZE,
|
| 62 |
) -> torch.Tensor:
|
| 63 |
"""Resize so height = image_size and width is patch-aligned,
|
| 64 |
preserving aspect ratio. Returns (1, 3, H, W) float tensor."""
|
|
|
|
| 80 |
).unsqueeze(0)
|
| 81 |
|
| 82 |
|
| 83 |
+
def pad_to_square(img: Image.Image, canvas_size: int) -> Image.Image:
|
| 84 |
"""Letterbox/pillarbox img onto a square canvas with a dark background.
|
| 85 |
Ensures all output images share the same dimensions so the Gradio row
|
| 86 |
never reflows or stretches when aspect ratios differ."""
|
| 87 |
w, h = img.size
|
| 88 |
+
size = max(w, h, canvas_size)
|
| 89 |
canvas = Image.new("RGB", (size, size), color=(18, 18, 18))
|
| 90 |
canvas.paste(img, ((size - w) // 2, (size - h) // 2))
|
| 91 |
return canvas
|
|
|
|
| 94 |
# ββ PCA visualisation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 95 |
|
| 96 |
|
| 97 |
+
def pca_vis(model: torch.nn.Module, image_tensor: torch.Tensor, canvas_size: int) -> Image.Image:
|
| 98 |
"""Run image through model, PCA patch features β square-padded RGB PIL image."""
|
| 99 |
model_input = preprocess(image_tensor).to(DEVICE)
|
| 100 |
|
|
|
|
| 116 |
|
| 117 |
# nearest-neighbour upscale β pad to square so all outputs are the same size
|
| 118 |
upscaled = Image.fromarray(pca_array, mode="RGB").resize((W, H), Image.NEAREST)
|
| 119 |
+
return pad_to_square(upscaled, canvas_size)
|
| 120 |
|
| 121 |
|
| 122 |
# ββ streaming inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
def run(pil_image: Image.Image, epoch: str, weight_type: str, image_size: int):
|
| 126 |
"""
|
| 127 |
+
Generator: yields updates sequentially across models and sizes.
|
|
|
|
| 128 |
"""
|
| 129 |
if pil_image is None:
|
| 130 |
raise gr.Error("Please upload an image.")
|
| 131 |
|
| 132 |
+
image_size = int(image_size)
|
| 133 |
+
pending_img = Image.new("RGB", (image_size, image_size), color=(18, 18, 18))
|
| 134 |
+
|
| 135 |
+
# 6 total positions: ViT-S [dino, ibot, lejepa], ViT-B [dino, ibot, lejepa]
|
| 136 |
+
results = [pending_img] * 6
|
| 137 |
+
yield tuple(results)
|
| 138 |
+
|
| 139 |
pil_image = pil_image.convert("RGB")
|
| 140 |
+
image_tensor = resize_image_for_patches(pil_image, image_size)
|
| 141 |
+
|
| 142 |
+
idx = 0
|
| 143 |
+
for arch in ["ViT-S/16", "ViT-B/16"]:
|
| 144 |
+
for model_key in MODEL_KEYS:
|
| 145 |
+
repo_id = MODEL_IDS[arch][model_key]
|
| 146 |
+
|
| 147 |
+
# LeJEPA only supports student weights
|
| 148 |
+
current_weight = "student" if model_key == "LeJEPA" else weight_type
|
| 149 |
+
revision = f"{epoch}/{current_weight}"
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
model = get_model(repo_id, revision)
|
| 153 |
+
results[idx] = pca_vis(model, image_tensor, image_size)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Error processing {repo_id} ({revision}): {e}")
|
| 156 |
+
# Create an error placeholder card if a model/revision download fails
|
| 157 |
+
error_canvas = Image.new("RGB", (image_size, image_size), color=(40, 20, 20))
|
| 158 |
+
results[idx] = error_canvas
|
| 159 |
|
| 160 |
+
yield tuple(results)
|
| 161 |
+
idx += 1
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
# ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 174 |
font-size: 0.9rem;
|
| 175 |
padding-bottom: 1rem;
|
| 176 |
}
|
| 177 |
+
.arch-header {
|
| 178 |
+
font-size: 1.2rem;
|
| 179 |
+
font-weight: 700;
|
| 180 |
+
margin-top: 1rem;
|
| 181 |
+
padding-left: 0.5rem;
|
| 182 |
+
border-left: 4px solid #3b82f6;
|
| 183 |
+
color: #1f2937;
|
| 184 |
+
}
|
| 185 |
.model-label {
|
| 186 |
text-align: center;
|
| 187 |
font-weight: 600;
|
|
|
|
| 189 |
color: #374151;
|
| 190 |
padding: 0.25rem 0;
|
| 191 |
}
|
| 192 |
+
/* Ensure strict rigid layouts for outputs to avoid layout shifting */
|
| 193 |
.output-col {
|
| 194 |
+
display: flex !important;
|
| 195 |
+
flex-direction: column !important;
|
| 196 |
+
align-items: center !important;
|
| 197 |
+
gap: 0.25rem !important;
|
| 198 |
+
flex: 1 1 0% !important;
|
| 199 |
+
min-width: 150px !important;
|
| 200 |
+
}
|
| 201 |
+
.output-col img {
|
| 202 |
+
aspect-ratio: 1 / 1 !important;
|
| 203 |
+
object-fit: contain !important;
|
| 204 |
+
max-height: 350px !important;
|
| 205 |
+
width: 100% !important;
|
| 206 |
}
|
| 207 |
.subtitle-row a, .model-label a {
|
| 208 |
color: inherit;
|
|
|
|
| 216 |
"""
|
| 217 |
|
| 218 |
with gr.Blocks(css=CSS, title="SSL ViT PCA Visualiser") as demo:
|
|
|
|
| 219 |
gr.HTML("""
|
| 220 |
<div class="title-row">
|
| 221 |
<h1 style="font-size:1.6rem; font-weight:700; margin:0;">
|
|
|
|
| 223 |
</h1>
|
| 224 |
</div>
|
| 225 |
<div class="subtitle-row">
|
| 226 |
+
ImageNet-1K pre-training Β·
|
| 227 |
+
<a href="https://huggingface.co/OK-AI" target="_blank">OK-AI Models</a>
|
|
|
|
|
|
|
| 228 |
</div>
|
| 229 |
""")
|
| 230 |
|
|
|
|
| 235 |
label="Input image",
|
| 236 |
show_label=True,
|
| 237 |
)
|
| 238 |
+
|
| 239 |
+
with gr.Row():
|
| 240 |
+
opt_epoch = gr.Dropdown(
|
| 241 |
+
choices=["ep100", "ep300"],
|
| 242 |
+
value="ep300",
|
| 243 |
+
label="Epochs",
|
| 244 |
+
interactive=True
|
| 245 |
+
)
|
| 246 |
+
opt_weight = gr.Dropdown(
|
| 247 |
+
choices=["student", "teacher"],
|
| 248 |
+
value="teacher",
|
| 249 |
+
label="Weight Type",
|
| 250 |
+
info="LeJEPA always uses student",
|
| 251 |
+
interactive=True
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
opt_size = gr.Dropdown(
|
| 255 |
+
choices=["224", "448", "672", "1280"],
|
| 256 |
+
value="672",
|
| 257 |
+
label="Image Target Resolution",
|
| 258 |
+
interactive=True
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
run_btn = gr.Button("Visualise", variant="primary")
|
| 262 |
+
|
| 263 |
gr.HTML("""
|
| 264 |
<p style="font-size:0.8rem; color:#9ca3af; margin-top:0.5rem; line-height:1.5;">
|
| 265 |
+
PCA is fit on all patch tokens and projected to
|
|
|
|
| 266 |
3 components, then scaled with sigmoid for colour display.
|
| 267 |
+
Results stream seamlessly into view as individual variants complete.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
</p>
|
| 269 |
""")
|
| 270 |
|
| 271 |
with gr.Column(scale=3):
|
| 272 |
+
# ββ ViT-S/16 Row ββ
|
| 273 |
+
gr.HTML('<div class="arch-header">ViT-S/16 Grid</div>')
|
| 274 |
+
with gr.Row(equal_height=True):
|
| 275 |
+
with gr.Column(elem_classes="output-col"):
|
| 276 |
+
gr.HTML('<div class="model-label">DiNO (S/16)</div>')
|
| 277 |
+
out_dino_s = gr.Image(show_label=False, interactive=False)
|
| 278 |
+
with gr.Column(elem_classes="output-col"):
|
| 279 |
+
gr.HTML('<div class="model-label">iBOT (S/16)</div>')
|
| 280 |
+
out_ibot_s = gr.Image(show_label=False, interactive=False)
|
| 281 |
+
with gr.Column(elem_classes="output-col"):
|
| 282 |
+
gr.HTML('<div class="model-label">LeJEPA (S/16)</div>')
|
| 283 |
+
out_lejepa_s = gr.Image(show_label=False, interactive=False)
|
| 284 |
+
|
| 285 |
+
# ββ ViT-B/16 Row ββ
|
| 286 |
+
gr.HTML('<div class="arch-header">ViT-B/16 Grid</div>')
|
| 287 |
with gr.Row(equal_height=True):
|
| 288 |
with gr.Column(elem_classes="output-col"):
|
| 289 |
+
gr.HTML('<div class="model-label">DiNO (B/16)</div>')
|
| 290 |
+
out_dino_b = gr.Image(show_label=False, interactive=False)
|
| 291 |
with gr.Column(elem_classes="output-col"):
|
| 292 |
+
gr.HTML('<div class="model-label">iBOT (B/16)</div>')
|
| 293 |
+
out_ibot_b = gr.Image(show_label=False, interactive=False)
|
| 294 |
with gr.Column(elem_classes="output-col"):
|
| 295 |
+
gr.HTML('<div class="model-label">LeJEPA (B/16)</div>')
|
| 296 |
+
out_lejepa_b = gr.Image(show_label=False, interactive=False)
|
| 297 |
+
|
| 298 |
+
# Wire outputs orderly following the exact resolution pattern tracking inside the `run` loop
|
| 299 |
+
output_targets = [
|
| 300 |
+
out_dino_s, out_ibot_s, out_lejepa_s,
|
| 301 |
+
out_dino_b, out_ibot_b, out_lejepa_b
|
| 302 |
+
]
|
| 303 |
|
| 304 |
run_btn.click(
|
| 305 |
fn=run,
|
| 306 |
+
inputs=[input_image, opt_epoch, opt_weight, opt_size],
|
| 307 |
+
outputs=output_targets,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
)
|
| 309 |
|
| 310 |
+
if os.path.exists("examples"):
|
| 311 |
+
gr.Examples(
|
| 312 |
+
examples=[
|
| 313 |
+
[f"examples/{f}", "ep300", "teacher", "672"]
|
| 314 |
+
for f in sorted(os.listdir("examples"))
|
| 315 |
+
if f.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
|
| 316 |
+
],
|
| 317 |
+
inputs=[input_image, opt_epoch, opt_weight, opt_size],
|
| 318 |
+
)
|
| 319 |
|
| 320 |
if __name__ == "__main__":
|
| 321 |
demo.launch()
|