| | |
| | |
| | |
| | |
| |
|
| | from tqdm import tqdm |
| | import argparse |
| | import pandas as pd |
| | import torch |
| | import os |
| | from transformers import T5Tokenizer, T5ForConditionalGeneration |
| | from fuzzywuzzy import fuzz |
| |
|
| | parser = argparse.ArgumentParser(description="") |
| | parser.add_argument("--shard", type=int, help="The shard number to process.") |
| | parser.add_argument("--mode", type=str, help=['color_removal', 'qa_gen', 'rtc']) |
| | parser.add_argument("--split", type=str, help=['train', 'val']) |
| | parser.add_argument("--original_data_file", type=str, help=['Download csv file from https://huggingface.co/datasets/tiange/Cap3D/blob/main/Cap3D_automated_Objaverse_no3Dword.csv']) |
| |
|
| | args = parser.parse_args() |
| | shard = args.shard |
| | mode = args.mode |
| | split = args.split |
| | original_data_file = args.original_datafile |
| | |
| | output_dir = "./3d_qa_data" |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl") |
| | model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", torch_dtype=torch.float16) |
| |
|
| | def get_output(input_text, input_len=128, output_len=128): |
| | input_ids = torch.cat([tokenizer(inp, padding='max_length', max_length=input_len, return_tensors="pt").input_ids.to("cuda") for inp in input_text]) |
| | outputs = model.generate(input_ids, max_length=output_len) |
| | outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | return outputs |
| | |
| |
|
| | if mode == 'color_removal' or mode == 'all': |
| | df = pd.read_csv(original_data_file, names=["sample_id", "caption"]) |
| | print(f"Total captions: {len(df)}") |
| | start_index = shard * (len(df) // 4) |
| | num_rows_to_extract = len(df) // 4 |
| | df = df.iloc[start_index:start_index + num_rows_to_extract] |
| | |
| | no_color_captions = [] |
| | captions = df["caption"].tolist() |
| | num_examples = len(captions) |
| | bs = 64 |
| | for i in tqdm(range(0,num_examples, bs)): |
| | input_text = [f"Rewrite the sentence {c} by removing mentions of color." for c in captions[i:i+bs]] |
| | no_color_captions.extend(get_output(input_text, input_len=128, output_len=128)) |
| |
|
| | df['caption_no_color'] = no_color_captions |
| | df.to_csv(os.path.join(output_dir,f'Cap3D_automated_Objaverse_no_color_shard_{shard}_{split}.csv')) |
| |
|
| | if mode == 'qa_gen' or mode == 'all': |
| | df = pd.read_csv(os.path.join(output_dir,f'/Cap3D_automated_Objaverse_no_color_shard_{shard}_{split}.csv')).dropna() |
| | df = df[df['caption_no_color'].apply(lambda x: len(str(x).split(' ')) > 10)] |
| | print(f"Total number of data: {len(df)}") |
| | captions = df['caption_no_color'].tolist() |
| | num_examples = len(captions) |
| | bs = 32 |
| | questions = [] |
| | answers = [] |
| | extractive = [] |
| | for i in tqdm(range(0,num_examples, bs)): |
| | try: |
| | input_text = [f"Generate a potential answer word from the following text: {c} " for c in captions[i:i+bs]] |
| | answers.extend(get_output(input_text, input_len=180, output_len=128)) |
| | input_text = [f"Generate a question for the answer using the context. Context: {c} Answer: {q} Question:" for c,q in zip(captions[i:i+bs], answers[i:i+bs])] |
| | questions.extend(get_output(input_text, input_len=180, output_len=30)) |
| | extractive.extend([fuzz.partial_ratio(a,c)>90 for c,a in zip(captions[i:i+bs], answers[i:i+bs])]) |
| | except: |
| | from pdb import set_trace; set_trace() |
| | |
| | df['question'] = questions |
| | df['answer'] = answers |
| | df['extractive'] = extractive |
| | print(f'Number extractive: {len([e for e in extractive if e])}') |
| | df.to_csv(os.path.join(output_dir,f'/Cap3D_automated_Objaverse_no_color_qa_shard_{shard}_{split}.csv')) |
| |
|
| | if mode == 'rtc' or mode == 'all': |
| | df = pd.read_csv(os.path.join(output_dir, f'Cap3D_automated_Objaverse_no_color_qa_shard_{shard}_{split}.csv')).dropna() |
| | print(f"Total number of data: {len(df)}") |
| | captions = df['caption_no_color'].tolist() |
| | num_examples = len(captions) |
| | bs = 32 |
| | questions = df['question'].tolist() |
| | answers =df['answer'].tolist() |
| | correct = [] |
| | for i in tqdm(range(0,num_examples, bs)): |
| | try: |
| | input_text = [f"Answer the question given the context. Context: {c} Question: {q} Answer:" for c,q in zip(captions[i:i+bs], questions[i:i+bs])] |
| | outputs = get_output(input_text, input_len=256, output_len=20) |
| | correct.extend([fuzz.partial_ratio(a,c)>90 for c,a in zip(outputs, answers[i:i+bs])]) |
| | except: |
| | from pdb import set_trace; set_trace() |
| | |
| | df['correct'] = correct |
| | print(f'Number correct: {len([e for e in correct if e])}') |
| | df.to_csv(os.path.join(output_dir, f'/Cap3D_automated_Objaverse_no_color_qa_correct_shard_{shard}_{split}.csv')) |
| | |
| | df[df['extractive'] == True][df['correct'] == True].to_csv(os.path.join(output_dir, f'/CAP3DQA_final_shard_{shard}_{split}.csv')) |