lsnu's picture
Add files using upload-large-folder tool
7eb3f10 verified
import re
from pathlib import Path
import pickle
import tap
import transformers
from tqdm.auto import tqdm
import torch
from typing import Dict, List, Tuple, Literal, Optional
from utils.utils_with_rlbench import RLBenchEnv, task_file_to_task_class
TextEncoder = Literal["bert", "clip"]
class Arguments(tap.Tap):
tasks: Tuple[str, ...] = ['bimanual_pick_laptop','bimanual_pick_plate','bimanual_straighten_rope',
'coordinated_lift_ball','coordinated_lift_tray','coordinated_push_box','coordinated_put_bottle_in_fridge','dual_push_buttons',
'handover_item','bimanual_sweep_to_dustpan','coordinated_take_tray_out_of_oven','handover_item_easy']
output: Path = 'instructions.pkl'
batch_size: int = 10
encoder: TextEncoder = "clip"
model_max_length: int = 77
variations: Tuple[int, ...] = (1,)
device: str = "cuda"
train_dir: Path = Path("train_dir")
zero: bool = False
verbose: bool = False
def parse_int(s):
return int(re.findall(r"\d+", s)[0])
def load_model(encoder: TextEncoder) -> transformers.PreTrainedModel:
if encoder == "bert":
model = transformers.BertModel.from_pretrained("bert-base-uncased")
elif encoder == "clip":
model = transformers.CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
else:
raise ValueError(f"Unexpected encoder {encoder}")
if not isinstance(model, transformers.PreTrainedModel):
raise ValueError(f"Unexpected encoder {encoder}")
return model
def load_tokenizer(encoder: TextEncoder) -> transformers.PreTrainedTokenizer:
if encoder == "bert":
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
elif encoder == "clip":
tokenizer = transformers.CLIPTokenizer.from_pretrained(
"openai/clip-vit-base-patch32"
)
else:
raise ValueError(f"Unexpected encoder {encoder}")
if not isinstance(tokenizer, transformers.PreTrainedTokenizer):
raise ValueError(f"Unexpected encoder {encoder}")
return tokenizer
def count_and_set_variations(task_dir: Path) -> List[int]:
variations = []
for variation_dir in task_dir.iterdir():
if variation_dir.is_dir() and variation_dir.name.startswith('variation'):
variation_number = int(variation_dir.name.replace('variation', ''))
variations.append(variation_number)
return sorted(variations)
def load_instructions_from_train(train_dir: Path, tasks: Tuple[str, ...], variations: Tuple[int, ...]) -> Dict[str, Dict[int, List[str]]]:
instructions = {}
for task in tasks:
task_dir = train_dir / task
variations = count_and_set_variations(task_dir)
print("variations: ",variations)
if not variations:
print(f"No variations found for task {task}")
continue
instructions[task] = {}
for variation in variations:
variation_dir = task_dir / f"variation{variation}"
pkl_file = variation_dir / "variation_descriptions.pkl"
if pkl_file.exists():
with open(pkl_file, 'rb') as f:
variation_instructions = pickle.load(f)
print("variation_instructions: ",variation_instructions)
instructions[task][variation] = variation_instructions
return instructions
if __name__ == "__main__":
args = Arguments().parse_args()
print(args)
tokenizer = load_tokenizer(args.encoder)
tokenizer.model_max_length = args.model_max_length
model = load_model(args.encoder)
model = model.to(args.device)
env = RLBenchEnv(
data_path="",
apply_rgb=True,
apply_pc=True,
apply_cameras=("over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"),
headless=True,
)
instructions: Dict[str, Dict[int, torch.Tensor]] = {}
tasks = set(args.tasks)
loaded_instructions = load_instructions_from_train(args.train_dir, tasks, args.variations)
for task in tqdm(tasks):
# task_type = task_file_to_task_class(task)
# task_inst = env.env.get_task(task_type)._task
# task_inst.init_task()
instructions[task] = {}
task_dir = args.train_dir / task
variations = count_and_set_variations(task_dir)
for variation in variations:
instr = loaded_instructions.get(task, {}).get(variation)
if instr is None:
print(f"No instructions found for task {task} variation {variation}")
continue
if args.verbose:
print(task, variation, instr)
tokens = tokenizer(instr, padding="max_length", truncation=True)["input_ids"]
lengths = [len(t) for t in tokens]
if any(l > args.model_max_length for l in lengths):
raise RuntimeError(f"Too long instructions: {lengths}")
tokens = torch.tensor(tokens).to(args.device)
with torch.no_grad():
pred = model(tokens).last_hidden_state
instructions[task][variation] = pred.cpu()
if args.zero:
for instr_task in instructions.values():
for variation, instr_var in instr_task.items():
instr_task[variation].fill_(0)
print("Instructions:", sum(len(inst) for inst in instructions.values()))
args.output.parent.mkdir(exist_ok=True)
with open(args.output, "wb") as f:
pickle.dump(instructions, f)