Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from response_db import StResponseDb | |
| from create_cache import Game_Cache | |
| import numpy as np | |
| from PIL import Image | |
| import pandas as pd | |
| import torch | |
| import pickle | |
| import uuid | |
| db = StResponseDb() | |
| css = """ | |
| .chatbot {display:flex;flex-direction:column} | |
| .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
| .msg.user {background-color:cornflowerblue;color:white;align-self:self-end} | |
| .msg.bot {background-color:lightgray} | |
| .na_button {background-color:red;color:red} | |
| """ | |
| from model.run_question_asking_model import return_modules, return_modules_yn | |
| question_model, response_model_simul, _, caption_model = return_modules() | |
| question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn() | |
| class Game_Session: | |
| def __init__(self, taskid, yn, hard_setting): | |
| self.yn = yn | |
| self.hard_setting = hard_setting | |
| global question_model, response_model_simul, caption_model | |
| global question_model_yn, response_model_simul_yn, caption_model_yn | |
| self.question_model = question_model | |
| self.response_model_simul = response_model_simul | |
| self.caption_model = caption_model | |
| self.question_model_yn = question_model_yn | |
| self.response_model_simul_yn = response_model_simul_yn | |
| self.caption_model_yn = caption_model_yn | |
| global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions | |
| self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None | |
| self.captions, self.questions, self.target_questions = None, None, None | |
| self.history = [] | |
| self.game_id = str(uuid.uuid4()) | |
| self.set_curr_models() | |
| def set_curr_models(self): | |
| if self.yn: | |
| self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn | |
| else: | |
| self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul | |
| def get_next_question(self): | |
| return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul) | |
| def ask_a_question(input, taskid, gs): | |
| gs.history.append(input) | |
| gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions) | |
| gs.p_y_xqr = gs.p_y_x*gs.p_r_qy | |
| gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr) | |
| gs.p_y_x = gs.p_y_xqr | |
| gs.questions.remove(gs.history[-2]) | |
| db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1]) | |
| gs.history.append(gs.get_next_question()) | |
| top_prob = torch.max(gs.p_y_x).item() | |
| top_pred = torch.argmax(gs.p_y_x).item() | |
| if top_prob > 0.8: | |
| gs.history = gs.history[:-1] | |
| db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "") | |
| # write some HTML | |
| html = "<div class='chatbot'>" | |
| for m, msg in enumerate(gs.history): | |
| if msg=="nothing": msg="n/a" | |
| cls = "bot" if m%2 == 0 else "user" | |
| html += "<div class='msg {}'> {}</div>".format(cls, msg) | |
| html += "</div>" | |
| ### Game finished: | |
| if top_prob > 0.8: | |
| html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>" | |
| return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False) | |
| else: | |
| if not gs.yn: | |
| return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False) | |
| else: | |
| return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True) | |
| def set_images(taskid): | |
| pilot_study = pd.read_csv("pilot-study.csv") | |
| taskid_original = taskid | |
| taskid = pilot_study['mscoco-id'].tolist()[int(taskid)] | |
| with open(f'cache/{int(taskid)}.p', 'rb') as fp: | |
| game_cache = pickle.load(fp) | |
| gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting) | |
| id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}" | |
| id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}" | |
| id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}" | |
| id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}" | |
| id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}" | |
| id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}" | |
| id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}" | |
| id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}" | |
| id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}" | |
| id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}" | |
| gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10] | |
| gs.image_files = [x[15:] for x in gs.image_files] | |
| gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files] | |
| gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np] | |
| gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device) | |
| gs.captions = gs.curr_caption_model.get_captions(gs.image_files) | |
| gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0) | |
| gs.curr_question_model.reset_question_bank() | |
| gs.curr_question_model.question_bank = game_cache.question_dict | |
| first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul) | |
| first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>" | |
| gs.history.append(first_question) | |
| html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>" | |
| if not gs.yn: | |
| return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False) | |
| else: | |
| return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True) | |
| with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo: | |
| gr.HTML("<h1>Image Q&A Guessing Game</h1>\ | |
| <p style='font-size:120%;'>\ | |
| Imagine you are playing 20-questions with an AI model.<br>\ | |
| The AI model plays the role of the question asker. You play the role of the responder. <br>\ | |
| There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\ | |
| The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\ | |
| <span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\ | |
| Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\ | |
| <b>Guidelines:</b><br>\ | |
| <ol style='font-size:120%;'>\ | |
| <li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\ | |
| <li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\ | |
| </ol> \ | |
| <br>\ | |
| (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\ | |
| <br>\ | |
| <h2>Please enter a TaskID to start</h2>") | |
| with gr.Column(): | |
| with gr.Row(): | |
| taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0) | |
| start_button = gr.Button("Enter") | |
| with gr.Row(): | |
| task_text = gr.HTML() | |
| with gr.Column() as img_block: | |
| with gr.Row(): | |
| img1 = gr.Image(label="Image 1", show_label=True) | |
| img2 = gr.Image(label="Image 2", show_label=True) | |
| img3 = gr.Image(label="Image 3", show_label=True) | |
| img4 = gr.Image(label="Image 4", show_label=True) | |
| img5 = gr.Image(label="Image 5", show_label=True) | |
| with gr.Row(): | |
| img6 = gr.Image(label="Image 6", show_label=True) | |
| img7 = gr.Image(label="Image 7", show_label=True) | |
| img8 = gr.Image(label="Image 8", show_label=True) | |
| img9 = gr.Image(label="Image 9", show_label=True) | |
| img10 = gr.Image(label="Image 10", show_label=True) | |
| conversation = gr.HTML() | |
| game_session_state = gr.State() | |
| answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False) | |
| null_answer = gr.Textbox("nothing", visible=False) | |
| yes_answer = gr.Textbox("yes", visible=False) | |
| no_answer = gr.Textbox("no", visible=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| yes_box = gr.Button("Yes", visible=False) | |
| no_box = gr.Button("No", visible=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| na_box = gr.Button("N/A", visible=False, elem_classes="na_button") | |
| submit = gr.Button("Submit", visible=False) | |
| ### Button click events | |
| start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box]) | |
| submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) | |
| na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) | |
| yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) | |
| no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) | |
| demo.launch() | |