Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import random | |
| import os | |
| import pickle | |
| from PIL import Image | |
| # global raw_image_path,candidate_image1_path,candidate_image2_path | |
| class Process: | |
| gt_image= "" | |
| raw_image_path = "" | |
| candidate_image1_path = "" | |
| candidate_image2_path = "" | |
| pkl_data = None | |
| positive_cand = [] | |
| negative_cand = [] | |
| positive1_cand = [] | |
| positive2_cand = [] | |
| positive_common_cand = [] | |
| schedule = 0 | |
| idx_to_chain = {} | |
| global process | |
| process = Process() | |
| def load_data_and_produce_list(dataset,exp_mode, concept_choices): | |
| if dataset == "ocl": | |
| #TODO | |
| attr_name = ['wooden', 'metal', 'flying', 'ripe', 'fresh', 'natural', 'cooked', 'painted', 'rusty', 'furry'] | |
| attr2idx = {item:idx for idx,item in enumerate(attr_name)} | |
| idx_2_attr = {value:key for key,value in attr2idx.items()} | |
| pkl_path = "Data/updated_OCL_test_data.pkl" | |
| image_dir = "Data/refined_OCL_test" | |
| with open(pkl_path,"rb") as f: | |
| data = pickle.load(f) | |
| process.pkl_data = data | |
| if exp_mode == "One concept": | |
| positive_cand = [] | |
| negative_cand = [] | |
| for data_idx,each_data in enumerate(data): | |
| save_flag = False | |
| for obj_idx,each_obj in enumerate(each_data["objects"]): | |
| if attr2idx[process.idx_to_chain[concept_choices]] in each_obj['attr']: | |
| positive_cand.append([data_idx,obj_idx]) | |
| save_flag = True | |
| break | |
| if save_flag == False and len(negative_cand) < 200: | |
| obj_idx = random.choice(range(len(each_data['objects']))) | |
| negative_cand.append([data_idx,obj_idx]) | |
| process.positive_cand = positive_cand | |
| process.negative_cand = negative_cand | |
| else: | |
| negative_cand = [] | |
| positive1_cand = [] | |
| positive2_cand = [] | |
| positive_common_cand = [] | |
| for data_idx,each_data in enumerate(data): | |
| selected_concept_group = process.idx_to_chain[concept_choices].split("-") | |
| save_flag = False | |
| for obj_idx, each_obj in enumerate(each_data['objects']): | |
| if attr2idx[selected_concept_group[0]] in each_obj['attr'] and attr2idx[selected_concept_group[1]] in each_obj['attr']: | |
| positive_common_cand.append([data_idx,obj_idx]) | |
| save_flag = True | |
| break | |
| elif attr2idx[selected_concept_group[0]] in each_obj['attr']: | |
| positive1_cand.append([data_idx,obj_idx]) | |
| save_flag = True | |
| break | |
| elif attr2idx[selected_concept_group[1]] in each_obj['attr']: | |
| positive2_cand.append([data_idx,obj_idx]) | |
| save_flag = True | |
| break | |
| if save_flag == False and len(negative_cand) < 200: | |
| obj_idx = random.choice(range(len(each_data['objects']))) | |
| negative_cand.append([data_idx,obj_idx]) | |
| process.positive1_cand = positive1_cand | |
| process.positive2_cand = positive2_cand | |
| process.positive_common_cand = positive_common_cand | |
| process.negative_cand = negative_cand | |
| elif dataset == "hmdb": | |
| attr_name = ['brush_hair','clap', 'dive', 'shake_hands','hug' ,'sit','smoke','eat'] | |
| attr2idx = {key:item for key,item in enumerate(attr_name)} | |
| image_dir = "Data/refined_HMDB" | |
| pkl_path = "Data/refined_HMDB.pkl" | |
| with open(pkl_path,"rb") as f: | |
| data = pickle.load(f) | |
| if exp_mode == "One concept": | |
| positive_cand = [] | |
| negative_cand = [] | |
| for each_data in data: | |
| each_data['name'] = os.path.join(image_dir,each_data['name']) | |
| if process.idx_to_chain[concept_choices] in each_data["label"]: | |
| positive_cand.append(each_data) | |
| else: | |
| negative_cand.append(each_data) | |
| if len(positive_cand) > 30 and len(negative_cand) > 100: | |
| break | |
| process.positive_cand = positive_cand | |
| process.negative_cand = negative_cand | |
| else: | |
| negative_cand = [] | |
| positive1_cand = [] | |
| positive2_cand = [] | |
| positive_common_cand = [] | |
| for each_data in data: | |
| each_data['name'] = os.path.join(image_dir,each_data['name']) | |
| selected_concept_group = process.idx_to_chain[concept_choices].split("-") | |
| if selected_concept_group[0] in each_data["name"] and selected_concept_group[1] in each_data["name"]: | |
| positive_common_cand.append(each_data) | |
| elif selected_concept_group[0] in each_data["name"]: | |
| positive1_cand.append(each_data) | |
| elif selected_concept_group[1] in each_data["name"]: | |
| positive2_cand.append(each_data) | |
| else: | |
| if len(negative_cand) <= 100: | |
| negative_cand.append(each_data) | |
| process.positive1_cand = positive1_cand | |
| process.positive2_cand = positive2_cand | |
| process.positive_common_cand = positive_common_cand | |
| process.negative_cand = negative_cand | |
| TARGET_SIZE = (200,200) | |
| # raw_image_path_list = ["Association/images/true_2.jpg"] | |
| # candidate_image1_path_list = ["Association/images/true_3.jpg",] | |
| # candidate_image2_path_list = ["Association/images/false_2.jpg",] | |
| def load_images(dataset, raw_image_path, candidate_image1_path, candidate_image2_path): | |
| if dataset == "ocl": | |
| raw_data = process.pkl_data[raw_image_path[0]] | |
| img_path = os.path.join("Data/refined_OCL_test",raw_data["name"].split("/")[-1]) | |
| raw_image = Image.open(img_path).crop(raw_data['objects'][raw_image_path[1]]['box']).resize(TARGET_SIZE) | |
| candidate_data1 = process.pkl_data[candidate_image1_path[0]] | |
| cand1_img_path = os.path.join("Data/refined_OCL_test",candidate_data1["name"].split("/")[-1]) | |
| candidate_image1 = Image.open(cand1_img_path).crop(candidate_data1['objects'][candidate_image1_path[1]]['box']).resize(TARGET_SIZE) | |
| candidate_data2 = process.pkl_data[candidate_image2_path[0]] | |
| cand2_img_path = os.path.join("Data/refined_OCL_test",candidate_data2["name"].split("/")[-1]) | |
| candidate_image2 = Image.open(cand2_img_path).crop(candidate_data2['objects'][candidate_image2_path[1]]['box']).resize(TARGET_SIZE) | |
| else: | |
| raw_image = Image.open(raw_image_path['name']).resize(TARGET_SIZE) | |
| candidate_image1 = Image.open(candidate_image1_path['name']).resize(TARGET_SIZE) | |
| candidate_image2 = Image.open(candidate_image2_path['name']).resize(TARGET_SIZE) | |
| return raw_image, candidate_image1, candidate_image2 | |
| def load_candidate_images(dataset, cand_image,candidate_image1_path,candidate_image2_path): | |
| raw_image = cand_image | |
| if dataset == "ocl": | |
| candidate_data1 = process.pkl_data[candidate_image1_path[0]] | |
| cand1_img_path = os.path.join("Data/refined_OCL_test",candidate_data1["name"].split("/")[-1]) | |
| candidate_image1 = Image.open(cand1_img_path).crop(candidate_data1['objects'][candidate_image1_path[1]]['box']).resize(TARGET_SIZE) | |
| candidate_data2 = process.pkl_data[candidate_image2_path[0]] | |
| cand2_img_path = os.path.join("Data/refined_OCL_test",candidate_data2["name"].split("/")[-1]) | |
| candidate_image2 = Image.open(cand2_img_path).crop(candidate_data2['objects'][candidate_image2_path[1]]['box']).resize(TARGET_SIZE) | |
| else: | |
| candidate_image1 = Image.open(candidate_image1_path['name']).resize(TARGET_SIZE) | |
| candidate_image2 = Image.open(candidate_image2_path['name']).resize(TARGET_SIZE) | |
| return raw_image,candidate_image1,candidate_image2 | |
| class InferenceDemo(object): | |
| def __init__(self,args,dataset,exp_mode,concept_choices): | |
| print("init success") | |
| def get_concept_choices(dataset,exp_mode): | |
| # if dataset == "ocl": | |
| if exp_mode == "One concept": | |
| # choices = ["furry","metal","fresh","cooked","natural","ripe","painted","rusty"] | |
| choices = [f"Chain_{i}" for i in range(8)] | |
| else: | |
| choices = [f"Chain_{i}" for i in range(4)] | |
| # choices = ["furry-metal","fresh-cooked","natural-ripe","painted-rusty"] | |
| # elif dataset == "hmdb": | |
| # if exp_mode == "One concept": | |
| # choices = ["brush_hair","dive","clap","hug","shake_hands","sit","smoke","eat"] | |
| # else: | |
| # choices = ["brush_hair-dive","clap-hug","shake_hands-sit","smoke-eat"] | |
| return gr.update(choices=choices) | |
| def load_images_and_concepts(dataset,exp_mode,concept_choices): | |
| # concept_choices = get_concept_choices(dataset,exp_mode) | |
| idx_2_chain = {} | |
| if dataset == "ocl": | |
| if exp_mode == "One concept": | |
| concept = ["furry","metal","fresh","cooked","natural","ripe","painted","rusty"] | |
| for idx in range(8): | |
| idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
| else: | |
| concept = ["furry-metal","fresh-cooked","natural-ripe","painted-rusty"] | |
| for idx in range(4): | |
| idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
| else: | |
| if exp_mode == "One concept": | |
| concept = ["brush_hair","dive","clap","hug","shake_hands","sit","smoke","eat"] | |
| for idx in range(8): | |
| idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
| else: | |
| concept = ["brush_hair-dive","clap-hug","shake_hands-sit","smoke-eat"] | |
| for idx in range(4): | |
| idx_2_chain[f"Chain_{idx}"] = concept[idx] | |
| process.idx_to_chain = idx_2_chain | |
| load_data_and_produce_list(dataset,exp_mode,concept_choices) | |
| if exp_mode == "One concept": | |
| if random.random() < 0.5: | |
| process.raw_image_path = random.choice(process.positive_cand) | |
| process.candidate_image1_path = random.choice(process.positive_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| process.gt_image = "Image1" | |
| else: | |
| process.raw_image_path = random.choice(process.positive_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| process.candidate_image2_path = random.choice(process.positive_cand) | |
| process.gt_image = "Image2" | |
| else: | |
| if random.random() < 0.5: | |
| process.raw_image_path = random.choice(process.positive1_cand) | |
| process.candidate_image1_path = random.choice(process.positive1_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| process.gt_image = "Image1" | |
| else: | |
| process.raw_image_path = random.choice(process.positive1_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| process.candidate_image2_path = random.choice(process.positive1_cand) | |
| process.gt_image = "Image2" | |
| raw_image,candidate_image1,candidate_image2 = load_images(dataset, process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
| return raw_image,candidate_image1,candidate_image2 | |
| def count_and_reload_images(dataset,exp_mode, select_input,show_result, steps,raw_image,candidate_image1,candidate_image2): | |
| if select_input != None: | |
| if select_input == process.gt_image: | |
| if exp_mode == "One concept": | |
| if process.gt_image == "Image1": | |
| candidate_image = candidate_image1 | |
| else: | |
| candidate_image = candidate_image2 | |
| if random.random() < 0.5: | |
| process.candidate_image1_path = random.choice(process.positive_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| process.gt_image = "Image1" | |
| else: | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| process.candidate_image2_path = random.choice(process.positive_cand) | |
| process.gt_image = "Image2" | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| else: | |
| if process.gt_image == "Image1": | |
| candidate_image = candidate_image1 | |
| else: | |
| candidate_image = candidate_image2 | |
| if random.random() < 0.5: | |
| if process.schedule < 3: | |
| process.candidate_image1_path = random.choice(process.positive1_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule += 1 | |
| elif process.schedule == 3: | |
| if len(process.positive_common_cand) != 0: | |
| process.candidate_image1_path = random.choice(process.positive_common_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| else: | |
| process.raw_image_path = random.choice(process.positive2_cand) | |
| process.candidate_image1_path = random.choice(process.positive2_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule += 1 | |
| elif process.schedule < 7: | |
| process.candidate_image1_path = random.choice(process.positive2_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule += 1 | |
| elif process.schedule == 7: | |
| if len(process.positive_common_cand) != 0: | |
| process.candidate_image1_path = random.choice(process.positive_common_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| else: | |
| process.raw_image_path = random.choice(process.positive1_cand) | |
| process.candidate_image1_path = random.choice(process.positive1_cand) | |
| process.candidate_image2_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule = 0 | |
| process.gt_image = "Image1" | |
| else: | |
| if process.schedule < 3: | |
| process.candidate_image2_path = random.choice(process.positive1_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule += 1 | |
| elif process.schedule == 3: | |
| if len(process.positive_common_cand) != 0: | |
| process.candidate_image2_path = random.choice(process.positive_common_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| else: | |
| process.raw_image_path = random.choice(process.positive2_cand) | |
| process.candidate_image2_path = random.choice(process.positive2_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule += 1 | |
| elif process.schedule < 7: | |
| process.candidate_image2_path = random.choice(process.positive2_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule += 1 | |
| elif process.schedule == 7: | |
| if len(process.positive_common_cand) != 0: | |
| process.candidate_image2_path = random.choice(process.positive_common_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_candidate_images(dataset,candidate_image,process.candidate_image1_path,process.candidate_image2_path) | |
| else: | |
| process.raw_image_path = random.choice(process.positive1_cand) | |
| process.candidate_image2_path = random.choice(process.positive1_cand) | |
| process.candidate_image1_path = random.choice(process.negative_cand) | |
| raw_image,candidate_image1,candidate_image2 = load_images(dataset,process.raw_image_path,process.candidate_image1_path,process.candidate_image2_path) | |
| process.schedule = 0 | |
| process.gt_image = "Image2" | |
| # candidate_image1_path = random.choice(candidate_image1_path_list) | |
| # candidate_image2_path = random.choice(candidate_image2_path_list) | |
| # process.gt_image = "Image2" | |
| steps = int(steps) + 1 | |
| select_input = None | |
| show_result = "Success!" | |
| else: | |
| show_result = "Error, Please reset!" | |
| process.gt_image = None | |
| return select_input,show_result, steps,raw_image,candidate_image1,candidate_image2 | |
| with gr.Blocks() as demo: | |
| # Informations | |
| # <img src="images/android-chrome-192x192.png" alt="RHOS" width="100" height="100"> <img src="images/android-chrome-192x192.png" alt="RHOS" width="100" height="100"> | |
| title_markdown = (""" | |
| # MLLM Associstion | |
| [[Paper]](https://mvig-rhos.com) [[Code]](https://github.com/lihong2303/MLLMs_Association) | |
| """) | |
| #  | |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| gr.Markdown(title_markdown) | |
| with gr.Row(): | |
| with gr.Column(): | |
| raw_image = gr.Image(label="Raw Image",interactive=False) | |
| with gr.Column(): | |
| candidate_image1 = gr.Image(label="Candidate Image 1",interactive=False) | |
| with gr.Column(): | |
| candidate_image2 = gr.Image(label="Candidate Image 2",interactive=False) | |
| with gr.Row(): | |
| dataset = gr.Dropdown(choices=["ocl","hmdb"],label="Select a dataset",interactive=True) | |
| exp_mode = gr.Dropdown(choices=["One concept","Two concepts"],label="Select a test mode",interactive=True) | |
| concept_choices = gr.Dropdown(choices=[],label = "Select the chain",interactive=True) | |
| with gr.Row(): | |
| select_input = gr.Radio(choices=["Image1","Image2"],label="Select candidate image") | |
| steps = gr.Label(value="0",label="Steps") | |
| show_result = gr.Label(value="",label="Selected Result") | |
| # reset_button = gr.Button(text="Reset") | |
| exp_mode.change(fn=get_concept_choices,inputs=[dataset,exp_mode],outputs=concept_choices) | |
| # process data list | |
| # select init image_path | |
| # process.raw_image_path = "Association/images/true_1.jpg" | |
| # process.candidate_image1_path = "Association/images/true_2.jpg" | |
| # process.candidate_image2_path = "Association/images/false_1.jpg" | |
| # global gt_image | |
| concept_choices.change(fn=load_images_and_concepts, | |
| inputs=[dataset,exp_mode,concept_choices], | |
| outputs=[raw_image,candidate_image1,candidate_image2]) | |
| select_input.change(fn=count_and_reload_images,inputs=[dataset,exp_mode,select_input,show_result,steps,raw_image,candidate_image1,candidate_image2],outputs=[select_input,show_result,steps,raw_image,candidate_image1,candidate_image2]) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| import argparse | |
| argparser = argparse.ArgumentParser() | |
| argparser.add_argument("--server_name", default="0.0.0.0", type=str) | |
| argparser.add_argument("--port", default="6123", type=str) | |
| args = argparser.parse_args() | |
| try: | |
| demo.launch(server_name=args.server_name, server_port=int(args.port),share=True) | |
| except Exception as e: | |
| args.port=int(args.port)+1 | |
| print(f"Port {args.port} is occupied, try port {args.port}") | |
| demo.launch(server_name=args.server_name, server_port=int(args.port),share=True) | |