Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from transformers import CLIPProcessor, CLIPModel | |
| import torch | |
| from PIL import Image | |
| import requests | |
| import os | |
| import random | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_id = "openai/clip-vit-base-patch16" # You can choose a different CLIP model from Hugging Face | |
| clipprocessor = CLIPProcessor.from_pretrained(model_id) | |
| clipmodel = CLIPModel.from_pretrained(model_id).to(device) | |
| model_id = "Salesforce/blip-image-captioning-base" ## load modelID for BLIP | |
| blipmodel = BlipForConditionalGeneration.from_pretrained(model_id) | |
| blipprocessor = BlipProcessor.from_pretrained(model_id) | |
| im_dir = os.path.join(os.getcwd(),'images') | |
| def sample_image(im_dir=im_dir): | |
| all_ims = os.listdir(im_dir) | |
| new_im = random.choice(all_ims) | |
| return gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,new_im),height=500),gr.Textbox(label="Image fname",value=new_im,interactive=False, visible=False) | |
| def evaluate_caption(image, caption): | |
| # # Pre-process image | |
| # image = processor(images=image, return_tensors="pt").to(device) | |
| # # Tokenize and encode the caption | |
| # text = processor(text=caption, return_tensors="pt").to(device) | |
| blip_input = blipprocessor(image, return_tensors="pt") | |
| out = blipmodel.generate(**blip_input,max_new_tokens=50) | |
| blip_caption = blipprocessor.decode(out[0], skip_special_tokens=True) | |
| inputs = clipprocessor(text=[caption,blip_caption], images=image, return_tensors="pt", padding=True) | |
| similarity_score = clipmodel(**inputs).logits_per_image | |
| # Convert score to a float | |
| score = similarity_score.softmax(dim=1).detach().numpy() | |
| print(score) | |
| if score[0][0]>score[0][1]: | |
| winner = "Player 1 wins!" | |
| else: | |
| winner = "Player 2 wins!" | |
| return blip_caption,winner | |
| # ,gr.Image(type="pil", value="mukherjee_kushin_WIDPICS1.jpg") | |
| callback = gr.HuggingFaceDatasetSaver('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', "WID_sym_human_vs_ai") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Welcome to our Human vs. AI game! | |
| You and an AI agent are trying to convince a third AI agent that each of you are better at describing the visual world. \n | |
| In order to win, describe this image in one sentence. Then the second AI agent will also generate a description and the third agent will decide a winner. | |
| You win if the AI says that "Player 1 wins!" | |
| """) | |
| # im_path_str = 'n03418158_2886.JPEG' | |
| im_path_str = random.choice(os.listdir(im_dir)) | |
| im_path = gr.Textbox(label="Image fname",value=im_path_str,interactive=False, visible=False) | |
| # fn=evaluate_caption, | |
| # inputs=["image", "text"] | |
| with gr.Row(): | |
| im = gr.Image(label="Target Image", interactive = False, type="pil",value =os.path.join(im_dir,im_path_str),height=400) | |
| with gr.Column(): | |
| caps = gr.Textbox(label="Player 1 Caption") | |
| submit_btn = gr.Button("Submit!!") | |
| out1 = gr.Textbox(label="Player 2 (Machine) Caption",interactive=False) | |
| # outputs=["text","text"], | |
| with gr.Row(): | |
| with gr.Column(): | |
| out2 = gr.Textbox(label="Winner",interactive=False) | |
| reload_btn = gr.Button("Next Image") | |
| # live=False, | |
| # interpretation="default" | |
| callback.setup([caps, out1, out2, im_path], "flagged_data_points") | |
| # callback.flag([image, caption, blip_caption, winner]) | |
| submit_btn.click(fn = evaluate_caption,inputs = [im,caps], outputs = [out1, out2],api_name="test").success(lambda *args: callback.flag(args), [caps, out1, out2, im_path], None, preprocess=False) | |
| reload_btn.click(fn = sample_image, inputs=None, outputs = [im,im_path] ) | |
| # with gr.Row(): | |
| # btn = gr.Button("Flag") | |
| # btn.click(lambda *args: callback.flag(args), [im, caps, out1, out2], None, preprocess=False) | |
| demo.launch() |