Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
| 3 |
import spaces
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from PIL import Image, ImageDraw
|
|
@@ -54,21 +55,22 @@ def resize(img, target_res=224, resize=True, to_pil=True, edge=False, sampling_f
|
|
| 54 |
return canvas
|
| 55 |
|
| 56 |
# ─── Feature extraction ──────────────────────────────────────────
|
| 57 |
-
@spaces.GPU(duration=
|
| 58 |
def get_processed_features_dino(num_patches, img,use_dummy):
|
| 59 |
-
batch = extractor_vit.preprocess_pil(img)
|
| 60 |
-
features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
|
| 61 |
-
.permute(0,1,3,2) \
|
| 62 |
-
.reshape(1, -1, num_patches, num_patches)
|
| 63 |
-
# Project + normalize
|
| 64 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
if use_dummy == "DINOv2":
|
| 66 |
desc = aggre_net_dummy(features_dino)
|
| 67 |
else:
|
| 68 |
desc = aggre_net(features_dino)
|
| 69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
| 70 |
desc = desc / (norms + 1e-8)
|
| 71 |
-
desc = desc.cpu()
|
|
|
|
|
|
|
| 72 |
torch.cuda.empty_cache()
|
| 73 |
return desc # shape [1, C, num_patches, num_patches]
|
| 74 |
|
|
@@ -86,7 +88,6 @@ def get_sim(
|
|
| 86 |
y, x = coord # row, col
|
| 87 |
|
| 88 |
# Upsample both feature maps to [1, C, img_size, img_size]
|
| 89 |
-
upsampler = nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=False)
|
| 90 |
src_ft = upsampler(feat1) # [1, C, img_size, img_size]
|
| 91 |
trg_ft = upsampler(feat2)
|
| 92 |
|
|
@@ -176,7 +177,7 @@ def reload_img(
|
|
| 176 |
|
| 177 |
|
| 178 |
# ─── Configuration ───────────────────────────────────────────────
|
| 179 |
-
num_patches =
|
| 180 |
target_res = num_patches * 14
|
| 181 |
ckpt_file = "./ckpts/dino_spair_0300.pth"
|
| 182 |
|
|
@@ -188,6 +189,11 @@ aggre_net.load_pretrained_weights(torch.load(ckpt_file, map_location=device))
|
|
| 188 |
aggre_net_dummy = DummyAggregationNetwork()
|
| 189 |
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
# ─── Build Gradio UI ──────────────────────────────────────────────
|
| 192 |
with gr.Blocks() as demo:
|
| 193 |
# Hidden states to hold features
|
|
|
|
| 3 |
import spaces
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
+
import gc
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
from PIL import Image, ImageDraw
|
|
|
|
| 55 |
return canvas
|
| 56 |
|
| 57 |
# ─── Feature extraction ──────────────────────────────────────────
|
| 58 |
+
@spaces.GPU(duration=0)
|
| 59 |
def get_processed_features_dino(num_patches, img,use_dummy):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
with torch.no_grad():
|
| 61 |
+
batch = extractor_vit.preprocess_pil(img)
|
| 62 |
+
features_dino = extractor_vit.extract_descriptors(batch.to(extractor_vit.device), layer=11, facet='token') \
|
| 63 |
+
.permute(0,1,3,2) \
|
| 64 |
+
.reshape(1, -1, num_patches, num_patches)
|
| 65 |
if use_dummy == "DINOv2":
|
| 66 |
desc = aggre_net_dummy(features_dino)
|
| 67 |
else:
|
| 68 |
desc = aggre_net(features_dino)
|
| 69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
| 70 |
desc = desc / (norms + 1e-8)
|
| 71 |
+
desc = desc.cpu().detach()
|
| 72 |
+
del batch, features_dino
|
| 73 |
+
gc.collect()
|
| 74 |
torch.cuda.empty_cache()
|
| 75 |
return desc # shape [1, C, num_patches, num_patches]
|
| 76 |
|
|
|
|
| 88 |
y, x = coord # row, col
|
| 89 |
|
| 90 |
# Upsample both feature maps to [1, C, img_size, img_size]
|
|
|
|
| 91 |
src_ft = upsampler(feat1) # [1, C, img_size, img_size]
|
| 92 |
trg_ft = upsampler(feat2)
|
| 93 |
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
# ─── Configuration ───────────────────────────────────────────────
|
| 180 |
+
num_patches = 30
|
| 181 |
target_res = num_patches * 14
|
| 182 |
ckpt_file = "./ckpts/dino_spair_0300.pth"
|
| 183 |
|
|
|
|
| 189 |
aggre_net_dummy = DummyAggregationNetwork()
|
| 190 |
extractor_vit = ViTExtractor('dinov2_vitb14', stride=14, device=device)
|
| 191 |
|
| 192 |
+
aggre_net = aggre_net.eval()
|
| 193 |
+
extractor_vit.model.eval()
|
| 194 |
+
|
| 195 |
+
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
|
| 196 |
+
|
| 197 |
# ─── Build Gradio UI ──────────────────────────────────────────────
|
| 198 |
with gr.Blocks() as demo:
|
| 199 |
# Hidden states to hold features
|