File size: 4,986 Bytes
31ec239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
 # Copyright (c) 2023, salesforce.com, inc.
 # All rights reserved.
 # SPDX-License-Identifier: BSD-3-Clause
 # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause

from tqdm import tqdm 
import argparse
import pandas as pd
import torch
import os
from transformers import T5Tokenizer, T5ForConditionalGeneration
from fuzzywuzzy import fuzz

parser = argparse.ArgumentParser(description="")
parser.add_argument("--shard", type=int, help="The shard number to process.")
parser.add_argument("--mode", type=str, help=['color_removal', 'qa_gen', 'rtc'])
parser.add_argument("--split", type=str, help=['train', 'val'])
parser.add_argument("--original_data_file", type=str, help=['Download csv file from https://huggingface.co/datasets/tiange/Cap3D/blob/main/Cap3D_automated_Objaverse_no3Dword.csv'])

args = parser.parse_args()
shard = args.shard
mode  = args.mode
split = args.split
original_data_file = args.original_datafile
# original_data_file = f'/export/einstein-vision/3d_vision/objaverse_captions/objaverse_blip_captions_no3d_{split}.csv'
output_dir = "./3d_qa_data"
os.makedirs(output_dir, exist_ok=True)

## Load Model
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto", torch_dtype=torch.float16)

def get_output(input_text, input_len=128, output_len=128):
    input_ids = torch.cat([tokenizer(inp, padding='max_length', max_length=input_len, return_tensors="pt").input_ids.to("cuda") for inp in input_text])
    outputs = model.generate(input_ids, max_length=output_len)
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs
  

if mode == 'color_removal' or mode == 'all':
  df = pd.read_csv(original_data_file, names=["sample_id", "caption"])
  print(f"Total captions: {len(df)}")
  start_index = shard * (len(df) // 4)
  num_rows_to_extract = len(df) // 4
  df = df.iloc[start_index:start_index + num_rows_to_extract]
  ## remove color. 
  no_color_captions = []
  captions = df["caption"].tolist()
  num_examples = len(captions)
  bs = 64
  for i in tqdm(range(0,num_examples, bs)):
    input_text = [f"Rewrite the sentence {c} by removing mentions of color." for c in captions[i:i+bs]]
    no_color_captions.extend(get_output(input_text, input_len=128, output_len=128))

  df['caption_no_color'] = no_color_captions
  df.to_csv(os.path.join(output_dir,f'Cap3D_automated_Objaverse_no_color_shard_{shard}_{split}.csv'))

if mode == 'qa_gen' or mode == 'all':
  df = pd.read_csv(os.path.join(output_dir,f'/Cap3D_automated_Objaverse_no_color_shard_{shard}_{split}.csv')).dropna()
  df = df[df['caption_no_color'].apply(lambda x: len(str(x).split(' ')) > 10)]
  print(f"Total number of data: {len(df)}")
  captions = df['caption_no_color'].tolist()
  num_examples = len(captions)
  bs = 32
  questions = []
  answers = []
  extractive = []
  for i in tqdm(range(0,num_examples, bs)):
    try:
      input_text = [f"Generate a potential answer word from the following text: {c} " for c in captions[i:i+bs]]
      answers.extend(get_output(input_text, input_len=180, output_len=128))
      input_text = [f"Generate a question for the answer using the context. Context: {c} Answer: {q} Question:" for c,q in zip(captions[i:i+bs], answers[i:i+bs])]
      questions.extend(get_output(input_text, input_len=180, output_len=30))
      extractive.extend([fuzz.partial_ratio(a,c)>90 for c,a in zip(captions[i:i+bs], answers[i:i+bs])])
    except:
      from pdb import set_trace; set_trace()
    
  df['question'] = questions
  df['answer'] = answers
  df['extractive'] = extractive
  print(f'Number extractive: {len([e for e in extractive if e])}')
  df.to_csv(os.path.join(output_dir,f'/Cap3D_automated_Objaverse_no_color_qa_shard_{shard}_{split}.csv'))

if mode == 'rtc' or mode == 'all':
  df = pd.read_csv(os.path.join(output_dir, f'Cap3D_automated_Objaverse_no_color_qa_shard_{shard}_{split}.csv')).dropna()
  print(f"Total number of data: {len(df)}")
  captions = df['caption_no_color'].tolist()
  num_examples = len(captions)
  bs = 32
  questions = df['question'].tolist()
  answers =df['answer'].tolist()
  correct = []
  for i in tqdm(range(0,num_examples, bs)):
    try:
      input_text = [f"Answer the question given the context. Context: {c} Question: {q} Answer:" for c,q in zip(captions[i:i+bs], questions[i:i+bs])]
      outputs = get_output(input_text, input_len=256, output_len=20)
      correct.extend([fuzz.partial_ratio(a,c)>90 for c,a in zip(outputs, answers[i:i+bs])])
    except:
      from pdb import set_trace; set_trace()
    
  df['correct'] = correct
  print(f'Number correct: {len([e for e in correct if e])}')
  df.to_csv(os.path.join(output_dir, f'/Cap3D_automated_Objaverse_no_color_qa_correct_shard_{shard}_{split}.csv'))
  
  df[df['extractive'] == True][df['correct'] == True].to_csv(os.path.join(output_dir, f'/CAP3DQA_final_shard_{shard}_{split}.csv'))