Spaces:
Runtime error
Runtime error
| from model.model.question_asking_model import get_question_model | |
| from model.model.caption_model import get_caption_model | |
| from model.model.response_model import get_response_model | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image | |
| import argparse | |
| import random | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| import pandas as pd | |
| import logging | |
| from model.utils import logging_handler, image_saver, assert_checks | |
| random.seed(123) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--device', type=str, default='cuda') | |
| parser.add_argument('--include_what', action='store_true') | |
| parser.add_argument('--target_idx', type=int, default=0) | |
| parser.add_argument('--max_num_questions', type=int, default=25) | |
| parser.add_argument('--num_images', type=int, default=10) | |
| parser.add_argument('--beam', type=int, default=1) | |
| parser.add_argument('--num_samples', type=int, default=100) | |
| parser.add_argument('--threshold', type=float, default=0.9) | |
| parser.add_argument('--caption_strategy', type=str, default='simple', choices=['simple', 'granular', 'gtruth']) | |
| parser.add_argument('--sample_strategy', type=str, default='random', choices=['random', 'attribute', 'clip']) | |
| parser.add_argument('--attribute_n', type=int, default=1) # Number of attributes to split | |
| parser.add_argument('--response_type_simul', type=str, default='VQA1', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4']) | |
| parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4']) | |
| parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3']) | |
| parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none']) | |
| parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt3') | |
| parser.add_argument('--save_name', type=str, default=None) | |
| parser.add_argument('--verbose', action='store_true') | |
| args = parser.parse_args() | |
| args.question_strategy='gpt3' | |
| args.include_what=True | |
| args.response_type_simul='VQA1' | |
| args.response_type_gtruth='VQA3' | |
| args.multiplier_mode='soft' | |
| args.sample_strategy='attribute' | |
| args.attribute_n=1 | |
| args.caption_strategy='gtruth' | |
| assert_checks(args) | |
| if args.save_name is None: logger = logging_handler(args.verbose, args.save_name) | |
| args.load_response_model = True | |
| print("1. Loading question model ...") | |
| question_model = get_question_model(args) | |
| args.question_generator = question_model.question_generator | |
| print("2. Loading response model simul ...") | |
| response_model_simul = get_response_model(args, args.response_type_simul) | |
| response_model_simul.to(args.device) | |
| print("3. Loading response model gtruth ...") | |
| response_model_gtruth = get_response_model(args, args.response_type_gtruth) | |
| response_model_gtruth.to(args.device) | |
| print("4. Loading caption model ...") | |
| caption_model = get_caption_model(args, question_model) | |
| def return_modules(): | |
| return question_model, response_model_simul, response_model_gtruth, caption_model | |
| args.question_strategy='rule' | |
| args.include_what=False | |
| args.response_type_simul='VQA1' | |
| args.response_type_gtruth='VQA3' | |
| args.multiplier_mode='none' | |
| args.sample_strategy='attribute' | |
| args.attribute_n=1 | |
| args.caption_strategy='gtruth' | |
| print("1. Loading question model ...") | |
| question_model_yn = get_question_model(args) | |
| args.question_generator_yn = question_model_yn.question_generator | |
| print("2. Loading response model simul ...") | |
| response_model_simul_yn = get_response_model(args, args.response_type_simul) | |
| response_model_simul_yn.to(args.device) | |
| print("3. Loading response model gtruth ...") | |
| response_model_gtruth_yn = get_response_model(args, args.response_type_gtruth) | |
| response_model_gtruth_yn.to(args.device) | |
| print("4. Loading caption model ...") | |
| caption_model_yn = get_caption_model(args, question_model_yn) | |
| def return_modules_yn(): | |
| return question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn | |
| # args.question_strategy='gpt3' | |
| # args.include_what=True | |
| # args.response_type_simul='VQA1' | |
| # args.response_type_gtruth='VQA3' | |
| # args.multiplier_mode='none' | |
| # args.sample_strategy='attribute' | |
| # args.attribute_n=1 | |
| # args.caption_strategy='gtruth' | |
| # assert_checks(args) | |
| # if args.save_name is None: logger = logging_handler(args.verbose, args.save_name) | |
| # args.load_response_model = True | |
| # print("1. Loading question model ...") | |
| # question_model = get_question_model(args) | |
| # args.question_generator = question_model.question_generator | |
| # print("2. Loading response model simul ...") | |
| # response_model_simul = get_response_model(args, args.response_type_simul) | |
| # response_model_simul.to(args.device) | |
| # print("3. Loading response model gtruth ...") | |
| # response_model_gtruth = get_response_model(args, args.response_type_gtruth) | |
| # response_model_gtruth.to(args.device) | |
| # print("4. Loading caption model ...") | |
| # caption_model = get_caption_model(args, question_model) | |
| # # dataloader = DataLoader(dataset=ReferenceGameData(split='test', | |
| # # num_images=args.num_images, | |
| # # num_samples=args.num_samples, | |
| # # sample_strategy=args.sample_strategy, | |
| # # attribute_n=args.attribute_n)) | |
| # def return_modules(): | |
| # return question_model, response_model_simul, response_model_gtruth, caption_model | |
| # # game_lens, game_preds = [], [] | |
| # for t, batch in enumerate(tqdm(dataloader)): | |
| # image_files = [image[0] for image in batch['images'][:args.num_images]] | |
| # image_files = [str(i).split('/')[1] for i in image_files] | |
| # with open("mscoco_images_attribute_n=1.txt", 'a') as f: | |
| # for i in image_files: | |
| # f.write(str(i)+"\n") | |
| # images = [np.asarray(Image.open(f"./../../../data/ms-coco/images/{i}")) for i in image_files] | |
| # images = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images] | |
| # p_y_x = (torch.ones(args.num_images)/args.num_images).to(question_model.device) | |
| # if args.save_name is not None: | |
| # logger = logging_handler(args.verbose, args.save_name, t) | |
| # image_saver(images, args.save_name, t) | |
| # captions = caption_model.get_captions(image_files) | |
| # questions, target_questions = question_model.get_questions(image_files, captions, args.target_idx) | |
| # question_model.reset_question_bank() | |
| # logger.info(questions) | |
| # for idx, c in enumerate(captions): logger.info(f"Image_{idx}: {c}") | |
| # num_questions_original = len(questions) | |
| # for j in range(min(args.max_num_questions, num_questions_original)): | |
| # # Select best question | |
| # question = question_model.select_best_question(p_y_x, questions, images, captions, response_model_simul) | |
| # logger.info(f"Question: {question}") | |
| # # Ask the question and get the model's response | |
| # response = response_model_gtruth.get_response(question, images[args.target_idx], captions[args.target_idx], target_questions, is_a=1-args.include_what) | |
| # logger.info(f"Response: {response}") | |
| # # Update the probabilities | |
| # p_r_qy = response_model_simul.get_p_r_qy(response, question, images, captions) | |
| # logger.info(f"P(r|q,y):\n{np.around(p_r_qy.cpu().detach().numpy(), 3)}") | |
| # p_y_xqr = p_y_x*p_r_qy | |
| # p_y_xqr = p_y_xqr/torch.sum(p_y_xqr)if torch.sum(p_y_xqr) != 0 else torch.zeros_like(p_y_xqr) | |
| # p_y_x = p_y_xqr | |
| # logger.info(f"Updated distribution:\n{np.around(p_y_x.cpu().detach().numpy(), 3)}\n") | |
| # # Don't repeat the same question again in the future | |
| # questions.remove(question) | |
| # # Terminate if probability exceeds threshold or if out of questions to ask | |
| # top_prob = torch.max(p_y_x).item() | |
| # if top_prob >= args.threshold or j==min(args.max_num_questions, num_questions_original)-1: | |
| # game_preds.append(torch.argmax(p_y_x).item()) | |
| # game_lens.append(j+1) | |
| # logger.info(f"pred:{game_preds[-1]} game_len:{game_lens[-1]}") | |
| # break | |
| # logger = logging_handler(args.verbose, args.save_name, "final_results") | |
| # logger.info(f"Game lenths:\n{game_lens}") | |
| # logger.info(sum(game_lens)/len(game_lens)) | |
| # logger.info(f"Predictions:\n{game_preds}") | |
| # logger.info(f"Accuracy:\n{sum([i==args.target_idx for i in game_preds])/len(game_preds)}") |