| from operator import truediv |
| import os |
| import re |
| import json |
| import sys |
| import argparse |
| |
| |
| import openai |
| from abc import ABC, abstractmethod |
| |
| |
| |
| from tqdm import tqdm |
| from llava.eval.masp_eval.utils import GPTAPIWrapper |
|
|
| import time |
|
|
| class BaseAPIWrapper(ABC): |
| @abstractmethod |
| def get_completion(self, user_prompt, system_prompt=None): |
| pass |
|
|
| class CHAIR(): |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self.system_prompt = "I am ChatGPT, a virtual assistant based on OpenAI's GPT-4 model. I'm designed to understand and generate human-like text based on the input I receive. My main purpose is to assist with information, answer questions, help with tasks that involve natural language processing, and engage in conversations with users.Please note that while I aim to provide accurate and reliable information, I can't guarantee perfection, and it's always a good idea to consult additional resources or professionals when making critical decisions based on the information I provide." |
| |
| self.openai_obj = GPTAPIWrapper(ak="GjrgjjyJHUbLa15DLnr7t0Bhu6IPqFPj") |
| with open('llava/eval/masp_eval/video_chair/prompts/cap2info.txt', 'r') as file: |
| content = file.read() |
| self.cap_user_prompt = content |
| with open('llava/eval/masp_eval/video_chair/prompts/refine_json.txt', 'r') as file: |
| content = file.read() |
| self.cap_user_prompt_deduplicate = content |
| |
| def cap2info_gpt4(self, cap): |
| user_prompt = self.cap_user_prompt.replace('/video caption/', cap) |
| gpt_ret1, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt) |
| user_prompt = self.cap_user_prompt_deduplicate.replace('/json file/', gpt_ret1) |
| gpt_ret2, msgs = self.openai_obj.get_completion(user_prompt=user_prompt, system_prompt=self.system_prompt, previous_msgs=msgs, last_answer=gpt_ret1) |
| match = re.search(r"(?<=```json\n)([\s\S]*?)(?=```)", gpt_ret2) |
| if match: |
| try: |
| info = json.loads(match.group(1)) |
| except Exception as e: |
| print(match.group(1)) |
| info = None |
| |
| return info |
| else: |
| try: |
| start = gpt_ret2.find('{') |
| end = gpt_ret2.rfind('}') |
| info = json.loads(gpt_ret2[start:end+1]) |
| return info |
| except Exception as e: |
| print(gpt_ret1) |
| print(gpt_ret2) |
| return None |
|
|
|
|
| def post_process_masp_cap_label(evaluator, annotations_file, gt=True): |
| results = [] |
| with open(annotations_file, 'r', encoding='utf-8') as f: |
| annotations = json.load(f) |
| for data in tqdm(annotations): |
| if gt: |
| caption = data['refine_caption'] |
| else: |
| caption = data['masp_inference'] |
| cap_info = evaluator.cap2info_gpt4(caption) |
| data['cap_info'] = cap_info |
| results.append(data) |
| return results |
|
|
|
|
| from multiprocessing import Pool |
|
|
| evaluator = CHAIR() |
|
|
| |
| def process_data(data, gt): |
| if gt: |
| caption = data['refine_caption'] |
| else: |
| caption = data['masp_inference'] |
| cap_info = evaluator.cap2info_gpt4(caption) |
| data['cap_info'] = cap_info |
| return data |
|
|
| |
| def process_annotations(annotations_file, gt=False): |
| |
| with open(annotations_file, 'r', encoding='utf-8') as f: |
| annotations = json.load(f) |
|
|
| |
| pool = Pool(processes=32) |
|
|
| |
| from functools import partial |
| process_data_partial = partial(process_data, gt=gt) |
|
|
| |
| |
| res = [] |
| for data in tqdm(pool.imap_unordered(process_data_partial, annotations), total=len(annotations)): |
| res.append(data) |
| |
| pool.close() |
| pool.join() |
| return res |
|
|
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--cap_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res.json') |
| parser.add_argument("--output_file", type=str, default='/mnt/bn/algo-masp-nas-2/xiangchen/model/masp_models/checkpoints/llava-mistral_gpt4v_adso65k_unfreeze_qformer/video_chair/vid_top1k_res_info.json') |
| parser.add_argument("--gt", type=bool, default=False) |
|
|
| args = parser.parse_args() |
| |
| |
| post_anno = process_annotations(args.cap_file, args.gt) |
| with open(f"{args.output_file}", "w") as file: |
| json.dump(post_anno, file, indent=4) |
|
|
|
|