Spaces:
Running
Running
Refactor
Browse files
app.py
CHANGED
|
@@ -59,7 +59,7 @@ def get_image_features(model, image_dir):
|
|
| 59 |
|
| 60 |
loader = torch.utils.data.DataLoader(
|
| 61 |
dataset,
|
| 62 |
-
batch_size=
|
| 63 |
shuffle=False,
|
| 64 |
num_workers=4,
|
| 65 |
drop_last=False,
|
|
@@ -103,7 +103,8 @@ def text_encoder(text, tokenizer):
|
|
| 103 |
return jnp.expand_dims(embedding, axis=0)
|
| 104 |
|
| 105 |
|
| 106 |
-
|
|
|
|
| 107 |
image_features = []
|
| 108 |
for i, (images) in enumerate(tqdm(loader)):
|
| 109 |
images = images.permute(0, 2, 3, 1).numpy()
|
|
@@ -145,8 +146,32 @@ if query:
|
|
| 145 |
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
|
| 146 |
)
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
image_paths = find_image(query, dataset, tokenizer, image_features, n=
|
| 151 |
|
| 152 |
st.image(image_paths)
|
|
|
|
| 59 |
|
| 60 |
loader = torch.utils.data.DataLoader(
|
| 61 |
dataset,
|
| 62 |
+
batch_size=16,
|
| 63 |
shuffle=False,
|
| 64 |
num_workers=4,
|
| 65 |
drop_last=False,
|
|
|
|
| 103 |
return jnp.expand_dims(embedding, axis=0)
|
| 104 |
|
| 105 |
|
| 106 |
+
@st.cache
|
| 107 |
+
def precompute_image_features(model, loader):
|
| 108 |
image_features = []
|
| 109 |
for i, (images) in enumerate(tqdm(loader)):
|
| 110 |
images = images.permute(0, 2, 3, 1).numpy()
|
|
|
|
| 146 |
"dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
|
| 147 |
)
|
| 148 |
|
| 149 |
+
image_size = model.config.vision_config.image_size
|
| 150 |
+
|
| 151 |
+
val_preprocess = transforms.Compose(
|
| 152 |
+
[
|
| 153 |
+
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
| 154 |
+
CenterCrop(image_size),
|
| 155 |
+
ToTensor(),
|
| 156 |
+
Normalize(
|
| 157 |
+
(0.48145466, 0.4578275, 0.40821073),
|
| 158 |
+
(0.26862954, 0.26130258, 0.27577711),
|
| 159 |
+
),
|
| 160 |
+
]
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
dataset = CustomDataSet("photos/", transform=val_preprocess)
|
| 164 |
+
|
| 165 |
+
loader = torch.utils.data.DataLoader(
|
| 166 |
+
dataset,
|
| 167 |
+
batch_size=16,
|
| 168 |
+
shuffle=False,
|
| 169 |
+
num_workers=2,
|
| 170 |
+
drop_last=False,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
image_features = precompute_image_features(model, loader)
|
| 174 |
|
| 175 |
+
image_paths = find_image(query, dataset, tokenizer, image_features, n=2)
|
| 176 |
|
| 177 |
st.image(image_paths)
|