from __future__ import annotations import numpy as np import torch _VOLTRON_IMPORT_ERROR = None try: import voltron from voltron import instantiate_extractor, load except ImportError as e: _VOLTRON_IMPORT_ERROR = e from torchvision import transforms as T import uvd.utils as U from uvd.models.preprocessors.base import Preprocessor AVAILABLE_VOLTRON_MODEL_TYPES = [ # === Voltron ViT-Small (Sth-Sth) Models === "v-cond", "v-dual", "v-gen", # === Voltron ViT-Base Model === "v-cond-base", # === Data-Locked Reproductions === # "r-mvp", # "r-r3m-vit", # "r-r3m-rn50", ] class VoltronPreprocessor(Preprocessor): def __init__( self, model_type: str | None = None, device: torch.device | str | None = None, remove_bn: bool = False, bn_to_gn: bool = False, remove_pool: bool = False, preprocess_with_fc: bool = False, save_fc: bool = False, random_crop: bool = False, ckpt: str | None = None, use_language_goal: bool = False, ): if _VOLTRON_IMPORT_ERROR is not None: raise ImportError(_VOLTRON_IMPORT_ERROR) model_type = model_type or "v-cond" assert model_type in AVAILABLE_VOLTRON_MODEL_TYPES, ( model_type, AVAILABLE_VOLTRON_MODEL_TYPES, ) self.random_crop = random_crop self.ckpt = ckpt if save_fc or preprocess_with_fc: U.rank_zero_print(f"WARNING: LIV no fc to save", color="red") save_fc = False preprocess_with_fc = False bn_to_gn = False super().__init__( model_type=model_type, device=device, remove_bn=remove_bn, bn_to_gn=bn_to_gn, remove_pool=remove_pool, preprocess_with_fc=preprocess_with_fc, save_fc=save_fc, use_language_goal=use_language_goal, ) self._cached_language_embedding = {} def _get_model_and_transform(self, model_type: str) -> tuple: vcond, preprocess = load(model_type, freeze=True) vector_extractor = instantiate_extractor(vcond)() self.vector_extractor = vector_extractor.to(self.device) preprocess: T.Compose normlayer = preprocess.transforms[-1] assert isinstance(normlayer, T.Normalize) transform = ( T.Compose([T.Resize(224), normlayer]) if not self.random_crop else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer]) ) return vcond.to(self.device), transform def _encode_image(self, img_tensors: torch.Tensor) -> torch.FloatTensor: with torch.no_grad(): return self.vector_extractor(self.model(img_tensors, mode="visual")) def _encode_text( self, text: str | np.ndarray | list | torch.Tensor ) -> torch.Tensor: raise NotImplementedError def encode_text(self, text: str | np.ndarray | list | torch.Tensor) -> torch.Tensor: return self.cached_language_embed(text) def cached_language_embed(self, text: str): if text in self._cached_language_embedding: return self._cached_language_embedding[text] text_embed = self._encode_text(text) self._cached_language_embedding[text] = text_embed return text_embed def sim(tensor1, tensor2, metric: str = "l2", device=None): if type(tensor1) == np.ndarray: tensor1 = torch.from_numpy(tensor1).to(device) tensor2 = torch.from_numpy(tensor2).to(device) if metric == "l2": d = -torch.linalg.norm(tensor1 - tensor2, dim=-1) elif metric == "cos": tensor1 = tensor1 / tensor1.norm(dim=-1, keepdim=True) tensor2 = tensor2 / tensor2.norm(dim=-1, keepdim=True) d = torch.nn.CosineSimilarity(-1)(tensor1, tensor2) else: raise NotImplementedError return d PROMPT_DICT = dict( microwave="open the microwave", kettle="move the kettle to the top left stove", light_switch="turn on the light", hinge_cabinet="open the left hinge cabinet", slide_cabinet="open the right slide cabinet", top_burner="turn on the top left burner", bottom_burner="turn on the bottom left burner", ) PROMPT_DICT.update({k.replace("_", " "): v for k, v in PROMPT_DICT.items()})