feat: add segmentation
Browse files- app.py +25 -14
- lib/utils/model.py +53 -4
- pages/losses.py +42 -33
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from st_pages import Page, show_pages, add_page_title, Section
|
| 3 |
-
from lib.utils.model import get_model, get_similarities
|
| 4 |
-
from lib.utils.timer import timer
|
| 5 |
|
| 6 |
add_page_title()
|
| 7 |
|
|
@@ -23,23 +23,31 @@ caption = st.text_input('Description Input')
|
|
| 23 |
|
| 24 |
images = st.file_uploader('Upload images', accept_multiple_files=True)
|
| 25 |
if images is not None:
|
| 26 |
-
|
| 27 |
-
st.image(images)
|
| 28 |
|
| 29 |
st.header('Options')
|
| 30 |
st.subheader('Ranks', help='How many predictions the model is allowed to make')
|
| 31 |
|
| 32 |
-
ranks = st.slider('slider_ranks', min_value=1, max_value=10,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
button = st.button('Match most similar', disabled=len(images) == 0 or caption == '')
|
| 35 |
|
| 36 |
if button:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
st.header('Results')
|
| 38 |
with st.spinner('Loading model'):
|
| 39 |
model = get_model()
|
| 40 |
|
| 41 |
-
st.text(
|
| 42 |
-
|
|
|
|
| 43 |
time = timer()
|
| 44 |
with st.spinner('Computing and ranking similarities'):
|
| 45 |
with timer() as t:
|
|
@@ -47,15 +55,16 @@ if button:
|
|
| 47 |
elapsed = t()
|
| 48 |
|
| 49 |
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
|
| 50 |
-
|
| 51 |
c1, c2, c3 = st.columns(3)
|
| 52 |
with c1:
|
| 53 |
st.subheader('Rank')
|
| 54 |
with c2:
|
| 55 |
st.subheader('Image')
|
| 56 |
with c3:
|
| 57 |
-
st.subheader('Cosine Similarity',
|
| 58 |
-
|
|
|
|
| 59 |
for i, idx in enumerate(indices):
|
| 60 |
c1, c2, c3 = st.columns(3)
|
| 61 |
with c1:
|
|
@@ -72,5 +81,7 @@ with st.sidebar:
|
|
| 72 |
|
| 73 |
st.subheader('Useful Links')
|
| 74 |
st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
|
| 75 |
-
st.markdown(
|
| 76 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from st_pages import Page, show_pages, add_page_title, Section
|
| 3 |
+
from lib.utils.model import get_model, get_similarities, get_detr, segment_images
|
| 4 |
+
from lib.utils.timer import timer
|
| 5 |
|
| 6 |
add_page_title()
|
| 7 |
|
|
|
|
| 23 |
|
| 24 |
images = st.file_uploader('Upload images', accept_multiple_files=True)
|
| 25 |
if images is not None:
|
| 26 |
+
|
| 27 |
+
st.image(images) # type: ignore
|
| 28 |
|
| 29 |
st.header('Options')
|
| 30 |
st.subheader('Ranks', help='How many predictions the model is allowed to make')
|
| 31 |
|
| 32 |
+
ranks = st.slider('slider_ranks', min_value=1, max_value=10,
|
| 33 |
+
label_visibility='collapsed', value=5)
|
| 34 |
+
do_segment = st.checkbox('Segment images with DETR', value=False)
|
| 35 |
+
button = st.button('Match most similar', disabled=len(
|
| 36 |
+
images) == 0 or caption == '')
|
| 37 |
|
|
|
|
| 38 |
|
| 39 |
if button:
|
| 40 |
+
if do_segment:
|
| 41 |
+
detr, processor = get_detr()
|
| 42 |
+
images = segment_images(detr, processor, images)
|
| 43 |
+
|
| 44 |
st.header('Results')
|
| 45 |
with st.spinner('Loading model'):
|
| 46 |
model = get_model()
|
| 47 |
|
| 48 |
+
st.text(
|
| 49 |
+
f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
|
| 50 |
+
|
| 51 |
time = timer()
|
| 52 |
with st.spinner('Computing and ranking similarities'):
|
| 53 |
with timer() as t:
|
|
|
|
| 55 |
elapsed = t()
|
| 56 |
|
| 57 |
indices = similarities.argsort(descending=True).cpu().tolist()[:ranks]
|
| 58 |
+
|
| 59 |
c1, c2, c3 = st.columns(3)
|
| 60 |
with c1:
|
| 61 |
st.subheader('Rank')
|
| 62 |
with c2:
|
| 63 |
st.subheader('Image')
|
| 64 |
with c3:
|
| 65 |
+
st.subheader('Cosine Similarity',
|
| 66 |
+
help='Due to the nature of the SDM loss, the higher the similarity, the more similar the match is')
|
| 67 |
+
|
| 68 |
for i, idx in enumerate(indices):
|
| 69 |
c1, c2, c3 = st.columns(3)
|
| 70 |
with c1:
|
|
|
|
| 81 |
|
| 82 |
st.subheader('Useful Links')
|
| 83 |
st.markdown('[arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501)')
|
| 84 |
+
st.markdown(
|
| 85 |
+
'[IRRA implementation (Pytorch Lightning + Transformers)](https://github.com/grostaco/modern-IRRA)')
|
| 86 |
+
st.markdown(
|
| 87 |
+
'[IRRA implementation (PyTorch)](https://github.com/anosorae/IRRA/tree/main)')
|
lib/utils/model.py
CHANGED
|
@@ -1,15 +1,20 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
import yaml
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
|
|
|
|
|
|
|
| 6 |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
|
| 7 |
from lib.IRRA.image import prepare_images
|
| 8 |
from lib.IRRA.model.build import build_model, IRRA
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from easydict import EasyDict
|
| 11 |
|
| 12 |
-
|
|
|
|
| 13 |
def get_model():
|
| 14 |
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
|
| 15 |
args = EasyDict(args)
|
|
@@ -17,7 +22,51 @@ def get_model():
|
|
| 17 |
|
| 18 |
model = build_model(args)
|
| 19 |
|
| 20 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
| 23 |
tokenizer = SimpleTokenizer()
|
|
@@ -30,5 +79,5 @@ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
|
| 30 |
|
| 31 |
image_feats = F.normalize(image_feats, p=2, dim=1)
|
| 32 |
text_feats = F.normalize(text_feats, p=2, dim=1)
|
| 33 |
-
|
| 34 |
return text_feats @ image_feats.t()
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
import yaml
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
|
| 6 |
+
from transformers import DetrImageProcessor, DetrForObjectDetection
|
| 7 |
+
|
| 8 |
from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
|
| 9 |
from lib.IRRA.image import prepare_images
|
| 10 |
from lib.IRRA.model.build import build_model, IRRA
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from pathlib import Path
|
| 13 |
|
| 14 |
from easydict import EasyDict
|
| 15 |
|
| 16 |
+
|
| 17 |
+
@st.cache_resource
|
| 18 |
def get_model():
|
| 19 |
args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
|
| 20 |
args = EasyDict(args)
|
|
|
|
| 22 |
|
| 23 |
model = build_model(args)
|
| 24 |
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@st.cache_resource
|
| 29 |
+
def get_detr():
|
| 30 |
+
processor = DetrImageProcessor.from_pretrained(
|
| 31 |
+
"facebook/detr-resnet-50", revision="no_timm")
|
| 32 |
+
|
| 33 |
+
model = DetrForObjectDetection.from_pretrained(
|
| 34 |
+
"facebook/detr-resnet-50", revision="no_timm")
|
| 35 |
+
|
| 36 |
+
return model, processor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def segment_images(model, processor, images: list[str]):
|
| 40 |
+
segments = []
|
| 41 |
+
id = 0
|
| 42 |
+
|
| 43 |
+
p = Path('segments')
|
| 44 |
+
p.mkdir(exist_ok=True)
|
| 45 |
+
|
| 46 |
+
for image in images:
|
| 47 |
+
image = Image.open(image)
|
| 48 |
+
|
| 49 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 50 |
+
outputs = model(**inputs)
|
| 51 |
+
|
| 52 |
+
target_sizes = torch.tensor([image.size[::-1]])
|
| 53 |
+
results = processor.post_process_object_detection(
|
| 54 |
+
outputs, target_sizes=target_sizes, threshold=0.9)[0]
|
| 55 |
+
|
| 56 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 57 |
+
box = [round(i, 2) for i in box.tolist()]
|
| 58 |
+
label = model.config.id2label[label.item()]
|
| 59 |
+
|
| 60 |
+
if box[2] - box[0] > 70 and box[3] - box[1] > 70:
|
| 61 |
+
if label == 'person':
|
| 62 |
+
file = p / f'img_{id}.jpg'
|
| 63 |
+
image.crop(box).save(file)
|
| 64 |
+
segments.append(file.as_posix())
|
| 65 |
+
|
| 66 |
+
id += 1
|
| 67 |
+
|
| 68 |
+
return segments
|
| 69 |
+
|
| 70 |
|
| 71 |
def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
|
| 72 |
tokenizer = SimpleTokenizer()
|
|
|
|
| 79 |
|
| 80 |
image_feats = F.normalize(image_feats, p=2, dim=1)
|
| 81 |
text_feats = F.normalize(text_feats, p=2, dim=1)
|
| 82 |
+
|
| 83 |
return text_feats @ image_feats.t()
|
pages/losses.py
CHANGED
|
@@ -4,36 +4,45 @@ from st_pages import add_indentation
|
|
| 4 |
add_indentation()
|
| 5 |
|
| 6 |
st.title('Loss functions')
|
| 7 |
-
st.
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
''')
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
st.
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
add_indentation()
|
| 5 |
|
| 6 |
st.title('Loss functions')
|
| 7 |
+
st.markdown('In order to align textual and visual features, multiple loss functions are employed. '
|
| 8 |
+
'The most notable loss function was proposed in [arXiv: Cross-Modal Implicit Relation Reasoning and Aligning for Text-to-Image Person Retrieval](https://arxiv.org/abs/2303.12501) '
|
| 9 |
+
'with the introduction of the SDM loss and the usage of the IRR (Implicit Reason Relations) loss.')
|
| 10 |
+
with st.expander('SDM Loss'):
|
| 11 |
+
st.markdown('''
|
| 12 |
+
The similarity distribution matching (SDM) loss, which is the KL divergence
|
| 13 |
+
of the image to text and text to image to the label distribution.
|
| 14 |
+
|
| 15 |
+
We define $f^v$ and $f^t$ to be the global representation of the visual and textual features respectively.
|
| 16 |
+
The cosine similarity $sim(u, v) = \\frac{u \\cdot v}{|u||v|}$ will be used to compute the probability of the labels.
|
| 17 |
+
|
| 18 |
+
We define $y_{i, j}=1$ if the visual feature $f^v_i$ matches the textual feature $f^t_j$, else $y_{i, j}=0$.
|
| 19 |
+
The predicted label distribution can be formulated by''')
|
| 20 |
+
st.latex(r'''
|
| 21 |
+
p_{i} = \sigma(sim(f^v_i, f^t))
|
| 22 |
+
''')
|
| 23 |
+
|
| 24 |
+
st.markdown('''
|
| 25 |
+
We can define the image to text loss as
|
| 26 |
+
''')
|
| 27 |
+
|
| 28 |
+
st.latex(r'''
|
| 29 |
+
\mathcal{L}_{i2t} = KL(\mathbf{p_i} || \mathbf{q_i})
|
| 30 |
+
''')
|
| 31 |
+
|
| 32 |
+
st.markdown('Where $\\mathbf{q_i}$, the true probability distribution, is defined as')
|
| 33 |
+
|
| 34 |
+
st.latex(r'''
|
| 35 |
+
q_{i, j} = \frac{y_{i, j}}{\sum_{k=1}^{N} y_{i, k}}
|
| 36 |
+
''')
|
| 37 |
+
|
| 38 |
+
st.markdown('It should be noted that the reason this computation is needed is because there could be multiple correct labels.')
|
| 39 |
+
|
| 40 |
+
st.markdown('The SDM loss can be formulated as')
|
| 41 |
+
st.latex(r'''
|
| 42 |
+
\mathcal{L}_{sdm} = \mathcal{L}_{i2t} + \mathcal{L}_{t2i}
|
| 43 |
+
''')
|
| 44 |
+
|
| 45 |
+
with st.expander('IRR (MLM) Loss'):
|
| 46 |
+
...
|
| 47 |
+
with st.expander('ID Loss'):
|
| 48 |
+
...
|