import os # allocated gpus # os.environ['CUDA_VISIBLE_DEVICES'] = '1,5' os.environ['CUDA_VISIBLE_DEVICES'] = '0' import torch from lavis.models import load_model_and_preprocess from PIL import Image from complex_image_search.utils.display_utils import display_image_and_text import matplotlib.pyplot as plt from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import torchvision.transforms.functional as F if __name__ == "__main__": # rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/chinchila.png").convert("RGB") # rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/shiba.png").convert("RGB") # rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/flamingo.png").convert("RGB") # # prompt1 = "This is a chinchilla. They are mainly found in Chile." # prompt2 = "This is a shiba. They are very popular in Japan." # prompt3 = "This is " #a flamingo. They are found in the Caribbean and South America." # rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/underground.png").convert("RGB") # rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/congress.png").convert("RGB") # rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/Soulomes.png").convert("RGB") # # prompt1 = "Output: Underground" # prompt2 = "Output: Congress ave 400" # prompt3 = "Output:" # prompt1 = "This is a start of a highway sign" # prompt2 = "This is an end of a highway sign" # prompt3 = "This is a start of a traffic calming zone sign" # prompt4 = "This is an end of a traffic calming zone sign" # prompt5 = "What is the meaning of this sign?" # prompt1 = "This road sign indicates the start of a zone." # prompt2 = "This road sign indicates the end of a zone." # prompt3 = "This road sign indicates the start of a zone." # prompt4 = "This road sign indicates the end of zone." # prompt5 = "This road sign" # prompt1 = "Output: This road sign indicates the beginning of something." # prompt2 = "Output: This road sign indicates the end of something." # prompt3 = "Output: This road sign indicates the beginning of something." # prompt4 = "Output: This road sign indicates the end of something." # prompt5 = "Output: This road sign" # prompt1 = "This road sign indicates the start of a zone." prompt2 = "This road sign indicates the end of a zone." prompt3 = "This road sign indicates the start of a zone." prompt4 = "This road sign indicates the end of a zone." # prompt5 = "Question: {What does this road sign indicates} Short answer:" prompt5 = "This road sign indicates" # prompt1 = "Question: start of sign or end of sign? Answer: start of sign." # prompt2 = "Question: start of sign or end of sign? Answer: end of sign." # prompt3 = "Question: start of sign or end of sign? Answer: start of sign." # prompt4 = "Question: start of sign or end of sign? Answer: end of sign." # prompt5 = "Question: start of sign or end of sign? Answer:" # prompt1 = "Question: {start of zone or end of zone} Short answer: start of zone." prompt2 = "Question: {start of zone or end of zone} Short answer: end of zone." prompt3 = "Question: {start of zone or end of zone} Short answer: start of zone." prompt4 = "Question: {start of zone or end of zone} Short answer: end of zone." prompt5 = "Question: {start of zone or end of zone} Short answer:" # prompt1 = "Output: start of." # prompt2 = "Output: end of." # prompt3 = "Output: start of." # prompt4 = "Output: end of." # prompt5 = "Output:" # set number of available gpus world_size = torch.cuda.device_count() print("Total number of available gpus: " + str(torch.cuda.device_count())) num_of_captions = 3 model_type = "instruct_blip_flan_t5" # directory where the images are stored image_directory = "/fs/scratch/rb_bd_dlp_rng-dl01_cr_AIM_employees/AIM_105/Complex_Image_Search/images" # path to image path = os.path.join(image_directory, "DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png") # LB-UH_104_20180310_084704_f000545_fc00011793_4d87dc.png # LB-UH_104_20180310_123432_f000110_fc00283857_4d87dc.png # DS-CN_13R7C_20180508_142445_f000550_fc00191778_4d87dc.png # DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png # DS-CN_13R7C_20180517_115758_f000000_fc00020673_4d87dc.png # DS-CN_13R7C_20180517_130432_f000770_fc00188181_4d87dc.png # DS-CN_13R7C_20180518_134040_f000550_fc00152871_4d87dc.png # DS-CN_13R7C_20180518_141444_f000000_fc00235740_4d87dc.png # DS-CN_13R7C_20180519_130651_f000440_fc00519180_4d87dc.png # path = "/home/gea1tv/Deploy/complex_image_search/rgb_example.jpg" # rgb_image = Image.open(path).convert("RGB") # load image rgb_image = Image.open(path).convert("L").convert("RGB") rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/h_way_sign.png").convert("RGB") rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_h_way_sign.png").convert("RGB") rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/calm_zone_sign.png").convert("RGB") rgb_img4 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_calm_zone_sign.png").convert("RGB") rgb_img5 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/30_sign.png").convert("RGB") rgb_img6 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_30_sign.png").convert("RGB") # set up your device device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset. # this also loads the associated image processors print("Loading " + model_type + " model...\n") if model_type == "instruct_blip_flan_t5": original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5_instruct", model_type="flant5xl", is_eval=True, device=device) elif model_type == "instruct_blip_vicuna": original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device) elif model_type == "blip2_flan_t5_caption": original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device) # original_model = original_model.float() print("Finished Loading model.\n") # rgb_image = F.crop(rgb_image, 384 * 0, 384 * 1, 384, 384) img = vis_processors["eval"](rgb_image).unsqueeze(0).to(device) img1 = vis_processors["eval"](rgb_img1).unsqueeze(0).to(device) img2 = vis_processors["eval"](rgb_img2).unsqueeze(0).to(device) img3 = vis_processors["eval"](rgb_img3).unsqueeze(0).to(device) img4 = vis_processors["eval"](rgb_img4).unsqueeze(0).to(device) img5 = vis_processors["eval"](rgb_img5).unsqueeze(0).to(device) img6 = vis_processors["eval"](rgb_img6).unsqueeze(0).to(device) # resized_image = transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC)(rgb_image) # # fig = plt.figure(figsize=(10, 5)) # # # Display the image on the upper subplot # plt.imshow(resized_image, cmap='gray') # plt.show() # exit(1) # prompt = "Question: {What is the meaning of this sign} Short answer:" prompt = "Can you describe the image in details?" # prompt = "Can you describe the image in details focusing on road signs?" # prompt = "A short image description:" # prompt = "Can you describe the image in details?" # prompt = "a photo of" # prompt = "The following road sign is detected in the image: a do not enter sign. Write a detailed description of the image using the detected road sign." # prompt = "Describe in detail all the road signs in the image including their meaning and their locations" # answer = original_model.generate({"image": img, "prompt": prompt}, use_nucleus_sampling=True, top_p=0.9, temperature=1) # answer = original_model.generate({"image": img, "prompt": prompt}) # answer = original_model.generate({"image": img}) # prompt = "Road signs that indicate the end of something, such as the end of a specific road condition or traffic regulation, often use specific visual features to convey their meaning. For example: Strikethrough: A striking visual feature is a diagonal line that crosses through the symbol or text indicating what is ending. For instance, if a sign indicates the end of a no-passing zone, the symbol of a no-passing zone with a diagonal line across it could be used. Question: {Does this road sign indicates the end of somthing?} Answer:" # prompt = "Question: {Does this road sign have a diagonal stripe across it?} Short answer:" answer = original_model.generate({"image": img1, "prompt": prompt}) print(answer) answer = original_model.generate({"image": img2, "prompt": prompt}) print(answer) answer = original_model.generate({"image": img3, "prompt": prompt}) print(answer) answer = original_model.generate({"image": img4, "prompt": prompt}) print(answer) answer = original_model.generate({"image": img5, "prompt": prompt}) print(answer) answer = original_model.generate({"image": img6, "prompt": prompt}) print(answer) # # answer2 = original_model.generate({"image": img5, "prompt": prompt + " " + answer[0] + ". Question: why?"}) # print(answer2) # # answer = original_model.generate({"image": img6, "prompt": prompt}) # print(answer) # # answer2 = original_model.generate({"image": img6, "prompt": prompt + " " + answer[0] + ". Question: why?"}) # print(answer2) # try in context leraning # answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img5)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]}) # # print(answer) # # # answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img6)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]}) # # print(answer) # prompt0 = "a photo of" # prompt1 = "What is the meaning of this road sign?" # prompt2 = "Question: {What is the meaning of this road sign} Short answer:" # prompt3 = "Question: {What is the meaning of this road sign} Answer:" # prompt4 = "Can you explain the meaning of this road sign?" # prompt5 = "Question: {start of or end of} Answer:" # prompt6 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a zone (b) end of a zone. Answer:" # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of 30 km/h minimal speed limit (b) end of 30 km/h minimal speed limit. Answer:" # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a highway (b) end of a highway. Answer:" # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a priority road (b) end of a priority road. Answer:" # create a figure with the top picks and the text # fig = display_image_and_text(Image.open(path), answer) # fig = display_image_and_text(rgb_image, answer, prompt) # show image # plt.show()