Spaces:
Running
Running
fix streamlit issues and update localization
Browse files- localization.py +121 -51
- modeling_hybrid_clip.py +3 -1
- requirements.txt +1 -1
localization.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from text2image import get_model, get_tokenizer, get_image_transform
|
| 3 |
from utils import text_encoder
|
| 4 |
-
from
|
| 5 |
from PIL import Image
|
| 6 |
from jax import numpy as jnp
|
| 7 |
import pandas as pd
|
|
@@ -13,30 +13,34 @@ import jax
|
|
| 13 |
import gc
|
| 14 |
|
| 15 |
|
| 16 |
-
preprocess =
|
| 17 |
-
[
|
| 18 |
-
transforms.ToTensor(),
|
| 19 |
-
transforms.Normalize(
|
| 20 |
-
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
|
| 21 |
-
),
|
| 22 |
-
]
|
| 23 |
-
)
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
image = image.resize(new_size, Image.ANTIALIAS)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def image_encoder(image, model):
|
| 36 |
image = np.transpose(image, (0, 2, 3, 1))
|
| 37 |
features = model.get_image_features(image)
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
|
@@ -44,64 +48,130 @@ def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
|
| 44 |
|
| 45 |
image_batch = []
|
| 46 |
masks = []
|
|
|
|
|
|
|
|
|
|
| 47 |
image_raw = requests.get(image_url, stream=True).raw
|
| 48 |
image = Image.open(image_raw).convert("RGB")
|
| 49 |
-
image =
|
| 50 |
-
gray = np.ones_like(image) *
|
| 51 |
-
mask = np.ones_like(image)
|
| 52 |
|
| 53 |
image_batch.append(image)
|
| 54 |
masks.append(mask)
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
for i in range(0,
|
| 57 |
-
for j in range(i
|
| 58 |
m = mask.copy()
|
| 59 |
-
m[:
|
| 60 |
-
m[min(j
|
| 61 |
neg_m = 1 - m
|
| 62 |
-
image_batch.append(image * m + gray * neg_m)
|
| 63 |
masks.append(m)
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
for i in range(0,
|
| 66 |
-
for j in range(i
|
| 67 |
m = mask.copy()
|
| 68 |
-
m[:, :
|
| 69 |
-
m[:, min(j
|
| 70 |
neg_m = 1 - m
|
| 71 |
-
image_batch.append(image * m + gray * neg_m)
|
| 72 |
masks.append(m)
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
return image_batch, masks
|
| 75 |
|
| 76 |
|
| 77 |
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
| 78 |
-
tokenizer = get_tokenizer()
|
| 79 |
model = get_model()
|
| 80 |
image_size = model.config.vision_config.image_size
|
| 81 |
-
text_embedding = text_encoder(text, model, tokenizer)
|
| 82 |
-
images, masks = gen_image_batch(
|
| 83 |
-
image_url, image_size=image_size, pixel_size=pixel_size
|
| 84 |
-
)
|
| 85 |
|
|
|
|
| 86 |
input_image = images[0].copy()
|
| 87 |
-
images = np.stack([preprocess(image) for image in images], axis=0)
|
| 88 |
-
image_embeddings = jnp.asarray(image_encoder(images, model))
|
| 89 |
-
|
| 90 |
-
sims = []
|
| 91 |
-
scores = []
|
| 92 |
-
mask_val = jnp.zeros_like(masks[0])
|
| 93 |
-
|
| 94 |
-
for e, m in zip(image_embeddings, masks):
|
| 95 |
-
sim = jnp.matmul(e, text_embedding.T)
|
| 96 |
-
sims.append(sim)
|
| 97 |
-
if len(sims) > 1:
|
| 98 |
-
scores.append(sim * m)
|
| 99 |
-
mask_val += 1 - m
|
| 100 |
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
for i in range(iterations):
|
| 103 |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
|
|
|
|
| 104 |
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
|
|
|
|
|
|
|
|
|
|
| 105 |
return np.asarray(score), input_image
|
| 106 |
|
| 107 |
|
|
@@ -144,7 +214,7 @@ def app():
|
|
| 144 |
with col2:
|
| 145 |
pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
|
| 146 |
|
| 147 |
-
iterations = st.selectbox("Refinement Steps", options=range(
|
| 148 |
|
| 149 |
compute = st.button("LOCATE")
|
| 150 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from text2image import get_model, get_tokenizer, get_image_transform
|
| 3 |
from utils import text_encoder
|
| 4 |
+
from transformers import AutoProcessor
|
| 5 |
from PIL import Image
|
| 6 |
from jax import numpy as jnp
|
| 7 |
import pandas as pd
|
|
|
|
| 13 |
import gc
|
| 14 |
|
| 15 |
|
| 16 |
+
preprocess = AutoProcessor.from_pretrained("clip-italian/clip-italian")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
+
def resize_longer(image, longer_size=224):
|
| 20 |
+
old_size = image.size
|
| 21 |
+
ratio = float(longer_size) / max(old_size)
|
| 22 |
+
new_size = tuple([int(x * ratio) for x in old_size])
|
| 23 |
image = image.resize(new_size, Image.ANTIALIAS)
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def pad_to_square(image):
|
| 28 |
+
(a,b)=image.shape[:2]
|
| 29 |
+
if a<b:
|
| 30 |
+
ah = (b - a) // 2
|
| 31 |
+
padding=((ah,b - a -ah), (0,0), (0,0))
|
| 32 |
+
else:
|
| 33 |
+
bh = (a - b) // 2
|
| 34 |
+
padding=((0,0), (bh,a-b-bh), (0,0))
|
| 35 |
+
return np.pad(image, padding,mode='constant',constant_values=127)
|
| 36 |
|
| 37 |
|
| 38 |
def image_encoder(image, model):
|
| 39 |
image = np.transpose(image, (0, 2, 3, 1))
|
| 40 |
features = model.get_image_features(image)
|
| 41 |
+
feature_norms = jnp.linalg.norm(features, axis=-1, keepdims=True)
|
| 42 |
+
features = features / feature_norms
|
| 43 |
+
return features, feature_norms
|
| 44 |
|
| 45 |
|
| 46 |
def gen_image_batch(image_url, image_size=224, pixel_size=10):
|
|
|
|
| 48 |
|
| 49 |
image_batch = []
|
| 50 |
masks = []
|
| 51 |
+
is_vertical = []
|
| 52 |
+
is_horizontal = []
|
| 53 |
+
|
| 54 |
image_raw = requests.get(image_url, stream=True).raw
|
| 55 |
image = Image.open(image_raw).convert("RGB")
|
| 56 |
+
image = np.array(resize_longer(image, longer_size=image_size))
|
| 57 |
+
gray = np.ones_like(image) * 127
|
| 58 |
+
mask = np.ones_like(image[:,:,:1])
|
| 59 |
|
| 60 |
image_batch.append(image)
|
| 61 |
masks.append(mask)
|
| 62 |
+
is_vertical.append(True)
|
| 63 |
+
is_horizontal.append(True)
|
| 64 |
+
|
| 65 |
|
| 66 |
+
for i in range(0, image.shape[0] // pixel_size + 1):
|
| 67 |
+
for j in range(i+1, image.shape[0] // pixel_size + 2):
|
| 68 |
m = mask.copy()
|
| 69 |
+
m[:min(i*pixel_size, image_size), :] = 0
|
| 70 |
+
m[min(j*pixel_size, image_size):, :] = 0
|
| 71 |
neg_m = 1 - m
|
| 72 |
+
image_batch.append(image.copy() * m + gray * neg_m)
|
| 73 |
masks.append(m)
|
| 74 |
+
is_vertical.append(False)
|
| 75 |
+
is_horizontal.append(True)
|
| 76 |
|
| 77 |
+
for i in range(0, image.shape[1] // pixel_size + 1):
|
| 78 |
+
for j in range(i+1, image.shape[1] // pixel_size + 2):
|
| 79 |
m = mask.copy()
|
| 80 |
+
m[:, :min(i*pixel_size, image_size)] = 0
|
| 81 |
+
m[:, min(j*pixel_size, image_size):] = 0
|
| 82 |
neg_m = 1 - m
|
| 83 |
+
image_batch.append(image.copy() * m + gray * neg_m)
|
| 84 |
masks.append(m)
|
| 85 |
+
is_vertical.append(True)
|
| 86 |
+
is_horizontal.append(False)
|
| 87 |
|
| 88 |
+
return image_batch, masks, is_vertical, is_horizontal
|
| 89 |
|
| 90 |
|
| 91 |
def get_heatmap(image_url, text, pixel_size=10, iterations=3):
|
| 92 |
+
# tokenizer = get_tokenizer()
|
| 93 |
model = get_model()
|
| 94 |
image_size = model.config.vision_config.image_size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
images, masks, vertical, horizontal = gen_image_batch(image_url, pixel_size=pixel_size)
|
| 97 |
input_image = images[0].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
inputs = preprocess(text=[text], images=images, return_tensors="np")
|
| 100 |
+
|
| 101 |
+
image_embeddings, embedding_norms = image_encoder(inputs['pixel_values'], model)
|
| 102 |
+
text_embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[0]
|
| 103 |
+
text_embedding = text_embedding / jnp.linalg.norm(text_embedding, axis=-1, keepdims=True)
|
| 104 |
+
|
| 105 |
+
vertical_scores = jnp.zeros((masks[0].shape[1], 512))
|
| 106 |
+
vertical_masks = jnp.zeros((masks[0].shape[1], 1))
|
| 107 |
+
horizontal_scores = jnp.zeros((masks[0].shape[0], 512))
|
| 108 |
+
horizontal_masks = jnp.zeros((masks[0].shape[0], 1))
|
| 109 |
+
|
| 110 |
+
for e, n, m, v, h in zip(image_embeddings, embedding_norms, masks, vertical, horizontal):
|
| 111 |
+
# sim = (jnp.matmul(e, text_embedding.T)) # + 1) / 2
|
| 112 |
+
|
| 113 |
+
# sim = jax.nn.relu(sim)
|
| 114 |
+
|
| 115 |
+
# if full_sim is None:
|
| 116 |
+
# full_sim = sim
|
| 117 |
+
# sim = jax.nn.relu(sim - full_sim)
|
| 118 |
+
emb = jnp.expand_dims(e, axis=0) * n
|
| 119 |
+
|
| 120 |
+
if v:
|
| 121 |
+
vm = jnp.any(m, axis=0)
|
| 122 |
+
vertical_scores = vertical_scores + (emb * vm) #/ jnp.mean(vm)
|
| 123 |
+
vertical_masks = vertical_masks + vm #/ jnp.mean(vm)
|
| 124 |
+
if h:
|
| 125 |
+
hm = jnp.any(m, axis=1)
|
| 126 |
+
horizontal_scores = horizontal_scores + (emb * hm) #/ jnp.mean(hm)
|
| 127 |
+
horizontal_masks = horizontal_masks + hm #/ jnp.mean(hm)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
embs_1 = jnp.expand_dims((vertical_scores), axis=0) * jnp.expand_dims(jnp.abs(horizontal_scores), axis=1)
|
| 131 |
+
embs_2 = jnp.expand_dims(jnp.abs(vertical_scores), axis=0) * jnp.expand_dims((horizontal_scores), axis=1)
|
| 132 |
+
full_embs = jnp.minimum(embs_1, embs_2)
|
| 133 |
+
mask_sum = jnp.expand_dims(vertical_masks, axis=0) * jnp.expand_dims(horizontal_masks, axis=1)
|
| 134 |
+
|
| 135 |
+
print(full_embs.shape)
|
| 136 |
+
|
| 137 |
+
#full_embs = full_embs / jnp.linalg.norm(full_embs, axis=-1, keepdims=True)
|
| 138 |
+
full_embs = (full_embs / mask_sum)
|
| 139 |
+
|
| 140 |
+
orig_shape = full_embs.shape
|
| 141 |
+
sims = jnp.matmul(jnp.reshape(full_embs, (-1, 512)), text_embedding.T)
|
| 142 |
+
sims = jnp.reshape(sims, (*orig_shape[:2], 1))
|
| 143 |
+
#sims = jax.nn.relu(sims)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# mean_vertical_scores = vertical_scores / vertical_masks
|
| 151 |
+
# mean_horizontal_scores = horizontal_scores / horizontal_masks
|
| 152 |
+
|
| 153 |
+
# print(mean_vertical_score)
|
| 154 |
+
# print(mean_horizontal_score)
|
| 155 |
+
|
| 156 |
+
# score = jnp.matmul(mean_vertical_scores, mean_horizontal_scores.T)
|
| 157 |
+
|
| 158 |
+
#mask = jnp.matmul(vertical_masks, horizontal_scores.T)
|
| 159 |
+
#score = score / mask
|
| 160 |
+
|
| 161 |
+
score = sims # jnp.expand_dims(score.T, axis=-1)
|
| 162 |
+
#score = jax.nn.relu(score) / jnp.max(jnp.abs(score))
|
| 163 |
+
|
| 164 |
+
#score = jax.nn.relu(score - sims[0])
|
| 165 |
+
|
| 166 |
+
# score = jnp.square(score)
|
| 167 |
+
|
| 168 |
for i in range(iterations):
|
| 169 |
score = jnp.clip(score - jnp.mean(score), 0, jnp.inf)
|
| 170 |
+
|
| 171 |
score = (score - jnp.min(score)) / (jnp.max(score) - jnp.min(score))
|
| 172 |
+
|
| 173 |
+
print(jnp.min(score), jnp.max(score))
|
| 174 |
+
|
| 175 |
return np.asarray(score), input_image
|
| 176 |
|
| 177 |
|
|
|
|
| 214 |
with col2:
|
| 215 |
pixel_size = st.selectbox("Pixel Size", options=range(10, 26, 5), index=2)
|
| 216 |
|
| 217 |
+
iterations = st.selectbox("Refinement Steps", options=range(1, 6, 1), index=0)
|
| 218 |
|
| 219 |
compute = st.button("LOCATE")
|
| 220 |
|
modeling_hybrid_clip.py
CHANGED
|
@@ -136,8 +136,10 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
| 136 |
):
|
| 137 |
if input_shape is None:
|
| 138 |
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
|
|
|
|
|
|
| 139 |
|
| 140 |
-
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 141 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
| 142 |
|
| 143 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
|
|
|
| 136 |
):
|
| 137 |
if input_shape is None:
|
| 138 |
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
| 139 |
+
|
| 140 |
+
print(kwargs)
|
| 141 |
|
| 142 |
+
module = self.module_class(config=config, dtype=dtype) # , **kwargs)
|
| 143 |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
| 144 |
|
| 145 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
requirements.txt
CHANGED
|
@@ -8,4 +8,4 @@ stqdm
|
|
| 8 |
pandas
|
| 9 |
requests
|
| 10 |
psutil
|
| 11 |
-
streamlit
|
|
|
|
| 8 |
pandas
|
| 9 |
requests
|
| 10 |
psutil
|
| 11 |
+
streamlit
|