event_retrieval / create_caption_using_blip2.py
sanskar753's picture
Upload folder using huggingface_hub
02d3a85 verified
import os
# allocated gpus
# os.environ['CUDA_VISIBLE_DEVICES'] = '1,5'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
from lavis.models import load_model_and_preprocess
from PIL import Image
from complex_image_search.utils.display_utils import display_image_and_text
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms.functional as F
if __name__ == "__main__":
# rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/chinchila.png").convert("RGB")
# rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/shiba.png").convert("RGB")
# rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/flamingo.png").convert("RGB")
#
# prompt1 = "This is a chinchilla. They are mainly found in Chile."
# prompt2 = "This is a shiba. They are very popular in Japan."
# prompt3 = "This is " #a flamingo. They are found in the Caribbean and South America."
# rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/underground.png").convert("RGB")
# rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/congress.png").convert("RGB")
# rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/Soulomes.png").convert("RGB")
#
# prompt1 = "Output: Underground"
# prompt2 = "Output: Congress ave 400"
# prompt3 = "Output:"
# prompt1 = "This is a start of a highway sign"
# prompt2 = "This is an end of a highway sign"
# prompt3 = "This is a start of a traffic calming zone sign"
# prompt4 = "This is an end of a traffic calming zone sign"
# prompt5 = "What is the meaning of this sign?"
# prompt1 = "This road sign indicates the start of a zone."
# prompt2 = "This road sign indicates the end of a zone."
# prompt3 = "This road sign indicates the start of a zone."
# prompt4 = "This road sign indicates the end of zone."
# prompt5 = "This road sign"
# prompt1 = "Output: This road sign indicates the beginning of something."
# prompt2 = "Output: This road sign indicates the end of something."
# prompt3 = "Output: This road sign indicates the beginning of something."
# prompt4 = "Output: This road sign indicates the end of something."
# prompt5 = "Output: This road sign"
#
prompt1 = "This road sign indicates the start of a zone."
prompt2 = "This road sign indicates the end of a zone."
prompt3 = "This road sign indicates the start of a zone."
prompt4 = "This road sign indicates the end of a zone."
# prompt5 = "Question: {What does this road sign indicates} Short answer:"
prompt5 = "This road sign indicates"
# prompt1 = "Question: start of sign or end of sign? Answer: start of sign."
# prompt2 = "Question: start of sign or end of sign? Answer: end of sign."
# prompt3 = "Question: start of sign or end of sign? Answer: start of sign."
# prompt4 = "Question: start of sign or end of sign? Answer: end of sign."
# prompt5 = "Question: start of sign or end of sign? Answer:"
#
prompt1 = "Question: {start of zone or end of zone} Short answer: start of zone."
prompt2 = "Question: {start of zone or end of zone} Short answer: end of zone."
prompt3 = "Question: {start of zone or end of zone} Short answer: start of zone."
prompt4 = "Question: {start of zone or end of zone} Short answer: end of zone."
prompt5 = "Question: {start of zone or end of zone} Short answer:"
# prompt1 = "Output: start of."
# prompt2 = "Output: end of."
# prompt3 = "Output: start of."
# prompt4 = "Output: end of."
# prompt5 = "Output:"
# set number of available gpus
world_size = torch.cuda.device_count()
print("Total number of available gpus: " + str(torch.cuda.device_count()))
num_of_captions = 3
model_type = "instruct_blip_flan_t5"
# directory where the images are stored
image_directory = "/fs/scratch/rb_bd_dlp_rng-dl01_cr_AIM_employees/AIM_105/Complex_Image_Search/images"
# path to image
path = os.path.join(image_directory, "DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png")
# LB-UH_104_20180310_084704_f000545_fc00011793_4d87dc.png
# LB-UH_104_20180310_123432_f000110_fc00283857_4d87dc.png
# DS-CN_13R7C_20180508_142445_f000550_fc00191778_4d87dc.png
# DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png
# DS-CN_13R7C_20180517_115758_f000000_fc00020673_4d87dc.png
# DS-CN_13R7C_20180517_130432_f000770_fc00188181_4d87dc.png
# DS-CN_13R7C_20180518_134040_f000550_fc00152871_4d87dc.png
# DS-CN_13R7C_20180518_141444_f000000_fc00235740_4d87dc.png
# DS-CN_13R7C_20180519_130651_f000440_fc00519180_4d87dc.png
# path = "/home/gea1tv/Deploy/complex_image_search/rgb_example.jpg"
# rgb_image = Image.open(path).convert("RGB")
# load image
rgb_image = Image.open(path).convert("L").convert("RGB")
rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/h_way_sign.png").convert("RGB")
rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_h_way_sign.png").convert("RGB")
rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/calm_zone_sign.png").convert("RGB")
rgb_img4 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_calm_zone_sign.png").convert("RGB")
rgb_img5 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/30_sign.png").convert("RGB")
rgb_img6 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_30_sign.png").convert("RGB")
# set up your device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.
# this also loads the associated image processors
print("Loading " + model_type + " model...\n")
if model_type == "instruct_blip_flan_t5":
original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5_instruct", model_type="flant5xl", is_eval=True, device=device)
elif model_type == "instruct_blip_vicuna":
original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device)
elif model_type == "blip2_flan_t5_caption":
original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device)
# original_model = original_model.float()
print("Finished Loading model.\n")
# rgb_image = F.crop(rgb_image, 384 * 0, 384 * 1, 384, 384)
img = vis_processors["eval"](rgb_image).unsqueeze(0).to(device)
img1 = vis_processors["eval"](rgb_img1).unsqueeze(0).to(device)
img2 = vis_processors["eval"](rgb_img2).unsqueeze(0).to(device)
img3 = vis_processors["eval"](rgb_img3).unsqueeze(0).to(device)
img4 = vis_processors["eval"](rgb_img4).unsqueeze(0).to(device)
img5 = vis_processors["eval"](rgb_img5).unsqueeze(0).to(device)
img6 = vis_processors["eval"](rgb_img6).unsqueeze(0).to(device)
# resized_image = transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC)(rgb_image)
#
# fig = plt.figure(figsize=(10, 5))
#
# # Display the image on the upper subplot
# plt.imshow(resized_image, cmap='gray')
# plt.show()
# exit(1)
# prompt = "Question: {What is the meaning of this sign} Short answer:"
prompt = "Can you describe the image in details?"
# prompt = "Can you describe the image in details focusing on road signs?"
# prompt = "A short image description:"
# prompt = "Can you describe the image in details?"
# prompt = "a photo of"
# prompt = "The following road sign is detected in the image: a do not enter sign. Write a detailed description of the image using the detected road sign."
# prompt = "Describe in detail all the road signs in the image including their meaning and their locations"
# answer = original_model.generate({"image": img, "prompt": prompt}, use_nucleus_sampling=True, top_p=0.9, temperature=1)
# answer = original_model.generate({"image": img, "prompt": prompt})
# answer = original_model.generate({"image": img})
# prompt = "Road signs that indicate the end of something, such as the end of a specific road condition or traffic regulation, often use specific visual features to convey their meaning. For example: Strikethrough: A striking visual feature is a diagonal line that crosses through the symbol or text indicating what is ending. For instance, if a sign indicates the end of a no-passing zone, the symbol of a no-passing zone with a diagonal line across it could be used. Question: {Does this road sign indicates the end of somthing?} Answer:"
# prompt = "Question: {Does this road sign have a diagonal stripe across it?} Short answer:"
answer = original_model.generate({"image": img1, "prompt": prompt})
print(answer)
answer = original_model.generate({"image": img2, "prompt": prompt})
print(answer)
answer = original_model.generate({"image": img3, "prompt": prompt})
print(answer)
answer = original_model.generate({"image": img4, "prompt": prompt})
print(answer)
answer = original_model.generate({"image": img5, "prompt": prompt})
print(answer)
answer = original_model.generate({"image": img6, "prompt": prompt})
print(answer)
#
# answer2 = original_model.generate({"image": img5, "prompt": prompt + " " + answer[0] + ". Question: why?"})
# print(answer2)
#
# answer = original_model.generate({"image": img6, "prompt": prompt})
# print(answer)
#
# answer2 = original_model.generate({"image": img6, "prompt": prompt + " " + answer[0] + ". Question: why?"})
# print(answer2)
# try in context leraning
# answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img5)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]})
#
# print(answer)
# #
# answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img6)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]})
#
# print(answer)
# prompt0 = "a photo of"
# prompt1 = "What is the meaning of this road sign?"
# prompt2 = "Question: {What is the meaning of this road sign} Short answer:"
# prompt3 = "Question: {What is the meaning of this road sign} Answer:"
# prompt4 = "Can you explain the meaning of this road sign?"
# prompt5 = "Question: {start of or end of} Answer:"
# prompt6 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a zone (b) end of a zone. Answer:"
# prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of 30 km/h minimal speed limit (b) end of 30 km/h minimal speed limit. Answer:"
# prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a highway (b) end of a highway. Answer:"
# prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a priority road (b) end of a priority road. Answer:"
# create a figure with the top picks and the text
# fig = display_image_and_text(Image.open(path), answer)
# fig = display_image_and_text(rgb_image, answer, prompt)
# show image
# plt.show()