RKP64 commited on
Commit
036b419
·
1 Parent(s): a565c22

Upload visual_foundation_models.py

Browse files
Files changed (1) hide show
  1. visual_foundation_models.py +1120 -0
visual_foundation_models.py ADDED
@@ -0,0 +1,1120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionInstructPix2PixPipeline
2
+ from diffusers import EulerAncestralDiscreteScheduler
3
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
4
+ from controlnet_aux import OpenposeDetector, MLSDdetector, HEDdetector
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPSegProcessor, CLIPSegForImageSegmentation
7
+ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
8
+ from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
9
+
10
+ import os
11
+ import random
12
+ import torch
13
+ import cv2
14
+ import re
15
+ import uuid
16
+ from PIL import Image, ImageOps, ImageDraw, ImageFont
17
+ import numpy as np
18
+ import math
19
+ import inspect
20
+ import tempfile
21
+
22
+ from langchain.llms.openai import OpenAI
23
+
24
+ # Grounding DINO
25
+ import groundingdino.datasets.transforms as T
26
+ from groundingdino.models import build_model
27
+ from groundingdino.util import box_ops
28
+ from groundingdino.util.slconfig import SLConfig
29
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
30
+
31
+ # segment anything
32
+ from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
33
+ import matplotlib.pyplot as plt
34
+ import wget
35
+
36
+ def prompts(name, description):
37
+ def decorator(func):
38
+ func.name = name
39
+ func.description = description
40
+ return func
41
+
42
+ return decorator
43
+
44
+ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100):
45
+ new_size = new_image.size
46
+ old_size = old_image.size
47
+ easy_img = np.array(new_image)
48
+ gt_img_array = np.array(old_image)
49
+ pos_w = (new_size[0] - old_size[0]) // 2
50
+ pos_h = (new_size[1] - old_size[1]) // 2
51
+
52
+ kernel_h = cv2.getGaussianKernel(old_size[1], old_size[1] * sigma)
53
+ kernel_w = cv2.getGaussianKernel(old_size[0], old_size[0] * sigma)
54
+ kernel = np.multiply(kernel_h, np.transpose(kernel_w))
55
+
56
+ kernel[steps:-steps, steps:-steps] = 1
57
+ kernel[:steps, :steps] = kernel[:steps, :steps] / kernel[steps - 1, steps - 1]
58
+ kernel[:steps, -steps:] = kernel[:steps, -steps:] / kernel[steps - 1, -(steps)]
59
+ kernel[-steps:, :steps] = kernel[-steps:, :steps] / kernel[-steps, steps - 1]
60
+ kernel[-steps:, -steps:] = kernel[-steps:, -steps:] / kernel[-steps, -steps]
61
+ kernel = np.expand_dims(kernel, 2)
62
+ kernel = np.repeat(kernel, 3, 2)
63
+
64
+ weight = np.linspace(0, 1, steps)
65
+ top = np.expand_dims(weight, 1)
66
+ top = np.repeat(top, old_size[0] - 2 * steps, 1)
67
+ top = np.expand_dims(top, 2)
68
+ top = np.repeat(top, 3, 2)
69
+
70
+ weight = np.linspace(1, 0, steps)
71
+ down = np.expand_dims(weight, 1)
72
+ down = np.repeat(down, old_size[0] - 2 * steps, 1)
73
+ down = np.expand_dims(down, 2)
74
+ down = np.repeat(down, 3, 2)
75
+
76
+ weight = np.linspace(0, 1, steps)
77
+ left = np.expand_dims(weight, 0)
78
+ left = np.repeat(left, old_size[1] - 2 * steps, 0)
79
+ left = np.expand_dims(left, 2)
80
+ left = np.repeat(left, 3, 2)
81
+
82
+ weight = np.linspace(1, 0, steps)
83
+ right = np.expand_dims(weight, 0)
84
+ right = np.repeat(right, old_size[1] - 2 * steps, 0)
85
+ right = np.expand_dims(right, 2)
86
+ right = np.repeat(right, 3, 2)
87
+
88
+ kernel[:steps, steps:-steps] = top
89
+ kernel[-steps:, steps:-steps] = down
90
+ kernel[steps:-steps, :steps] = left
91
+ kernel[steps:-steps, -steps:] = right
92
+
93
+ pt_gt_img = easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]]
94
+ gaussian_gt_img = kernel * gt_img_array + (1 - kernel) * pt_gt_img # gt img with blur img
95
+ gaussian_gt_img = gaussian_gt_img.astype(np.int64)
96
+ easy_img[pos_h:pos_h + old_size[1], pos_w:pos_w + old_size[0]] = gaussian_gt_img
97
+ gaussian_img = Image.fromarray(easy_img)
98
+ return gaussian_img
99
+
100
+ def get_new_image_name(org_img_name, func_name="update"):
101
+ head_tail = os.path.split(org_img_name)
102
+ head = head_tail[0]
103
+ tail = head_tail[1]
104
+ name_split = tail.split('.')[0].split('_')
105
+ this_new_uuid = str(uuid.uuid4())[0:4]
106
+ if len(name_split) == 1:
107
+ most_org_file_name = name_split[0]
108
+ recent_prev_file_name = name_split[0]
109
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
110
+ else:
111
+ assert len(name_split) == 4
112
+ most_org_file_name = name_split[3]
113
+ recent_prev_file_name = name_split[0]
114
+ new_file_name = '{}_{}_{}_{}.png'.format(this_new_uuid, func_name, recent_prev_file_name, most_org_file_name)
115
+ return os.path.join(head, new_file_name)
116
+
117
+ def seed_everything(seed):
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+ torch.manual_seed(seed)
121
+ torch.cuda.manual_seed_all(seed)
122
+ return seed
123
+
124
+ class InstructPix2Pix:
125
+ def __init__(self, device):
126
+ print(f"Initializing InstructPix2Pix to {device}")
127
+ self.device = device
128
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
129
+ self.pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix",
130
+ safety_checker=None,
131
+ torch_dtype=self.torch_dtype).to(device)
132
+ self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
133
+
134
+ @prompts(name="Instruct Image Using Text",
135
+ description="useful when you want to the style of the image to be like the text. "
136
+ "like: make it look like a painting. or make it like a robot. "
137
+ "The input to this tool should be a comma separated string of two, "
138
+ "representing the image_path and the text. ")
139
+ def inference(self, inputs):
140
+ """Change style of image."""
141
+ print("===>Starting InstructPix2Pix Inference")
142
+ image_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
143
+ original_image = Image.open(image_path)
144
+ image = self.pipe(text, image=original_image, num_inference_steps=40, image_guidance_scale=1.2).images[0]
145
+ updated_image_path = get_new_image_name(image_path, func_name="pix2pix")
146
+ image.save(updated_image_path)
147
+ print(f"\nProcessed InstructPix2Pix, Input Image: {image_path}, Instruct Text: {text}, "
148
+ f"Output Image: {updated_image_path}")
149
+ return updated_image_path
150
+
151
+
152
+ class Text2Image:
153
+ def __init__(self, device):
154
+ print(f"Initializing Text2Image to {device}")
155
+ self.device = device
156
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
157
+ self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
158
+ torch_dtype=self.torch_dtype)
159
+ self.pipe.to(device)
160
+ self.a_prompt = 'best quality, extremely detailed'
161
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
162
+ 'fewer digits, cropped, worst quality, low quality'
163
+
164
+ @prompts(name="Generate Image From User Input Text",
165
+ description="useful when you want to generate an image from a user input text and save it to a file. "
166
+ "like: generate an image of an object or something, or generate an image that includes some objects. "
167
+ "The input to this tool should be a string, representing the text used to generate image. ")
168
+ def inference(self, text):
169
+ image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
170
+ prompt = text + ', ' + self.a_prompt
171
+ image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
172
+ image.save(image_filename)
173
+ print(
174
+ f"\nProcessed Text2Image, Input Text: {text}, Output Image: {image_filename}")
175
+ return image_filename
176
+
177
+
178
+ class ImageCaptioning:
179
+ def __init__(self, device):
180
+ print(f"Initializing ImageCaptioning to {device}")
181
+ self.device = device
182
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
183
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
184
+ self.model = BlipForConditionalGeneration.from_pretrained(
185
+ "Salesforce/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
186
+
187
+ @prompts(name="Get Photo Description",
188
+ description="useful when you want to know what is inside the photo. receives image_path as input. "
189
+ "The input to this tool should be a string, representing the image_path. ")
190
+ def inference(self, image_path):
191
+ inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
192
+ out = self.model.generate(**inputs)
193
+ captions = self.processor.decode(out[0], skip_special_tokens=True)
194
+ print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
195
+ return captions
196
+
197
+
198
+ class Image2Canny:
199
+ def __init__(self, device):
200
+ print("Initializing Image2Canny")
201
+ self.low_threshold = 100
202
+ self.high_threshold = 200
203
+
204
+ @prompts(name="Edge Detection On Image",
205
+ description="useful when you want to detect the edge of the image. "
206
+ "like: detect the edges of this image, or canny detection on image, "
207
+ "or perform edge detection on this image, or detect the canny image of this image. "
208
+ "The input to this tool should be a string, representing the image_path")
209
+ def inference(self, inputs):
210
+ image = Image.open(inputs)
211
+ image = np.array(image)
212
+ canny = cv2.Canny(image, self.low_threshold, self.high_threshold)
213
+ canny = canny[:, :, None]
214
+ canny = np.concatenate([canny, canny, canny], axis=2)
215
+ canny = Image.fromarray(canny)
216
+ updated_image_path = get_new_image_name(inputs, func_name="edge")
217
+ canny.save(updated_image_path)
218
+ print(f"\nProcessed Image2Canny, Input Image: {inputs}, Output Text: {updated_image_path}")
219
+ return updated_image_path
220
+
221
+
222
+ class CannyText2Image:
223
+ def __init__(self, device):
224
+ print(f"Initializing CannyText2Image to {device}")
225
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
226
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-canny",
227
+ torch_dtype=self.torch_dtype)
228
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
229
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
230
+ torch_dtype=self.torch_dtype)
231
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
232
+ self.pipe.to(device)
233
+ self.seed = -1
234
+ self.a_prompt = 'best quality, extremely detailed'
235
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
236
+ 'fewer digits, cropped, worst quality, low quality'
237
+
238
+ @prompts(name="Generate Image Condition On Canny Image",
239
+ description="useful when you want to generate a new real image from both the user description and a canny image."
240
+ " like: generate a real image of a object or something from this canny image,"
241
+ " or generate a new real image of a object or something from this edge image. "
242
+ "The input to this tool should be a comma separated string of two, "
243
+ "representing the image_path and the user description. ")
244
+ def inference(self, inputs):
245
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
246
+ image = Image.open(image_path)
247
+ self.seed = random.randint(0, 65535)
248
+ seed_everything(self.seed)
249
+ prompt = f'{instruct_text}, {self.a_prompt}'
250
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
251
+ guidance_scale=9.0).images[0]
252
+ updated_image_path = get_new_image_name(image_path, func_name="canny2image")
253
+ image.save(updated_image_path)
254
+ print(f"\nProcessed CannyText2Image, Input Canny: {image_path}, Input Text: {instruct_text}, "
255
+ f"Output Text: {updated_image_path}")
256
+ return updated_image_path
257
+
258
+
259
+ class Image2Line:
260
+ def __init__(self, device):
261
+ print("Initializing Image2Line")
262
+ self.detector = MLSDdetector.from_pretrained('lllyasviel/ControlNet')
263
+
264
+ @prompts(name="Line Detection On Image",
265
+ description="useful when you want to detect the straight line of the image. "
266
+ "like: detect the straight lines of this image, or straight line detection on image, "
267
+ "or perform straight line detection on this image, or detect the straight line image of this image. "
268
+ "The input to this tool should be a string, representing the image_path")
269
+ def inference(self, inputs):
270
+ image = Image.open(inputs)
271
+ mlsd = self.detector(image)
272
+ updated_image_path = get_new_image_name(inputs, func_name="line-of")
273
+ mlsd.save(updated_image_path)
274
+ print(f"\nProcessed Image2Line, Input Image: {inputs}, Output Line: {updated_image_path}")
275
+ return updated_image_path
276
+
277
+
278
+ class LineText2Image:
279
+ def __init__(self, device):
280
+ print(f"Initializing LineText2Image to {device}")
281
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
282
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-mlsd",
283
+ torch_dtype=self.torch_dtype)
284
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
285
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
286
+ torch_dtype=self.torch_dtype
287
+ )
288
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
289
+ self.pipe.to(device)
290
+ self.seed = -1
291
+ self.a_prompt = 'best quality, extremely detailed'
292
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
293
+ 'fewer digits, cropped, worst quality, low quality'
294
+
295
+ @prompts(name="Generate Image Condition On Line Image",
296
+ description="useful when you want to generate a new real image from both the user description "
297
+ "and a straight line image. "
298
+ "like: generate a real image of a object or something from this straight line image, "
299
+ "or generate a new real image of a object or something from this straight lines. "
300
+ "The input to this tool should be a comma separated string of two, "
301
+ "representing the image_path and the user description. ")
302
+ def inference(self, inputs):
303
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
304
+ image = Image.open(image_path)
305
+ self.seed = random.randint(0, 65535)
306
+ seed_everything(self.seed)
307
+ prompt = f'{instruct_text}, {self.a_prompt}'
308
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
309
+ guidance_scale=9.0).images[0]
310
+ updated_image_path = get_new_image_name(image_path, func_name="line2image")
311
+ image.save(updated_image_path)
312
+ print(f"\nProcessed LineText2Image, Input Line: {image_path}, Input Text: {instruct_text}, "
313
+ f"Output Text: {updated_image_path}")
314
+ return updated_image_path
315
+
316
+
317
+ class Image2Hed:
318
+ def __init__(self, device):
319
+ print("Initializing Image2Hed")
320
+ self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
321
+
322
+ @prompts(name="Hed Detection On Image",
323
+ description="useful when you want to detect the soft hed boundary of the image. "
324
+ "like: detect the soft hed boundary of this image, or hed boundary detection on image, "
325
+ "or perform hed boundary detection on this image, or detect soft hed boundary image of this image. "
326
+ "The input to this tool should be a string, representing the image_path")
327
+ def inference(self, inputs):
328
+ image = Image.open(inputs)
329
+ hed = self.detector(image)
330
+ updated_image_path = get_new_image_name(inputs, func_name="hed-boundary")
331
+ hed.save(updated_image_path)
332
+ print(f"\nProcessed Image2Hed, Input Image: {inputs}, Output Hed: {updated_image_path}")
333
+ return updated_image_path
334
+
335
+
336
+ class HedText2Image:
337
+ def __init__(self, device):
338
+ print(f"Initializing HedText2Image to {device}")
339
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
340
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-hed",
341
+ torch_dtype=self.torch_dtype)
342
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
343
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
344
+ torch_dtype=self.torch_dtype
345
+ )
346
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
347
+ self.pipe.to(device)
348
+ self.seed = -1
349
+ self.a_prompt = 'best quality, extremely detailed'
350
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
351
+ 'fewer digits, cropped, worst quality, low quality'
352
+
353
+ @prompts(name="Generate Image Condition On Soft Hed Boundary Image",
354
+ description="useful when you want to generate a new real image from both the user description "
355
+ "and a soft hed boundary image. "
356
+ "like: generate a real image of a object or something from this soft hed boundary image, "
357
+ "or generate a new real image of a object or something from this hed boundary. "
358
+ "The input to this tool should be a comma separated string of two, "
359
+ "representing the image_path and the user description")
360
+ def inference(self, inputs):
361
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
362
+ image = Image.open(image_path)
363
+ self.seed = random.randint(0, 65535)
364
+ seed_everything(self.seed)
365
+ prompt = f'{instruct_text}, {self.a_prompt}'
366
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
367
+ guidance_scale=9.0).images[0]
368
+ updated_image_path = get_new_image_name(image_path, func_name="hed2image")
369
+ image.save(updated_image_path)
370
+ print(f"\nProcessed HedText2Image, Input Hed: {image_path}, Input Text: {instruct_text}, "
371
+ f"Output Image: {updated_image_path}")
372
+ return updated_image_path
373
+
374
+
375
+ class Image2Scribble:
376
+ def __init__(self, device):
377
+ print("Initializing Image2Scribble")
378
+ self.detector = HEDdetector.from_pretrained('lllyasviel/ControlNet')
379
+
380
+ @prompts(name="Sketch Detection On Image",
381
+ description="useful when you want to generate a scribble of the image. "
382
+ "like: generate a scribble of this image, or generate a sketch from this image, "
383
+ "detect the sketch from this image. "
384
+ "The input to this tool should be a string, representing the image_path")
385
+ def inference(self, inputs):
386
+ image = Image.open(inputs)
387
+ scribble = self.detector(image, scribble=True)
388
+ updated_image_path = get_new_image_name(inputs, func_name="scribble")
389
+ scribble.save(updated_image_path)
390
+ print(f"\nProcessed Image2Scribble, Input Image: {inputs}, Output Scribble: {updated_image_path}")
391
+ return updated_image_path
392
+
393
+
394
+ class ScribbleText2Image:
395
+ def __init__(self, device):
396
+ print(f"Initializing ScribbleText2Image to {device}")
397
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
398
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-scribble",
399
+ torch_dtype=self.torch_dtype)
400
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
401
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
402
+ torch_dtype=self.torch_dtype
403
+ )
404
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
405
+ self.pipe.to(device)
406
+ self.seed = -1
407
+ self.a_prompt = 'best quality, extremely detailed'
408
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
409
+ 'fewer digits, cropped, worst quality, low quality'
410
+
411
+ @prompts(name="Generate Image Condition On Sketch Image",
412
+ description="useful when you want to generate a new real image from both the user description and "
413
+ "a scribble image or a sketch image. "
414
+ "The input to this tool should be a comma separated string of two, "
415
+ "representing the image_path and the user description")
416
+ def inference(self, inputs):
417
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
418
+ image = Image.open(image_path)
419
+ self.seed = random.randint(0, 65535)
420
+ seed_everything(self.seed)
421
+ prompt = f'{instruct_text}, {self.a_prompt}'
422
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
423
+ guidance_scale=9.0).images[0]
424
+ updated_image_path = get_new_image_name(image_path, func_name="scribble2image")
425
+ image.save(updated_image_path)
426
+ print(f"\nProcessed ScribbleText2Image, Input Scribble: {image_path}, Input Text: {instruct_text}, "
427
+ f"Output Image: {updated_image_path}")
428
+ return updated_image_path
429
+
430
+
431
+ class Image2Pose:
432
+ def __init__(self, device):
433
+ print("Initializing Image2Pose")
434
+ self.detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
435
+
436
+ @prompts(name="Pose Detection On Image",
437
+ description="useful when you want to detect the human pose of the image. "
438
+ "like: generate human poses of this image, or generate a pose image from this image. "
439
+ "The input to this tool should be a string, representing the image_path")
440
+ def inference(self, inputs):
441
+ image = Image.open(inputs)
442
+ pose = self.detector(image)
443
+ updated_image_path = get_new_image_name(inputs, func_name="human-pose")
444
+ pose.save(updated_image_path)
445
+ print(f"\nProcessed Image2Pose, Input Image: {inputs}, Output Pose: {updated_image_path}")
446
+ return updated_image_path
447
+
448
+
449
+ class PoseText2Image:
450
+ def __init__(self, device):
451
+ print(f"Initializing PoseText2Image to {device}")
452
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
453
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose",
454
+ torch_dtype=self.torch_dtype)
455
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
456
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
457
+ torch_dtype=self.torch_dtype)
458
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
459
+ self.pipe.to(device)
460
+ self.num_inference_steps = 20
461
+ self.seed = -1
462
+ self.unconditional_guidance_scale = 9.0
463
+ self.a_prompt = 'best quality, extremely detailed'
464
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
465
+ ' fewer digits, cropped, worst quality, low quality'
466
+
467
+ @prompts(name="Generate Image Condition On Pose Image",
468
+ description="useful when you want to generate a new real image from both the user description "
469
+ "and a human pose image. "
470
+ "like: generate a real image of a human from this human pose image, "
471
+ "or generate a new real image of a human from this pose. "
472
+ "The input to this tool should be a comma separated string of two, "
473
+ "representing the image_path and the user description")
474
+ def inference(self, inputs):
475
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
476
+ image = Image.open(image_path)
477
+ self.seed = random.randint(0, 65535)
478
+ seed_everything(self.seed)
479
+ prompt = f'{instruct_text}, {self.a_prompt}'
480
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
481
+ guidance_scale=9.0).images[0]
482
+ updated_image_path = get_new_image_name(image_path, func_name="pose2image")
483
+ image.save(updated_image_path)
484
+ print(f"\nProcessed PoseText2Image, Input Pose: {image_path}, Input Text: {instruct_text}, "
485
+ f"Output Image: {updated_image_path}")
486
+ return updated_image_path
487
+
488
+
489
+ class SegText2Image:
490
+ def __init__(self, device):
491
+ print(f"Initializing SegText2Image to {device}")
492
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
493
+ self.controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-seg",
494
+ torch_dtype=self.torch_dtype)
495
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
496
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
497
+ torch_dtype=self.torch_dtype)
498
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
499
+ self.pipe.to(device)
500
+ self.seed = -1
501
+ self.a_prompt = 'best quality, extremely detailed'
502
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
503
+ ' fewer digits, cropped, worst quality, low quality'
504
+
505
+ @prompts(name="Generate Image Condition On Segmentations",
506
+ description="useful when you want to generate a new real image from both the user description and segmentations. "
507
+ "like: generate a real image of a object or something from this segmentation image, "
508
+ "or generate a new real image of a object or something from these segmentations. "
509
+ "The input to this tool should be a comma separated string of two, "
510
+ "representing the image_path and the user description")
511
+ def inference(self, inputs):
512
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
513
+ image = Image.open(image_path)
514
+ self.seed = random.randint(0, 65535)
515
+ seed_everything(self.seed)
516
+ prompt = f'{instruct_text}, {self.a_prompt}'
517
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
518
+ guidance_scale=9.0).images[0]
519
+ updated_image_path = get_new_image_name(image_path, func_name="segment2image")
520
+ image.save(updated_image_path)
521
+ print(f"\nProcessed SegText2Image, Input Seg: {image_path}, Input Text: {instruct_text}, "
522
+ f"Output Image: {updated_image_path}")
523
+ return updated_image_path
524
+
525
+
526
+ class Image2Depth:
527
+ def __init__(self, device):
528
+ print("Initializing Image2Depth")
529
+ self.depth_estimator = pipeline('depth-estimation')
530
+
531
+ @prompts(name="Predict Depth On Image",
532
+ description="useful when you want to detect depth of the image. like: generate the depth from this image, "
533
+ "or detect the depth map on this image, or predict the depth for this image. "
534
+ "The input to this tool should be a string, representing the image_path")
535
+ def inference(self, inputs):
536
+ image = Image.open(inputs)
537
+ depth = self.depth_estimator(image)['depth']
538
+ depth = np.array(depth)
539
+ depth = depth[:, :, None]
540
+ depth = np.concatenate([depth, depth, depth], axis=2)
541
+ depth = Image.fromarray(depth)
542
+ updated_image_path = get_new_image_name(inputs, func_name="depth")
543
+ depth.save(updated_image_path)
544
+ print(f"\nProcessed Image2Depth, Input Image: {inputs}, Output Depth: {updated_image_path}")
545
+ return updated_image_path
546
+
547
+
548
+ class DepthText2Image:
549
+ def __init__(self, device):
550
+ print(f"Initializing DepthText2Image to {device}")
551
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
552
+ self.controlnet = ControlNetModel.from_pretrained(
553
+ "fusing/stable-diffusion-v1-5-controlnet-depth", torch_dtype=self.torch_dtype)
554
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
555
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
556
+ torch_dtype=self.torch_dtype)
557
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
558
+ self.pipe.to(device)
559
+ self.seed = -1
560
+ self.a_prompt = 'best quality, extremely detailed'
561
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
562
+ ' fewer digits, cropped, worst quality, low quality'
563
+
564
+ @prompts(name="Generate Image Condition On Depth",
565
+ description="useful when you want to generate a new real image from both the user description and depth image. "
566
+ "like: generate a real image of a object or something from this depth image, "
567
+ "or generate a new real image of a object or something from the depth map. "
568
+ "The input to this tool should be a comma separated string of two, "
569
+ "representing the image_path and the user description")
570
+ def inference(self, inputs):
571
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
572
+ image = Image.open(image_path)
573
+ self.seed = random.randint(0, 65535)
574
+ seed_everything(self.seed)
575
+ prompt = f'{instruct_text}, {self.a_prompt}'
576
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
577
+ guidance_scale=9.0).images[0]
578
+ updated_image_path = get_new_image_name(image_path, func_name="depth2image")
579
+ image.save(updated_image_path)
580
+ print(f"\nProcessed DepthText2Image, Input Depth: {image_path}, Input Text: {instruct_text}, "
581
+ f"Output Image: {updated_image_path}")
582
+ return updated_image_path
583
+
584
+
585
+ class Image2Normal:
586
+ def __init__(self, device):
587
+ print("Initializing Image2Normal")
588
+ self.depth_estimator = pipeline("depth-estimation", model="Intel/dpt-hybrid-midas")
589
+ self.bg_threhold = 0.4
590
+
591
+ @prompts(name="Predict Normal Map On Image",
592
+ description="useful when you want to detect norm map of the image. "
593
+ "like: generate normal map from this image, or predict normal map of this image. "
594
+ "The input to this tool should be a string, representing the image_path")
595
+ def inference(self, inputs):
596
+ image = Image.open(inputs)
597
+ original_size = image.size
598
+ image = self.depth_estimator(image)['predicted_depth'][0]
599
+ image = image.numpy()
600
+ image_depth = image.copy()
601
+ image_depth -= np.min(image_depth)
602
+ image_depth /= np.max(image_depth)
603
+ x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
604
+ x[image_depth < self.bg_threhold] = 0
605
+ y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
606
+ y[image_depth < self.bg_threhold] = 0
607
+ z = np.ones_like(x) * np.pi * 2.0
608
+ image = np.stack([x, y, z], axis=2)
609
+ image /= np.sum(image ** 2.0, axis=2, keepdims=True) ** 0.5
610
+ image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
611
+ image = Image.fromarray(image)
612
+ image = image.resize(original_size)
613
+ updated_image_path = get_new_image_name(inputs, func_name="normal-map")
614
+ image.save(updated_image_path)
615
+ print(f"\nProcessed Image2Normal, Input Image: {inputs}, Output Depth: {updated_image_path}")
616
+ return updated_image_path
617
+
618
+
619
+ class NormalText2Image:
620
+ def __init__(self, device):
621
+ print(f"Initializing NormalText2Image to {device}")
622
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
623
+ self.controlnet = ControlNetModel.from_pretrained(
624
+ "fusing/stable-diffusion-v1-5-controlnet-normal", torch_dtype=self.torch_dtype)
625
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
626
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, safety_checker=None,
627
+ torch_dtype=self.torch_dtype)
628
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
629
+ self.pipe.to(device)
630
+ self.seed = -1
631
+ self.a_prompt = 'best quality, extremely detailed'
632
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
633
+ ' fewer digits, cropped, worst quality, low quality'
634
+
635
+ @prompts(name="Generate Image Condition On Normal Map",
636
+ description="useful when you want to generate a new real image from both the user description and normal map. "
637
+ "like: generate a real image of a object or something from this normal map, "
638
+ "or generate a new real image of a object or something from the normal map. "
639
+ "The input to this tool should be a comma separated string of two, "
640
+ "representing the image_path and the user description")
641
+ def inference(self, inputs):
642
+ image_path, instruct_text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
643
+ image = Image.open(image_path)
644
+ self.seed = random.randint(0, 65535)
645
+ seed_everything(self.seed)
646
+ prompt = f'{instruct_text}, {self.a_prompt}'
647
+ image = self.pipe(prompt, image, num_inference_steps=20, eta=0.0, negative_prompt=self.n_prompt,
648
+ guidance_scale=9.0).images[0]
649
+ updated_image_path = get_new_image_name(image_path, func_name="normal2image")
650
+ image.save(updated_image_path)
651
+ print(f"\nProcessed NormalText2Image, Input Normal: {image_path}, Input Text: {instruct_text}, "
652
+ f"Output Image: {updated_image_path}")
653
+ return updated_image_path
654
+
655
+
656
+ class VisualQuestionAnswering:
657
+ def __init__(self, device):
658
+ print(f"Initializing VisualQuestionAnswering to {device}")
659
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
660
+ self.device = device
661
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
662
+ self.model = BlipForQuestionAnswering.from_pretrained(
663
+ "Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
664
+
665
+ @prompts(name="Answer Question About The Image",
666
+ description="useful when you need an answer for a question based on an image. "
667
+ "like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
668
+ "The input to this tool should be a comma separated string of two, representing the image_path and the question")
669
+ def inference(self, inputs):
670
+ image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
671
+ raw_image = Image.open(image_path).convert('RGB')
672
+ inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
673
+ out = self.model.generate(**inputs)
674
+ answer = self.processor.decode(out[0], skip_special_tokens=True)
675
+ print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
676
+ f"Output Answer: {answer}")
677
+ return answer
678
+
679
+
680
+ class Segmenting:
681
+ def __init__(self, device):
682
+ print(f"Inintializing Segmentation to {device}")
683
+ self.device = device
684
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
685
+ self.model_checkpoint_path = os.path.join("checkpoints", "sam")
686
+
687
+ self.download_parameters()
688
+ self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
689
+ self.sam_predictor = SamPredictor(self.sam)
690
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam)
691
+
692
+ def download_parameters(self):
693
+ url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
694
+ if not os.path.exists(self.model_checkpoint_path):
695
+ wget.download(url, out=self.model_checkpoint_path)
696
+
697
+ def show_mask(self, mask, ax, random_color=False):
698
+ if random_color:
699
+ color = np.concatenate([np.random.random(3), np.array([1])], axis=0)
700
+ else:
701
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 1])
702
+ h, w = mask.shape[-2:]
703
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
704
+ ax.imshow(mask_image)
705
+
706
+ def show_box(self, box, ax, label):
707
+ x0, y0 = box[0], box[1]
708
+ w, h = box[2] - box[0], box[3] - box[1]
709
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
710
+ ax.text(x0, y0, label)
711
+
712
+ def get_mask_with_boxes(self, image_pil, image, boxes_filt):
713
+
714
+ size = image_pil.size
715
+ H, W = size[1], size[0]
716
+ for i in range(boxes_filt.size(0)):
717
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
718
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
719
+ boxes_filt[i][2:] += boxes_filt[i][:2]
720
+
721
+ boxes_filt = boxes_filt.cpu()
722
+ transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(self.device)
723
+
724
+ masks, _, _ = self.sam_predictor.predict_torch(
725
+ point_coords=None,
726
+ point_labels=None,
727
+ boxes=transformed_boxes.to(self.device),
728
+ multimask_output=False,
729
+ )
730
+ return masks
731
+
732
+ def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, pred_phrases):
733
+
734
+ image = cv2.imread(image_path)
735
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
736
+ self.sam_predictor.set_image(image)
737
+
738
+ masks = self.get_mask_with_boxes(image_pil, image, boxes_filt)
739
+
740
+ # draw output image
741
+ plt.figure(figsize=(10, 10))
742
+ plt.imshow(image)
743
+ for mask in masks:
744
+ self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
745
+
746
+ updated_image_path = get_new_image_name(image_path, func_name="segmentation")
747
+ plt.axis('off')
748
+ plt.savefig(
749
+ updated_image_path,
750
+ bbox_inches="tight", dpi=300, pad_inches=0.0
751
+ )
752
+ return updated_image_path
753
+
754
+ @prompts(name="Segment the Image",
755
+ description="useful when you want to segment all the part of the image, but not segment a certain object."
756
+ "like: segment all the object in this image, or generate segmentations on this image, "
757
+ "or segment the image,"
758
+ "or perform segmentation on this image, "
759
+ "or segment all the object in this image."
760
+ "The input to this tool should be a string, representing the image_path")
761
+ def inference_all(self, image_path):
762
+ image = cv2.imread(image_path)
763
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
764
+ masks = self.mask_generator.generate(image)
765
+ plt.figure(figsize=(20, 20))
766
+ plt.imshow(image)
767
+ if len(masks) == 0:
768
+ return
769
+ sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
770
+ ax = plt.gca()
771
+ ax.set_autoscale_on(False)
772
+ polygons = []
773
+ color = []
774
+ for ann in sorted_anns:
775
+ m = ann['segmentation']
776
+ img = np.ones((m.shape[0], m.shape[1], 3))
777
+ color_mask = np.random.random((1, 3)).tolist()[0]
778
+ for i in range(3):
779
+ img[:, :, i] = color_mask[i]
780
+ ax.imshow(np.dstack((img, m)))
781
+
782
+ updated_image_path = get_new_image_name(image_path, func_name="segment-image")
783
+ plt.axis('off')
784
+ plt.savefig(
785
+ updated_image_path,
786
+ bbox_inches="tight", dpi=300, pad_inches=0.0
787
+ )
788
+ return updated_image_path
789
+
790
+
791
+ class Text2Box:
792
+ def __init__(self, device):
793
+ print(f"Initializing ObjectDetection to {device}")
794
+ self.device = device
795
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
796
+ self.model_checkpoint_path = os.path.join("checkpoints", "groundingdino")
797
+ self.model_config_path = os.path.join("checkpoints", "grounding_config.py")
798
+ self.download_parameters()
799
+ self.box_threshold = 0.3
800
+ self.text_threshold = 0.25
801
+ self.grounding = (self.load_model()).to(self.device)
802
+
803
+ def download_parameters(self):
804
+ url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
805
+ if not os.path.exists(self.model_checkpoint_path):
806
+ wget.download(url, out=self.model_checkpoint_path)
807
+ config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
808
+ if not os.path.exists(self.model_config_path):
809
+ wget.download(config_url, out=self.model_config_path)
810
+
811
+ def load_image(self, image_path):
812
+ # load image
813
+ image_pil = Image.open(image_path).convert("RGB") # load image
814
+
815
+ transform = T.Compose(
816
+ [
817
+ T.RandomResize([512], max_size=1333),
818
+ T.ToTensor(),
819
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
820
+ ]
821
+ )
822
+ image, _ = transform(image_pil, None) # 3, h, w
823
+ return image_pil, image
824
+
825
+ def load_model(self):
826
+ args = SLConfig.fromfile(self.model_config_path)
827
+ args.device = self.device
828
+ model = build_model(args)
829
+ checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu")
830
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
831
+ print(load_res)
832
+ _ = model.eval()
833
+ return model
834
+
835
+ def get_grounding_boxes(self, image, caption, with_logits=True):
836
+ caption = caption.lower()
837
+ caption = caption.strip()
838
+ if not caption.endswith("."):
839
+ caption = caption + "."
840
+ image = image.to(self.device)
841
+ with torch.no_grad():
842
+ outputs = self.grounding(image[None], captions=[caption])
843
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
844
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
845
+ logits.shape[0]
846
+
847
+ # filter output
848
+ logits_filt = logits.clone()
849
+ boxes_filt = boxes.clone()
850
+ filt_mask = logits_filt.max(dim=1)[0] > self.box_threshold
851
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
852
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
853
+ logits_filt.shape[0]
854
+
855
+ # get phrase
856
+ tokenlizer = self.grounding.tokenizer
857
+ tokenized = tokenlizer(caption)
858
+ # build pred
859
+ pred_phrases = []
860
+ for logit, box in zip(logits_filt, boxes_filt):
861
+ pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, tokenized, tokenlizer)
862
+ if with_logits:
863
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
864
+ else:
865
+ pred_phrases.append(pred_phrase)
866
+
867
+ return boxes_filt, pred_phrases
868
+
869
+ def plot_boxes_to_image(self, image_pil, tgt):
870
+ H, W = tgt["size"]
871
+ boxes = tgt["boxes"]
872
+ labels = tgt["labels"]
873
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
874
+
875
+ draw = ImageDraw.Draw(image_pil)
876
+ mask = Image.new("L", image_pil.size, 0)
877
+ mask_draw = ImageDraw.Draw(mask)
878
+
879
+ # draw boxes and masks
880
+ for box, label in zip(boxes, labels):
881
+ # from 0..1 to 0..W, 0..H
882
+ box = box * torch.Tensor([W, H, W, H])
883
+ # from xywh to xyxy
884
+ box[:2] -= box[2:] / 2
885
+ box[2:] += box[:2]
886
+ # random color
887
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
888
+ # draw
889
+ x0, y0, x1, y1 = box
890
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
891
+
892
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
893
+ # draw.text((x0, y0), str(label), fill=color)
894
+
895
+ font = ImageFont.load_default()
896
+ if hasattr(font, "getbbox"):
897
+ bbox = draw.textbbox((x0, y0), str(label), font)
898
+ else:
899
+ w, h = draw.textsize(str(label), font)
900
+ bbox = (x0, y0, w + x0, y0 + h)
901
+ # bbox = draw.textbbox((x0, y0), str(label))
902
+ draw.rectangle(bbox, fill=color)
903
+ draw.text((x0, y0), str(label), fill="white")
904
+
905
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=2)
906
+
907
+ return image_pil, mask
908
+
909
+ @prompts(name="Detect the Give Object",
910
+ description="useful when you only want to detect or find out given objects in the picture"
911
+ "The input to this tool should be a comma separated string of two, "
912
+ "representing the image_path, the text description of the object to be found")
913
+ def inference(self, inputs):
914
+ image_path, det_prompt = inputs.split(",")
915
+ print(f"image_path={image_path}, text_prompt={det_prompt}")
916
+ image_pil, image = self.load_image(image_path)
917
+
918
+ boxes_filt, pred_phrases = self.get_grounding_boxes(image, det_prompt)
919
+
920
+ size = image_pil.size
921
+ pred_dict = {
922
+ "boxes": boxes_filt,
923
+ "size": [size[1], size[0]], # H,W
924
+ "labels": pred_phrases, }
925
+
926
+ image_with_box = self.plot_boxes_to_image(image_pil, pred_dict)[0]
927
+
928
+ updated_image_path = get_new_image_name(image_path, func_name="detect-something")
929
+ updated_image = image_with_box.resize(size)
930
+ updated_image.save(updated_image_path)
931
+ print(
932
+ f"\nProcessed ObejectDetecting, Input Image: {image_path}, Object to be Detect {det_prompt}, "
933
+ f"Output Image: {updated_image_path}")
934
+ return updated_image_path
935
+
936
+
937
+ class Inpainting:
938
+ def __init__(self, device):
939
+ self.device = device
940
+ self.revision = 'fp16' if 'cuda' in self.device else None
941
+ self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
942
+
943
+ self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
944
+ "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
945
+
946
+ def __call__(self, prompt, image, mask_image, height=512, width=512, num_inference_steps=50):
947
+ update_image = self.inpaint(prompt=prompt, image=image.resize((width, height)),
948
+ mask_image=mask_image.resize((width, height)), height=height, width=width,
949
+ num_inference_steps=num_inference_steps).images[0]
950
+ return update_image
951
+
952
+
953
+ class InfinityOutPainting:
954
+ template_model = True # Add this line to show this is a template model.
955
+ def __init__(self, ImageCaptioning, Inpainting, VisualQuestionAnswering):
956
+ self.ImageCaption = ImageCaptioning
957
+ self.inpaint = Inpainting
958
+ self.ImageVQA = VisualQuestionAnswering
959
+ self.a_prompt = 'best quality, extremely detailed'
960
+ self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
961
+ 'fewer digits, cropped, worst quality, low quality'
962
+
963
+ def get_BLIP_vqa(self, image, question):
964
+ inputs = self.ImageVQA.processor(image, question, return_tensors="pt").to(self.ImageVQA.device,
965
+ self.ImageVQA.torch_dtype)
966
+ out = self.ImageVQA.model.generate(**inputs)
967
+ answer = self.ImageVQA.processor.decode(out[0], skip_special_tokens=True)
968
+ print(f"\nProcessed VisualQuestionAnswering, Input Question: {question}, Output Answer: {answer}")
969
+ return answer
970
+
971
+ def get_BLIP_caption(self, image):
972
+ inputs = self.ImageCaption.processor(image, return_tensors="pt").to(self.ImageCaption.device,
973
+ self.ImageCaption.torch_dtype)
974
+ out = self.ImageCaption.model.generate(**inputs)
975
+ BLIP_caption = self.ImageCaption.processor.decode(out[0], skip_special_tokens=True)
976
+ return BLIP_caption
977
+
978
+ def get_imagine_caption(self, image, imagine):
979
+ BLIP_caption = self.get_BLIP_caption(image)
980
+ caption = BLIP_caption
981
+ print(f'Prompt: {caption}')
982
+ return caption
983
+
984
+ def resize_image(self, image, max_size=1000000, multiple=8):
985
+ aspect_ratio = image.size[0] / image.size[1]
986
+ new_width = int(math.sqrt(max_size * aspect_ratio))
987
+ new_height = int(new_width / aspect_ratio)
988
+ new_width, new_height = new_width - (new_width % multiple), new_height - (new_height % multiple)
989
+ return image.resize((new_width, new_height))
990
+
991
+ def dowhile(self, original_img, tosize, expand_ratio, imagine, usr_prompt):
992
+ old_img = original_img
993
+ while (old_img.size != tosize):
994
+ prompt = self.check_prompt(usr_prompt) if usr_prompt else self.get_imagine_caption(old_img, imagine)
995
+ crop_w = 15 if old_img.size[0] != tosize[0] else 0
996
+ crop_h = 15 if old_img.size[1] != tosize[1] else 0
997
+ old_img = ImageOps.crop(old_img, (crop_w, crop_h, crop_w, crop_h))
998
+ temp_canvas_size = (expand_ratio * old_img.width if expand_ratio * old_img.width < tosize[0] else tosize[0],
999
+ expand_ratio * old_img.height if expand_ratio * old_img.height < tosize[1] else tosize[
1000
+ 1])
1001
+ temp_canvas, temp_mask = Image.new("RGB", temp_canvas_size, color="white"), Image.new("L", temp_canvas_size,
1002
+ color="white")
1003
+ x, y = (temp_canvas.width - old_img.width) // 2, (temp_canvas.height - old_img.height) // 2
1004
+ temp_canvas.paste(old_img, (x, y))
1005
+ temp_mask.paste(0, (x, y, x + old_img.width, y + old_img.height))
1006
+ resized_temp_canvas, resized_temp_mask = self.resize_image(temp_canvas), self.resize_image(temp_mask)
1007
+ image = self.inpaint(prompt=prompt, image=resized_temp_canvas, mask_image=resized_temp_mask,
1008
+ height=resized_temp_canvas.height, width=resized_temp_canvas.width,
1009
+ num_inference_steps=50).resize(
1010
+ (temp_canvas.width, temp_canvas.height), Image.ANTIALIAS)
1011
+ image = blend_gt2pt(old_img, image)
1012
+ old_img = image
1013
+ return old_img
1014
+
1015
+ @prompts(name="Extend An Image",
1016
+ description="useful when you need to extend an image into a larger image."
1017
+ "like: extend the image into a resolution of 2048x1024, extend the image into 2048x1024. "
1018
+ "The input to this tool should be a comma separated string of two, representing the image_path and the resolution of widthxheight")
1019
+ def inference(self, inputs):
1020
+ image_path, resolution = inputs.split(',')
1021
+ width, height = resolution.split('x')
1022
+ tosize = (int(width), int(height))
1023
+ image = Image.open(image_path)
1024
+ image = ImageOps.crop(image, (10, 10, 10, 10))
1025
+ out_painted_image = self.dowhile(image, tosize, 4, True, False)
1026
+ updated_image_path = get_new_image_name(image_path, func_name="outpainting")
1027
+ out_painted_image.save(updated_image_path)
1028
+ print(f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input Resolution: {resolution}, "
1029
+ f"Output Image: {updated_image_path}")
1030
+ return updated_image_path
1031
+
1032
+
1033
+ class ObjectSegmenting:
1034
+ template_model = True # Add this line to show this is a template model.
1035
+
1036
+ def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting):
1037
+ # self.llm = OpenAI(temperature=0)
1038
+ self.grounding = Text2Box
1039
+ self.sam = Segmenting
1040
+
1041
+ @prompts(name="Segment the given object",
1042
+ description="useful when you only want to segment the certain objects in the picture"
1043
+ "according to the given text"
1044
+ "like: segment the cat,"
1045
+ "or can you segment an obeject for me"
1046
+ "The input to this tool should be a comma separated string of two, "
1047
+ "representing the image_path, the text description of the object to be found")
1048
+ def inference(self, inputs):
1049
+ image_path, det_prompt = inputs.split(",")
1050
+ print(f"image_path={image_path}, text_prompt={det_prompt}")
1051
+ image_pil, image = self.grounding.load_image(image_path)
1052
+ boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, det_prompt)
1053
+ updated_image_path = self.sam.segment_image_with_boxes(image_pil, image_path, boxes_filt, pred_phrases)
1054
+ print(
1055
+ f"\nProcessed ObejectSegmenting, Input Image: {image_path}, Object to be Segment {det_prompt}, "
1056
+ f"Output Image: {updated_image_path}")
1057
+ return updated_image_path
1058
+
1059
+
1060
+ class ImageEditing:
1061
+ template_model = True
1062
+
1063
+ def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting, Inpainting: Inpainting):
1064
+ print(f"Initializing ImageEditing")
1065
+ self.sam = Segmenting
1066
+ self.grounding = Text2Box
1067
+ self.inpaint = Inpainting
1068
+
1069
+ def pad_edge(self, mask, padding):
1070
+ # mask Tensor [H,W]
1071
+ mask = mask.numpy()
1072
+ true_indices = np.argwhere(mask)
1073
+ mask_array = np.zeros_like(mask, dtype=bool)
1074
+ for idx in true_indices:
1075
+ padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
1076
+ mask_array[padded_slice] = True
1077
+ new_mask = (mask_array * 255).astype(np.uint8)
1078
+ # new_mask
1079
+ return new_mask
1080
+
1081
+ @prompts(name="Remove Something From The Photo",
1082
+ description="useful when you want to remove and object or something from the photo "
1083
+ "from its description or location. "
1084
+ "The input to this tool should be a comma separated string of two, "
1085
+ "representing the image_path and the object need to be removed. ")
1086
+ def inference_remove(self, inputs):
1087
+ image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
1088
+ return self.inference_replace_sam(f"{image_path},{to_be_removed_txt},background")
1089
+
1090
+ @prompts(name="Replace Something From The Photo",
1091
+ description="useful when you want to replace an object from the object description or "
1092
+ "location with another object from its description. "
1093
+ "The input to this tool should be a comma separated string of three, "
1094
+ "representing the image_path, the object to be replaced, the object to be replaced with ")
1095
+ def inference_replace_sam(self, inputs):
1096
+ image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
1097
+
1098
+ print(f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}")
1099
+ image_pil, image = self.grounding.load_image(image_path)
1100
+ boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, to_be_replaced_txt)
1101
+ image = cv2.imread(image_path)
1102
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1103
+ self.sam.sam_predictor.set_image(image)
1104
+ masks = self.sam.get_mask_with_boxes(image_pil, image, boxes_filt)
1105
+ mask = torch.sum(masks, dim=0).unsqueeze(0)
1106
+ mask = torch.where(mask > 0, True, False)
1107
+ mask = mask.squeeze(0).squeeze(0).cpu() # tensor
1108
+
1109
+ mask = self.pad_edge(mask, padding=20) # numpy
1110
+ mask_image = Image.fromarray(mask)
1111
+
1112
+ updated_image = self.inpaint(prompt=replace_with_txt, image=image_pil,
1113
+ mask_image=mask_image)
1114
+ updated_image_path = get_new_image_name(image_path, func_name="replace-something")
1115
+ updated_image = updated_image.resize(image_pil.size)
1116
+ updated_image.save(updated_image_path)
1117
+ print(
1118
+ f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
1119
+ f"Output Image: {updated_image_path}")
1120
+ return updated_image_path