Spaces:
Runtime error
Runtime error
| import os | |
| # allocated gpus | |
| # os.environ['CUDA_VISIBLE_DEVICES'] = '1,5' | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
| import torch | |
| from lavis.models import load_model_and_preprocess | |
| from PIL import Image | |
| from complex_image_search.utils.display_utils import display_image_and_text | |
| import matplotlib.pyplot as plt | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| import torchvision.transforms.functional as F | |
| if __name__ == "__main__": | |
| # rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/chinchila.png").convert("RGB") | |
| # rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/shiba.png").convert("RGB") | |
| # rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/flamingo.png").convert("RGB") | |
| # | |
| # prompt1 = "This is a chinchilla. They are mainly found in Chile." | |
| # prompt2 = "This is a shiba. They are very popular in Japan." | |
| # prompt3 = "This is " #a flamingo. They are found in the Caribbean and South America." | |
| # rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/underground.png").convert("RGB") | |
| # rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/congress.png").convert("RGB") | |
| # rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/Soulomes.png").convert("RGB") | |
| # | |
| # prompt1 = "Output: Underground" | |
| # prompt2 = "Output: Congress ave 400" | |
| # prompt3 = "Output:" | |
| # prompt1 = "This is a start of a highway sign" | |
| # prompt2 = "This is an end of a highway sign" | |
| # prompt3 = "This is a start of a traffic calming zone sign" | |
| # prompt4 = "This is an end of a traffic calming zone sign" | |
| # prompt5 = "What is the meaning of this sign?" | |
| # prompt1 = "This road sign indicates the start of a zone." | |
| # prompt2 = "This road sign indicates the end of a zone." | |
| # prompt3 = "This road sign indicates the start of a zone." | |
| # prompt4 = "This road sign indicates the end of zone." | |
| # prompt5 = "This road sign" | |
| # prompt1 = "Output: This road sign indicates the beginning of something." | |
| # prompt2 = "Output: This road sign indicates the end of something." | |
| # prompt3 = "Output: This road sign indicates the beginning of something." | |
| # prompt4 = "Output: This road sign indicates the end of something." | |
| # prompt5 = "Output: This road sign" | |
| # | |
| prompt1 = "This road sign indicates the start of a zone." | |
| prompt2 = "This road sign indicates the end of a zone." | |
| prompt3 = "This road sign indicates the start of a zone." | |
| prompt4 = "This road sign indicates the end of a zone." | |
| # prompt5 = "Question: {What does this road sign indicates} Short answer:" | |
| prompt5 = "This road sign indicates" | |
| # prompt1 = "Question: start of sign or end of sign? Answer: start of sign." | |
| # prompt2 = "Question: start of sign or end of sign? Answer: end of sign." | |
| # prompt3 = "Question: start of sign or end of sign? Answer: start of sign." | |
| # prompt4 = "Question: start of sign or end of sign? Answer: end of sign." | |
| # prompt5 = "Question: start of sign or end of sign? Answer:" | |
| # | |
| prompt1 = "Question: {start of zone or end of zone} Short answer: start of zone." | |
| prompt2 = "Question: {start of zone or end of zone} Short answer: end of zone." | |
| prompt3 = "Question: {start of zone or end of zone} Short answer: start of zone." | |
| prompt4 = "Question: {start of zone or end of zone} Short answer: end of zone." | |
| prompt5 = "Question: {start of zone or end of zone} Short answer:" | |
| # prompt1 = "Output: start of." | |
| # prompt2 = "Output: end of." | |
| # prompt3 = "Output: start of." | |
| # prompt4 = "Output: end of." | |
| # prompt5 = "Output:" | |
| # set number of available gpus | |
| world_size = torch.cuda.device_count() | |
| print("Total number of available gpus: " + str(torch.cuda.device_count())) | |
| num_of_captions = 3 | |
| model_type = "instruct_blip_flan_t5" | |
| # directory where the images are stored | |
| image_directory = "/fs/scratch/rb_bd_dlp_rng-dl01_cr_AIM_employees/AIM_105/Complex_Image_Search/images" | |
| # path to image | |
| path = os.path.join(image_directory, "DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png") | |
| # LB-UH_104_20180310_084704_f000545_fc00011793_4d87dc.png | |
| # LB-UH_104_20180310_123432_f000110_fc00283857_4d87dc.png | |
| # DS-CN_13R7C_20180508_142445_f000550_fc00191778_4d87dc.png | |
| # DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png | |
| # DS-CN_13R7C_20180517_115758_f000000_fc00020673_4d87dc.png | |
| # DS-CN_13R7C_20180517_130432_f000770_fc00188181_4d87dc.png | |
| # DS-CN_13R7C_20180518_134040_f000550_fc00152871_4d87dc.png | |
| # DS-CN_13R7C_20180518_141444_f000000_fc00235740_4d87dc.png | |
| # DS-CN_13R7C_20180519_130651_f000440_fc00519180_4d87dc.png | |
| # path = "/home/gea1tv/Deploy/complex_image_search/rgb_example.jpg" | |
| # rgb_image = Image.open(path).convert("RGB") | |
| # load image | |
| rgb_image = Image.open(path).convert("L").convert("RGB") | |
| rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/h_way_sign.png").convert("RGB") | |
| rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_h_way_sign.png").convert("RGB") | |
| rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/calm_zone_sign.png").convert("RGB") | |
| rgb_img4 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_calm_zone_sign.png").convert("RGB") | |
| rgb_img5 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/30_sign.png").convert("RGB") | |
| rgb_img6 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_30_sign.png").convert("RGB") | |
| # set up your device | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset. | |
| # this also loads the associated image processors | |
| print("Loading " + model_type + " model...\n") | |
| if model_type == "instruct_blip_flan_t5": | |
| original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5_instruct", model_type="flant5xl", is_eval=True, device=device) | |
| elif model_type == "instruct_blip_vicuna": | |
| original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device) | |
| elif model_type == "blip2_flan_t5_caption": | |
| original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device) | |
| # original_model = original_model.float() | |
| print("Finished Loading model.\n") | |
| # rgb_image = F.crop(rgb_image, 384 * 0, 384 * 1, 384, 384) | |
| img = vis_processors["eval"](rgb_image).unsqueeze(0).to(device) | |
| img1 = vis_processors["eval"](rgb_img1).unsqueeze(0).to(device) | |
| img2 = vis_processors["eval"](rgb_img2).unsqueeze(0).to(device) | |
| img3 = vis_processors["eval"](rgb_img3).unsqueeze(0).to(device) | |
| img4 = vis_processors["eval"](rgb_img4).unsqueeze(0).to(device) | |
| img5 = vis_processors["eval"](rgb_img5).unsqueeze(0).to(device) | |
| img6 = vis_processors["eval"](rgb_img6).unsqueeze(0).to(device) | |
| # resized_image = transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC)(rgb_image) | |
| # | |
| # fig = plt.figure(figsize=(10, 5)) | |
| # | |
| # # Display the image on the upper subplot | |
| # plt.imshow(resized_image, cmap='gray') | |
| # plt.show() | |
| # exit(1) | |
| # prompt = "Question: {What is the meaning of this sign} Short answer:" | |
| prompt = "Can you describe the image in details?" | |
| # prompt = "Can you describe the image in details focusing on road signs?" | |
| # prompt = "A short image description:" | |
| # prompt = "Can you describe the image in details?" | |
| # prompt = "a photo of" | |
| # prompt = "The following road sign is detected in the image: a do not enter sign. Write a detailed description of the image using the detected road sign." | |
| # prompt = "Describe in detail all the road signs in the image including their meaning and their locations" | |
| # answer = original_model.generate({"image": img, "prompt": prompt}, use_nucleus_sampling=True, top_p=0.9, temperature=1) | |
| # answer = original_model.generate({"image": img, "prompt": prompt}) | |
| # answer = original_model.generate({"image": img}) | |
| # prompt = "Road signs that indicate the end of something, such as the end of a specific road condition or traffic regulation, often use specific visual features to convey their meaning. For example: Strikethrough: A striking visual feature is a diagonal line that crosses through the symbol or text indicating what is ending. For instance, if a sign indicates the end of a no-passing zone, the symbol of a no-passing zone with a diagonal line across it could be used. Question: {Does this road sign indicates the end of somthing?} Answer:" | |
| # prompt = "Question: {Does this road sign have a diagonal stripe across it?} Short answer:" | |
| answer = original_model.generate({"image": img1, "prompt": prompt}) | |
| print(answer) | |
| answer = original_model.generate({"image": img2, "prompt": prompt}) | |
| print(answer) | |
| answer = original_model.generate({"image": img3, "prompt": prompt}) | |
| print(answer) | |
| answer = original_model.generate({"image": img4, "prompt": prompt}) | |
| print(answer) | |
| answer = original_model.generate({"image": img5, "prompt": prompt}) | |
| print(answer) | |
| answer = original_model.generate({"image": img6, "prompt": prompt}) | |
| print(answer) | |
| # | |
| # answer2 = original_model.generate({"image": img5, "prompt": prompt + " " + answer[0] + ". Question: why?"}) | |
| # print(answer2) | |
| # | |
| # answer = original_model.generate({"image": img6, "prompt": prompt}) | |
| # print(answer) | |
| # | |
| # answer2 = original_model.generate({"image": img6, "prompt": prompt + " " + answer[0] + ". Question: why?"}) | |
| # print(answer2) | |
| # try in context leraning | |
| # answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img5)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]}) | |
| # | |
| # print(answer) | |
| # # | |
| # answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img6)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]}) | |
| # | |
| # print(answer) | |
| # prompt0 = "a photo of" | |
| # prompt1 = "What is the meaning of this road sign?" | |
| # prompt2 = "Question: {What is the meaning of this road sign} Short answer:" | |
| # prompt3 = "Question: {What is the meaning of this road sign} Answer:" | |
| # prompt4 = "Can you explain the meaning of this road sign?" | |
| # prompt5 = "Question: {start of or end of} Answer:" | |
| # prompt6 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a zone (b) end of a zone. Answer:" | |
| # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of 30 km/h minimal speed limit (b) end of 30 km/h minimal speed limit. Answer:" | |
| # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a highway (b) end of a highway. Answer:" | |
| # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a priority road (b) end of a priority road. Answer:" | |
| # create a figure with the top picks and the text | |
| # fig = display_image_and_text(Image.open(path), answer) | |
| # fig = display_image_and_text(rgb_image, answer, prompt) | |
| # show image | |
| # plt.show() | |