try to unzip flickr.zip file, cosmetic changes to display results
Browse files
app.py
CHANGED
|
@@ -1,38 +1,59 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
| 3 |
from main import *
|
| 4 |
from setup import *
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
|
| 8 |
-
def
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# lorax = Image.open('img/Lorax.jpg')
|
| 11 |
# print(lorax.width, lorax.height)
|
| 12 |
# st.image(lorax, width = 250)
|
| 13 |
|
| 14 |
-
|
| 15 |
i = 0
|
| 16 |
-
for
|
| 17 |
-
for col in
|
| 18 |
-
image_name, comment = search_result[i]
|
| 19 |
-
col.image(
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
i = i + 1
|
| 22 |
return
|
| 23 |
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
-
st.
|
| 28 |
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
-
search_result = search2(search_request)
|
| 37 |
-
for item in search_result :
|
| 38 |
-
st.write(item)
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
|
| 3 |
from main import *
|
| 4 |
from setup import *
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
|
| 8 |
+
def show_result(search_request,
|
| 9 |
+
search_result,
|
| 10 |
+
img_dir,
|
| 11 |
+
container) :
|
| 12 |
|
| 13 |
# lorax = Image.open('img/Lorax.jpg')
|
| 14 |
# print(lorax.width, lorax.height)
|
| 15 |
# st.image(lorax, width = 250)
|
| 16 |
|
| 17 |
+
container.header("\"" +search_request+ "\" reminds me of :")
|
| 18 |
i = 0
|
| 19 |
+
for _ in range(0, 2):
|
| 20 |
+
for col in container.columns(2) :
|
| 21 |
+
image_name, comment, score = search_result[i]
|
| 22 |
+
col.image(img_dir + image_name, width = 300)
|
| 23 |
+
|
| 24 |
+
if score != '' :
|
| 25 |
+
sim_score = f"{float(100 * score):.3f}"
|
| 26 |
+
col.header(sim_score + " " +comment)
|
| 27 |
+
else :
|
| 28 |
+
col.header(comment)
|
| 29 |
i = i + 1
|
| 30 |
return
|
| 31 |
|
| 32 |
+
def show_landing() :
|
| 33 |
|
| 34 |
+
st.title('Find my pic!')
|
| 35 |
|
| 36 |
+
search_request = st.text_input('', 'Search ...')
|
| 37 |
|
| 38 |
+
action = st.container()
|
| 39 |
+
results = st.container()
|
| 40 |
|
| 41 |
+
if action.button('Find Relsease 1!') and os.path.exists('img'):
|
| 42 |
+
search_result = search1(search_request)
|
| 43 |
+
show_result(search_request,
|
| 44 |
+
search_result,
|
| 45 |
+
'img/',
|
| 46 |
+
results)
|
| 47 |
+
|
| 48 |
+
if action.button('Find Relsease 2!') and os.path.exists('flickr30k_images_SAVE') :
|
| 49 |
+
search_result = search2(search_request)
|
| 50 |
+
show_result(search_request,
|
| 51 |
+
search_result,
|
| 52 |
+
'flickr30k_images_SAVE/',
|
| 53 |
+
results)
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
downlad_images()
|
| 58 |
|
| 59 |
+
show_landing()
|
|
|
|
|
|
|
|
|
main.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
-
|
| 2 |
import random
|
| 3 |
import torch
|
| 4 |
-
|
| 5 |
from dataframe import *
|
| 6 |
from model import *
|
| 7 |
|
|
|
|
| 8 |
images = ["Girl.jpg",
|
| 9 |
"Cat In Hat.jpg",
|
| 10 |
"Cat In The Hat.jpg",
|
|
@@ -26,7 +25,7 @@ def search1(search_prompt : str):
|
|
| 26 |
Given a search_prompt, return an array of pictures to display
|
| 27 |
"""
|
| 28 |
|
| 29 |
-
return [ (images[i], images[i].split('.')[0]) for i in random.sample(range(len(images)), 4) ]
|
| 30 |
|
| 31 |
def search2(search_prompt : str) :
|
| 32 |
|
|
@@ -37,7 +36,7 @@ def search2(search_prompt : str) :
|
|
| 37 |
model_ID = "openai/clip-vit-base-patch32"
|
| 38 |
|
| 39 |
# Get model, processor & tokenizer
|
| 40 |
-
model,
|
| 41 |
|
| 42 |
image_data_df = get_image_data()
|
| 43 |
|
|
|
|
|
|
|
| 1 |
import random
|
| 2 |
import torch
|
|
|
|
| 3 |
from dataframe import *
|
| 4 |
from model import *
|
| 5 |
|
| 6 |
+
|
| 7 |
images = ["Girl.jpg",
|
| 8 |
"Cat In Hat.jpg",
|
| 9 |
"Cat In The Hat.jpg",
|
|
|
|
| 25 |
Given a search_prompt, return an array of pictures to display
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
return [ (images[i], images[i].split('.')[0], '') for i in random.sample(range(len(images)), 4) ]
|
| 29 |
|
| 30 |
def search2(search_prompt : str) :
|
| 31 |
|
|
|
|
| 36 |
model_ID = "openai/clip-vit-base-patch32"
|
| 37 |
|
| 38 |
# Get model, processor & tokenizer
|
| 39 |
+
model, tokenizer = get_model_info(model_ID, device)
|
| 40 |
|
| 41 |
image_data_df = get_image_data()
|
| 42 |
|
model.py
CHANGED
|
@@ -6,15 +6,12 @@ from dataframe import *
|
|
| 6 |
def get_model_info(model_ID, device):
|
| 7 |
# Save the model to device
|
| 8 |
model = CLIPModel.from_pretrained(model_ID).to(device)
|
| 9 |
-
|
| 10 |
-
# Get the processor
|
| 11 |
-
processor = CLIPProcessor.from_pretrained(model_ID)
|
| 12 |
-
|
| 13 |
# Get the tokenizer
|
| 14 |
tokenizer = CLIPTokenizer.from_pretrained(model_ID)
|
| 15 |
|
| 16 |
# Return model, processor & tokenizer
|
| 17 |
-
return model,
|
| 18 |
|
| 19 |
|
| 20 |
def get_single_text_embedding(text, model, tokenizer, device):
|
|
@@ -25,8 +22,15 @@ def get_single_text_embedding(text, model, tokenizer, device):
|
|
| 25 |
|
| 26 |
return embedding_as_np
|
| 27 |
|
| 28 |
-
def
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def get_top_N_images(query,
|
| 32 |
data,
|
|
@@ -57,4 +61,4 @@ def get_top_N_images(query,
|
|
| 57 |
"""
|
| 58 |
|
| 59 |
result_df = most_similar_articles[revevant_cols].reset_index()
|
| 60 |
-
return
|
|
|
|
| 6 |
def get_model_info(model_ID, device):
|
| 7 |
# Save the model to device
|
| 8 |
model = CLIPModel.from_pretrained(model_ID).to(device)
|
| 9 |
+
|
|
|
|
|
|
|
|
|
|
| 10 |
# Get the tokenizer
|
| 11 |
tokenizer = CLIPTokenizer.from_pretrained(model_ID)
|
| 12 |
|
| 13 |
# Return model, processor & tokenizer
|
| 14 |
+
return model, tokenizer
|
| 15 |
|
| 16 |
|
| 17 |
def get_single_text_embedding(text, model, tokenizer, device):
|
|
|
|
| 22 |
|
| 23 |
return embedding_as_np
|
| 24 |
|
| 25 |
+
def get_item_data(result, index) :
|
| 26 |
+
|
| 27 |
+
img_name = str(result['image_name'][index])
|
| 28 |
+
|
| 29 |
+
# TODO: add code to get the original comment
|
| 30 |
+
comment = str(result['comment'][index])
|
| 31 |
+
cos_sim = result['cos_sim'][index]
|
| 32 |
+
|
| 33 |
+
return (img_name, comment, cos_sim)
|
| 34 |
|
| 35 |
def get_top_N_images(query,
|
| 36 |
data,
|
|
|
|
| 61 |
"""
|
| 62 |
|
| 63 |
result_df = most_similar_articles[revevant_cols].reset_index()
|
| 64 |
+
return [get_item_data(result_df, i) for i in range(len(result_df))]
|
setup.py
CHANGED
|
@@ -1,33 +1,26 @@
|
|
| 1 |
|
| 2 |
import os
|
| 3 |
-
import
|
| 4 |
|
| 5 |
-
|
| 6 |
-
# from huggingface_hub.archive import unpack_archive
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# # Specify the Google Drive link to the archive file
|
| 17 |
-
# archive_url = 'https://drive.google.com/uc?id=14QhofCbby053kWbVeWEBHCxOROQS-bjN'
|
| 18 |
-
|
| 19 |
-
# # Specify the destination directory within the Hugging Face space
|
| 20 |
-
# destination_dir = 'osanchik/flickr'
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
# # Replace with your desired destination directory
|
| 24 |
-
|
| 25 |
-
# # Construct the destination path
|
| 26 |
-
# destination_path = hf_hub_url(destination_dir)
|
| 27 |
-
# # Download the archive to the destination path
|
| 28 |
-
# cached_download(archive_url, destination_path)
|
| 29 |
-
# # Unpack the archive
|
| 30 |
-
# unpack_archive(destination_path, destination_dir)
|
| 31 |
-
# print(f"Archive unpacked to: {destination_dir}")
|
| 32 |
|
| 33 |
-
|
|
|
|
| 1 |
|
| 2 |
import os
|
| 3 |
+
import requests, zipfile, io
|
| 4 |
|
| 5 |
+
def downlad_images() :
|
|
|
|
| 6 |
|
| 7 |
+
img_dir = 'flickr30k_images'
|
| 8 |
+
zip_file = 'data/flickr.zip'
|
| 9 |
|
| 10 |
+
#TODO : zip_file_url?
|
| 11 |
+
zip_file_url = 'https://drive.google.com/open?id=14QhofCbby053kWbVeWEBHCxOROQS-bjN&authuser=0'
|
| 12 |
+
|
| 13 |
+
try :
|
| 14 |
+
if not os.path.exists(img_dir) :
|
| 15 |
+
if not os.path.exists(zip_file) :
|
| 16 |
+
r = requests.get(zip_file_url)
|
| 17 |
+
z = zipfile.ZipFile(io.BytesIO(r.content))
|
| 18 |
+
z.extractall(".")
|
| 19 |
|
| 20 |
+
else :
|
| 21 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
| 22 |
+
zip_ref.extractall(".")
|
| 23 |
+
except :
|
| 24 |
+
print("Problems with image file download")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
return
|