File size: 5,530 Bytes
7eb3f10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import re
from pathlib import Path
import pickle
import tap
import transformers
from tqdm.auto import tqdm
import torch
from typing import Dict, List, Tuple, Literal, Optional
from utils.utils_with_rlbench import RLBenchEnv, task_file_to_task_class
TextEncoder = Literal["bert", "clip"]
class Arguments(tap.Tap):
tasks: Tuple[str, ...] = ['bimanual_pick_laptop','bimanual_pick_plate','bimanual_straighten_rope',
'coordinated_lift_ball','coordinated_lift_tray','coordinated_push_box','coordinated_put_bottle_in_fridge','dual_push_buttons',
'handover_item','bimanual_sweep_to_dustpan','coordinated_take_tray_out_of_oven','handover_item_easy']
output: Path = 'instructions.pkl'
batch_size: int = 10
encoder: TextEncoder = "clip"
model_max_length: int = 77
variations: Tuple[int, ...] = (1,)
device: str = "cuda"
train_dir: Path = Path("train_dir")
zero: bool = False
verbose: bool = False
def parse_int(s):
return int(re.findall(r"\d+", s)[0])
def load_model(encoder: TextEncoder) -> transformers.PreTrainedModel:
if encoder == "bert":
model = transformers.BertModel.from_pretrained("bert-base-uncased")
elif encoder == "clip":
model = transformers.CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
else:
raise ValueError(f"Unexpected encoder {encoder}")
if not isinstance(model, transformers.PreTrainedModel):
raise ValueError(f"Unexpected encoder {encoder}")
return model
def load_tokenizer(encoder: TextEncoder) -> transformers.PreTrainedTokenizer:
if encoder == "bert":
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
elif encoder == "clip":
tokenizer = transformers.CLIPTokenizer.from_pretrained(
"openai/clip-vit-base-patch32"
)
else:
raise ValueError(f"Unexpected encoder {encoder}")
if not isinstance(tokenizer, transformers.PreTrainedTokenizer):
raise ValueError(f"Unexpected encoder {encoder}")
return tokenizer
def count_and_set_variations(task_dir: Path) -> List[int]:
variations = []
for variation_dir in task_dir.iterdir():
if variation_dir.is_dir() and variation_dir.name.startswith('variation'):
variation_number = int(variation_dir.name.replace('variation', ''))
variations.append(variation_number)
return sorted(variations)
def load_instructions_from_train(train_dir: Path, tasks: Tuple[str, ...], variations: Tuple[int, ...]) -> Dict[str, Dict[int, List[str]]]:
instructions = {}
for task in tasks:
task_dir = train_dir / task
variations = count_and_set_variations(task_dir)
print("variations: ",variations)
if not variations:
print(f"No variations found for task {task}")
continue
instructions[task] = {}
for variation in variations:
variation_dir = task_dir / f"variation{variation}"
pkl_file = variation_dir / "variation_descriptions.pkl"
if pkl_file.exists():
with open(pkl_file, 'rb') as f:
variation_instructions = pickle.load(f)
print("variation_instructions: ",variation_instructions)
instructions[task][variation] = variation_instructions
return instructions
if __name__ == "__main__":
args = Arguments().parse_args()
print(args)
tokenizer = load_tokenizer(args.encoder)
tokenizer.model_max_length = args.model_max_length
model = load_model(args.encoder)
model = model.to(args.device)
env = RLBenchEnv(
data_path="",
apply_rgb=True,
apply_pc=True,
apply_cameras=("over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"),
headless=True,
)
instructions: Dict[str, Dict[int, torch.Tensor]] = {}
tasks = set(args.tasks)
loaded_instructions = load_instructions_from_train(args.train_dir, tasks, args.variations)
for task in tqdm(tasks):
# task_type = task_file_to_task_class(task)
# task_inst = env.env.get_task(task_type)._task
# task_inst.init_task()
instructions[task] = {}
task_dir = args.train_dir / task
variations = count_and_set_variations(task_dir)
for variation in variations:
instr = loaded_instructions.get(task, {}).get(variation)
if instr is None:
print(f"No instructions found for task {task} variation {variation}")
continue
if args.verbose:
print(task, variation, instr)
tokens = tokenizer(instr, padding="max_length", truncation=True)["input_ids"]
lengths = [len(t) for t in tokens]
if any(l > args.model_max_length for l in lengths):
raise RuntimeError(f"Too long instructions: {lengths}")
tokens = torch.tensor(tokens).to(args.device)
with torch.no_grad():
pred = model(tokens).last_hidden_state
instructions[task][variation] = pred.cpu()
if args.zero:
for instr_task in instructions.values():
for variation, instr_var in instr_task.items():
instr_task[variation].fill_(0)
print("Instructions:", sum(len(inst) for inst in instructions.values()))
args.output.parent.mkdir(exist_ok=True)
with open(args.output, "wb") as f:
pickle.dump(instructions, f) |