Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -68,8 +68,8 @@ def get_processed_features_dino(num_patches, img,use_dummy):
|
|
| 68 |
desc = aggre_net(features_dino)
|
| 69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
| 70 |
desc = desc / (norms + 1e-8)
|
| 71 |
-
desc = desc.cpu()
|
| 72 |
-
del batch, features_dino
|
| 73 |
gc.collect()
|
| 74 |
torch.cuda.empty_cache()
|
| 75 |
return desc # shape [1, C, num_patches, num_patches]
|
|
@@ -114,6 +114,7 @@ def update_features(
|
|
| 114 |
num_patches,
|
| 115 |
use_dummy
|
| 116 |
):
|
|
|
|
| 117 |
torch.cuda.empty_cache()
|
| 118 |
"""
|
| 119 |
Given a PIL image, returns:
|
|
@@ -138,6 +139,8 @@ def on_select(
|
|
| 138 |
or_src_img: Image,
|
| 139 |
sel: gr.SelectData
|
| 140 |
):
|
|
|
|
|
|
|
| 141 |
# Convert to numpy arrays
|
| 142 |
src_arr = np.array(or_src_img)
|
| 143 |
tgt_arr = np.array(or_tgt_img)
|
|
@@ -194,6 +197,9 @@ extractor_vit.model.eval()
|
|
| 194 |
|
| 195 |
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
|
| 196 |
|
|
|
|
|
|
|
|
|
|
| 197 |
# ─── Build Gradio UI ──────────────────────────────────────────────
|
| 198 |
with gr.Blocks() as demo:
|
| 199 |
# Hidden states to hold features
|
|
|
|
| 68 |
desc = aggre_net(features_dino)
|
| 69 |
norms = torch.linalg.norm(desc, dim=1, keepdim=True)
|
| 70 |
desc = desc / (norms + 1e-8)
|
| 71 |
+
desc = desc.cpu()
|
| 72 |
+
# del batch, features_dino
|
| 73 |
gc.collect()
|
| 74 |
torch.cuda.empty_cache()
|
| 75 |
return desc # shape [1, C, num_patches, num_patches]
|
|
|
|
| 114 |
num_patches,
|
| 115 |
use_dummy
|
| 116 |
):
|
| 117 |
+
gc.collect()
|
| 118 |
torch.cuda.empty_cache()
|
| 119 |
"""
|
| 120 |
Given a PIL image, returns:
|
|
|
|
| 139 |
or_src_img: Image,
|
| 140 |
sel: gr.SelectData
|
| 141 |
):
|
| 142 |
+
gc.collect()
|
| 143 |
+
torch.cuda.empty_cache()
|
| 144 |
# Convert to numpy arrays
|
| 145 |
src_arr = np.array(or_src_img)
|
| 146 |
tgt_arr = np.array(or_tgt_img)
|
|
|
|
| 197 |
|
| 198 |
upsampler = nn.Upsample(size=(target_res, target_res), mode='bilinear', align_corners=False)
|
| 199 |
|
| 200 |
+
gc.collect()
|
| 201 |
+
torch.cuda.empty_cache()
|
| 202 |
+
|
| 203 |
# ─── Build Gradio UI ──────────────────────────────────────────────
|
| 204 |
with gr.Blocks() as demo:
|
| 205 |
# Hidden states to hold features
|