nas / LAVIS-main /projects /xinstructblip /data_aug /3d_qa_data_generation.py
yuccaaa's picture
Add files using upload-large-folder tool
31ec239 verified
# Copyright (c) 2023, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
from tqdm import tqdm
import argparse
import pandas as pd
import torch
import os
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fuzzywuzzy import fuzz
parser = argparse.ArgumentParser(description="")
parser.add_argument("--shard", type=int, help="The shard number to process.")
parser.add_argument("--mode", type=str, help=['color_removal', 'qa_gen', 'rtc'])
parser.add_argument("--split", type=str, help=['train', 'val'])
parser.add_argument("--original_data_file", type=str, help=['Download csv file from https://huggingface.co/datasets/tiange/Cap3D/blob/main/Cap3D_automated_Objaverse_no3Dword.csv'])
args = parser.parse_args()
shard = args.shard
mode = args.mode
split = args.split
original_data_file = args.original_datafile
# original_data_file = f'/export/einstein-vision/3d_vision/objaverse_captions/objaverse_blip_captions_no3d_{split}.csv'
output_dir = "./3d_qa_data"
os.makedirs(output_dir, exist_ok=True)
## Load Model
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", torch_dtype=torch.float16)
def get_output(input_text, input_len=128, output_len=128):
input_ids = torch.cat([tokenizer(inp, padding='max_length', max_length=input_len, return_tensors="pt").input_ids.to("cuda") for inp in input_text])
outputs = model.generate(input_ids, max_length=output_len)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs
if mode == 'color_removal' or mode == 'all':
df = pd.read_csv(original_data_file, names=["sample_id", "caption"])
print(f"Total captions: {len(df)}")
start_index = shard * (len(df) // 4)
num_rows_to_extract = len(df) // 4
df = df.iloc[start_index:start_index + num_rows_to_extract]
## remove color.
no_color_captions = []
captions = df["caption"].tolist()
num_examples = len(captions)
bs = 64
for i in tqdm(range(0,num_examples, bs)):
input_text = [f"Rewrite the sentence {c} by removing mentions of color." for c in captions[i:i+bs]]
no_color_captions.extend(get_output(input_text, input_len=128, output_len=128))
df['caption_no_color'] = no_color_captions
df.to_csv(os.path.join(output_dir,f'Cap3D_automated_Objaverse_no_color_shard_{shard}_{split}.csv'))
if mode == 'qa_gen' or mode == 'all':
df = pd.read_csv(os.path.join(output_dir,f'/Cap3D_automated_Objaverse_no_color_shard_{shard}_{split}.csv')).dropna()
df = df[df['caption_no_color'].apply(lambda x: len(str(x).split(' ')) > 10)]
print(f"Total number of data: {len(df)}")
captions = df['caption_no_color'].tolist()
num_examples = len(captions)
bs = 32
questions = []
answers = []
extractive = []
for i in tqdm(range(0,num_examples, bs)):
try:
input_text = [f"Generate a potential answer word from the following text: {c} " for c in captions[i:i+bs]]
answers.extend(get_output(input_text, input_len=180, output_len=128))
input_text = [f"Generate a question for the answer using the context. Context: {c} Answer: {q} Question:" for c,q in zip(captions[i:i+bs], answers[i:i+bs])]
questions.extend(get_output(input_text, input_len=180, output_len=30))
extractive.extend([fuzz.partial_ratio(a,c)>90 for c,a in zip(captions[i:i+bs], answers[i:i+bs])])
except:
from pdb import set_trace; set_trace()
df['question'] = questions
df['answer'] = answers
df['extractive'] = extractive
print(f'Number extractive: {len([e for e in extractive if e])}')
df.to_csv(os.path.join(output_dir,f'/Cap3D_automated_Objaverse_no_color_qa_shard_{shard}_{split}.csv'))
if mode == 'rtc' or mode == 'all':
df = pd.read_csv(os.path.join(output_dir, f'Cap3D_automated_Objaverse_no_color_qa_shard_{shard}_{split}.csv')).dropna()
print(f"Total number of data: {len(df)}")
captions = df['caption_no_color'].tolist()
num_examples = len(captions)
bs = 32
questions = df['question'].tolist()
answers =df['answer'].tolist()
correct = []
for i in tqdm(range(0,num_examples, bs)):
try:
input_text = [f"Answer the question given the context. Context: {c} Question: {q} Answer:" for c,q in zip(captions[i:i+bs], questions[i:i+bs])]
outputs = get_output(input_text, input_len=256, output_len=20)
correct.extend([fuzz.partial_ratio(a,c)>90 for c,a in zip(outputs, answers[i:i+bs])])
except:
from pdb import set_trace; set_trace()
df['correct'] = correct
print(f'Number correct: {len([e for e in correct if e])}')
df.to_csv(os.path.join(output_dir, f'/Cap3D_automated_Objaverse_no_color_qa_correct_shard_{shard}_{split}.csv'))
df[df['extractive'] == True][df['correct'] == True].to_csv(os.path.join(output_dir, f'/CAP3DQA_final_shard_{shard}_{split}.csv'))