File size: 11,647 Bytes
02d3a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os

# allocated gpus
# os.environ['CUDA_VISIBLE_DEVICES'] = '1,5'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
from lavis.models import load_model_and_preprocess
from PIL import Image
from complex_image_search.utils.display_utils import display_image_and_text
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms.functional as F

if __name__ == "__main__":


    # rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/chinchila.png").convert("RGB")
    # rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/shiba.png").convert("RGB")
    # rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/flamingo.png").convert("RGB")
    #
    # prompt1 = "This is a chinchilla. They are mainly found in Chile."
    # prompt2 = "This is a shiba. They are very popular in Japan."
    # prompt3 = "This is "  #a flamingo. They are found in the Caribbean and South America."

    # rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/underground.png").convert("RGB")
    # rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/congress.png").convert("RGB")
    # rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/Soulomes.png").convert("RGB")
    #
    # prompt1 = "Output: Underground"
    # prompt2 = "Output: Congress ave 400"
    # prompt3 = "Output:"


    # prompt1 = "This is a start of a highway sign"
    # prompt2 = "This is an end of a highway sign"
    # prompt3 = "This is a start of a traffic calming zone sign"
    # prompt4 = "This is an end of a traffic calming zone sign"
    # prompt5 = "What is the meaning of this sign?"

    # prompt1 = "This road sign indicates the start of a zone."
    # prompt2 = "This road sign indicates the end of a zone."
    # prompt3 = "This road sign indicates the start of a zone."
    # prompt4 = "This road sign indicates the end of zone."
    # prompt5 = "This road sign"

    # prompt1 = "Output: This road sign indicates the beginning of something."
    # prompt2 = "Output: This road sign indicates the end of something."
    # prompt3 = "Output: This road sign indicates the beginning of something."
    # prompt4 = "Output: This road sign indicates the end of something."
    # prompt5 = "Output: This road sign"
    #
    prompt1 = "This road sign indicates the start of a zone."
    prompt2 = "This road sign indicates the end of a zone."
    prompt3 = "This road sign indicates the start of a zone."
    prompt4 = "This road sign indicates the end of a zone."
    # prompt5 = "Question: {What does this road sign indicates} Short answer:"
    prompt5 = "This road sign indicates"

    # prompt1 = "Question: start of sign or end of sign? Answer: start of sign."
    # prompt2 = "Question: start of sign or end of sign? Answer: end of sign."
    # prompt3 = "Question: start of sign or end of sign? Answer: start of sign."
    # prompt4 = "Question: start of sign or end of sign? Answer: end of sign."
    # prompt5 = "Question: start of sign or end of sign? Answer:"
    #

    prompt1 = "Question: {start of zone or end of zone} Short answer: start of zone."
    prompt2 = "Question: {start of zone or end of zone} Short answer: end of zone."
    prompt3 = "Question: {start of zone or end of zone} Short answer: start of zone."
    prompt4 = "Question: {start of zone or end of zone} Short answer: end of zone."
    prompt5 = "Question: {start of zone or end of zone} Short answer:"

    # prompt1 = "Output: start of."
    # prompt2 = "Output: end of."
    # prompt3 = "Output: start of."
    # prompt4 = "Output: end of."
    # prompt5 = "Output:"

    # set number of available gpus
    world_size = torch.cuda.device_count()
    print("Total number of available gpus: " + str(torch.cuda.device_count()))

    num_of_captions = 3
    model_type = "instruct_blip_flan_t5"

    # directory where the images are stored
    image_directory = "/fs/scratch/rb_bd_dlp_rng-dl01_cr_AIM_employees/AIM_105/Complex_Image_Search/images"

    # path to image
    path = os.path.join(image_directory, "DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png")

    # LB-UH_104_20180310_084704_f000545_fc00011793_4d87dc.png
    # LB-UH_104_20180310_123432_f000110_fc00283857_4d87dc.png
    # DS-CN_13R7C_20180508_142445_f000550_fc00191778_4d87dc.png
    # DS-CN_13R7C_20180509_130050_f000550_fc00248514_4d87dc.png
    # DS-CN_13R7C_20180517_115758_f000000_fc00020673_4d87dc.png
    # DS-CN_13R7C_20180517_130432_f000770_fc00188181_4d87dc.png
    # DS-CN_13R7C_20180518_134040_f000550_fc00152871_4d87dc.png
    # DS-CN_13R7C_20180518_141444_f000000_fc00235740_4d87dc.png
    # DS-CN_13R7C_20180519_130651_f000440_fc00519180_4d87dc.png

    # path = "/home/gea1tv/Deploy/complex_image_search/rgb_example.jpg"
    # rgb_image = Image.open(path).convert("RGB")

    # load image
    rgb_image = Image.open(path).convert("L").convert("RGB")

    rgb_img1 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/h_way_sign.png").convert("RGB")
    rgb_img2 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_h_way_sign.png").convert("RGB")
    rgb_img3 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/calm_zone_sign.png").convert("RGB")
    rgb_img4 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_calm_zone_sign.png").convert("RGB")
    rgb_img5 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/30_sign.png").convert("RGB")
    rgb_img6 = Image.open("/home/gea1tv/Deploy/complex_image_search/images/end_30_sign.png").convert("RGB")

    # set up your device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # loads BLIP caption base model, with finetuned checkpoints on MSCOCO captioning dataset.
    # this also loads the associated image processors
    print("Loading " + model_type + " model...\n")
    if model_type == "instruct_blip_flan_t5":
        original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5_instruct", model_type="flant5xl", is_eval=True, device=device)
    elif model_type == "instruct_blip_vicuna":
        original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device)
    elif model_type == "blip2_flan_t5_caption":
        original_model, vis_processors, _ = load_model_and_preprocess(name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device)

    # original_model = original_model.float()

    print("Finished Loading model.\n")

    # rgb_image = F.crop(rgb_image, 384 * 0,  384 * 1, 384, 384)
    img = vis_processors["eval"](rgb_image).unsqueeze(0).to(device)

    img1 = vis_processors["eval"](rgb_img1).unsqueeze(0).to(device)
    img2 = vis_processors["eval"](rgb_img2).unsqueeze(0).to(device)
    img3 = vis_processors["eval"](rgb_img3).unsqueeze(0).to(device)
    img4 = vis_processors["eval"](rgb_img4).unsqueeze(0).to(device)
    img5 = vis_processors["eval"](rgb_img5).unsqueeze(0).to(device)
    img6 = vis_processors["eval"](rgb_img6).unsqueeze(0).to(device)

    # resized_image = transforms.Resize((384, 384), interpolation=InterpolationMode.BICUBIC)(rgb_image)
    #
    # fig = plt.figure(figsize=(10, 5))
    #
    # # Display the image on the upper subplot
    # plt.imshow(resized_image, cmap='gray')
    # plt.show()
    # exit(1)

    # prompt = "Question: {What is the meaning of this sign} Short answer:"
    prompt = "Can you describe the image in details?"
    # prompt = "Can you describe the image in details focusing on road signs?"
    # prompt = "A short image description:"
    # prompt = "Can you describe the image in details?"
    # prompt = "a photo of"
    # prompt = "The following road sign is detected in the image: a do not enter sign. Write a detailed description of the image using the detected road sign."
    # prompt = "Describe in detail all the road signs in the image including their meaning and their locations"
    # answer = original_model.generate({"image": img, "prompt": prompt}, use_nucleus_sampling=True, top_p=0.9, temperature=1)
    # answer = original_model.generate({"image": img, "prompt": prompt})
    # answer = original_model.generate({"image": img})

    # prompt = "Road signs that indicate the end of something, such as the end of a specific road condition or traffic regulation, often use specific visual features to convey their meaning. For example: Strikethrough: A striking visual feature is a diagonal line that crosses through the symbol or text indicating what is ending. For instance, if a sign indicates the end of a no-passing zone, the symbol of a no-passing zone with a diagonal line across it could be used. Question: {Does this road sign indicates the end of somthing?} Answer:"
    # prompt = "Question: {Does this road sign have a diagonal stripe across it?} Short answer:"
    answer = original_model.generate({"image": img1, "prompt": prompt})
    print(answer)

    answer = original_model.generate({"image": img2, "prompt": prompt})
    print(answer)

    answer = original_model.generate({"image": img3, "prompt": prompt})
    print(answer)

    answer = original_model.generate({"image": img4, "prompt": prompt})
    print(answer)

    answer = original_model.generate({"image": img5, "prompt": prompt})
    print(answer)

    answer = original_model.generate({"image": img6, "prompt": prompt})
    print(answer)
    #
    # answer2 = original_model.generate({"image": img5, "prompt": prompt + " " + answer[0] + ". Question: why?"})
    # print(answer2)
    #
    # answer = original_model.generate({"image": img6, "prompt": prompt})
    # print(answer)
    #
    # answer2 = original_model.generate({"image": img6, "prompt": prompt + " " + answer[0] + ". Question: why?"})
    # print(answer2)

    # try in context leraning
    # answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img5)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]})
    #
    # print(answer)
    # #
    # answer = original_model.in_context_learning_generate({"image": torch.cat((img1, img2, img4, img3, img6)), "prompt": [prompt1, prompt2, prompt4, prompt3, prompt5]})
    #
    # print(answer)

    # prompt0 = "a photo of"
    # prompt1 = "What is the meaning of this road sign?"
    # prompt2 = "Question: {What is the meaning of this road sign} Short answer:"
    # prompt3 = "Question: {What is the meaning of this road sign} Answer:"
    # prompt4 = "Can you explain the meaning of this road sign?"
    # prompt5 = "Question: {start of or end of} Answer:"
    # prompt6 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a zone (b) end of a zone. Answer:"
    # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of 30 km/h minimal speed limit (b) end of 30 km/h minimal speed limit. Answer:"
    # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a highway (b) end of a highway. Answer:"
    # prompt7 = "Choose the correct option to the following question: what is the meaning of this road sign? Options: (a) start of a priority road (b) end of a priority road. Answer:"


    # create a figure with the top picks and the text
    # fig = display_image_and_text(Image.open(path), answer)
    # fig = display_image_and_text(rgb_image, answer, prompt)


    # show image
    # plt.show()