| import gradio as gr | |
| import torch | |
| import clip | |
| from PIL import Image | |
| import numpy as np | |
| device = "cpu" | |
| model, preprocess = clip.load("RN50x64", device=device) | |
| def img_process(img1,img2,location_width,location_height,size_width,size_height): | |
| im1=Image.open(img1) | |
| im2=Image.open(img2).convert('RGBA').resize((600,400)) | |
| print(im1.mode) | |
| if im1.mode == 'RGBA': | |
| size = im1.size | |
| im3 = im1.resize((int(size[0]/2),int(size[1]/2))) | |
| r, g, b, a = im3.split() | |
| im2.paste(im3,(50, 50), mask=a) | |
| elif im1.mode == 'RGB': | |
| threshold=240 | |
| size = im1.size | |
| im1 = im1.resize((size_width,size_height)) | |
| im1=im1.convert('RGBA') | |
| arr=np.array(np.asarray(im1)) | |
| r,g,b,a=np.rollaxis(arr,axis=-1) | |
| mask=((r>threshold) | |
| & (g>threshold) | |
| & (b>threshold) | |
| ) | |
| arr[mask,3]=0 | |
| im1=Image.fromarray(arr,mode='RGBA') | |
| r, g, b, a = im1.split() | |
| im2.paste(im1,(location_width,location_height,), mask=a) | |
| return im2 | |
| def itm(obj,back,location_width,location_height,size_width,size_height,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr): | |
| img1 = img_process(obj,back,location_width,location_height,size_width,size_height) | |
| img = preprocess(img1).unsqueeze(0) | |
| obj_prompt = neg_obj if is_obj else pos_obj | |
| attr_prompt = neg_attr if is_attr else pos_attr | |
| text = clip.tokenize([f"a photo of {pos_attr} {pos_obj}",f"a photo of {attr_prompt} {obj_prompt}"]) | |
| with torch.no_grad(): | |
| logits_per_image, logits_per_text = model(img, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| print("Label probs:", probs) | |
| return f"a photo of {pos_attr} {pos_obj}",probs[0][0],f"a photo of {attr_prompt} {obj_prompt}",probs[0][1],img1 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1><center>VL-Checklist Demo</center></h1>") | |
| gr.Markdown(""" | |
| Tips: | |
| - In this demo, you can change the object and attribute of object in the text prompt, and you can also change the size and location of the object. | |
| - Please upload an object image with white background. | |
| - The model we used in the demo is CLIP. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_obj = gr.Image(value ='sample/apple.png',type = "filepath",label='object_img(Plz input an object with white background)') | |
| loc_w = gr.Slider(maximum = 500,label='location_width',step=1) | |
| loc_h = gr.Slider(maximum = 300,label='location_height',step=1) | |
| s_w = gr.Number(value =200,precision=0,label='size_width') | |
| s_h = gr.Number(value =200,precision=0,label='size_height') | |
| gr.Markdown("Click **Submit** to get the output!") | |
| with gr.Column(): | |
| img_back = gr.Image(value ='sample/back.jpg',type = "filepath",label='background_img') | |
| is_obj = gr.Checkbox(value = True,label='Does negative prompt change the object?') | |
| pos_obj = gr.Textbox(value = 'apple',label='positive object') | |
| neg_obj = gr.Textbox(value = 'dog',label='negative object') | |
| is_attr = gr.Checkbox(value = False,label='Does negative prompt change the attribute?') | |
| pos_attr = gr.Textbox(value = 'red',label='positive attribute') | |
| neg_attr = gr.Textbox(value = 'green',label='negative attribute') | |
| with gr.Row(): | |
| btn = gr.Button("Submit",variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_output = gr.Image(type = "pil",label='output_img') | |
| with gr.Column(): | |
| pos_prom = gr.Textbox(label='Positive prompt') | |
| pos_s = gr.Textbox(label='Positive score') | |
| neg_prom = gr.Textbox(label='Negative prompt') | |
| neg_s = gr.Textbox(label='Negative score') | |
| with gr.Row(): | |
| gr.Examples([['sample/apple.png', 'sample/back.jpg',50,50,200,200,True,'apple','dog',False,'red','green'], | |
| ['sample/banana.jpg', 'sample/back.jpg',300,200,200,200,True,'bananas','peaches',False,'yellow','green']], | |
| [img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr], | |
| [pos_prom,pos_s,neg_prom,neg_s,img_output],itm,True) | |
| btn.click(fn=itm,inputs=[img_obj,img_back,loc_w,loc_h,s_w,s_h,is_obj,pos_obj,neg_obj,is_attr,pos_attr,neg_attr], | |
| outputs=[pos_prom,pos_s,neg_prom,neg_s,img_output], | |
| ) | |
| demo.launch() | |