Spaces:
Sleeping
Sleeping
File size: 4,357 Bytes
c456c14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from __future__ import annotations
import numpy as np
import torch
_VOLTRON_IMPORT_ERROR = None
try:
import voltron
from voltron import instantiate_extractor, load
except ImportError as e:
_VOLTRON_IMPORT_ERROR = e
from torchvision import transforms as T
import uvd.utils as U
from uvd.models.preprocessors.base import Preprocessor
AVAILABLE_VOLTRON_MODEL_TYPES = [
# === Voltron ViT-Small (Sth-Sth) Models ===
"v-cond",
"v-dual",
"v-gen",
# === Voltron ViT-Base Model ===
"v-cond-base",
# === Data-Locked Reproductions ===
# "r-mvp",
# "r-r3m-vit",
# "r-r3m-rn50",
]
class VoltronPreprocessor(Preprocessor):
def __init__(
self,
model_type: str | None = None,
device: torch.device | str | None = None,
remove_bn: bool = False,
bn_to_gn: bool = False,
remove_pool: bool = False,
preprocess_with_fc: bool = False,
save_fc: bool = False,
random_crop: bool = False,
ckpt: str | None = None,
use_language_goal: bool = False,
):
if _VOLTRON_IMPORT_ERROR is not None:
raise ImportError(_VOLTRON_IMPORT_ERROR)
model_type = model_type or "v-cond"
assert model_type in AVAILABLE_VOLTRON_MODEL_TYPES, (
model_type,
AVAILABLE_VOLTRON_MODEL_TYPES,
)
self.random_crop = random_crop
self.ckpt = ckpt
if save_fc or preprocess_with_fc:
U.rank_zero_print(f"WARNING: LIV no fc to save", color="red")
save_fc = False
preprocess_with_fc = False
bn_to_gn = False
super().__init__(
model_type=model_type,
device=device,
remove_bn=remove_bn,
bn_to_gn=bn_to_gn,
remove_pool=remove_pool,
preprocess_with_fc=preprocess_with_fc,
save_fc=save_fc,
use_language_goal=use_language_goal,
)
self._cached_language_embedding = {}
def _get_model_and_transform(self, model_type: str) -> tuple:
vcond, preprocess = load(model_type, freeze=True)
vector_extractor = instantiate_extractor(vcond)()
self.vector_extractor = vector_extractor.to(self.device)
preprocess: T.Compose
normlayer = preprocess.transforms[-1]
assert isinstance(normlayer, T.Normalize)
transform = (
T.Compose([T.Resize(224), normlayer])
if not self.random_crop
else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer])
)
return vcond.to(self.device), transform
def _encode_image(self, img_tensors: torch.Tensor) -> torch.FloatTensor:
with torch.no_grad():
return self.vector_extractor(self.model(img_tensors, mode="visual"))
def _encode_text(
self, text: str | np.ndarray | list | torch.Tensor
) -> torch.Tensor:
raise NotImplementedError
def encode_text(self, text: str | np.ndarray | list | torch.Tensor) -> torch.Tensor:
return self.cached_language_embed(text)
def cached_language_embed(self, text: str):
if text in self._cached_language_embedding:
return self._cached_language_embedding[text]
text_embed = self._encode_text(text)
self._cached_language_embedding[text] = text_embed
return text_embed
def sim(tensor1, tensor2, metric: str = "l2", device=None):
if type(tensor1) == np.ndarray:
tensor1 = torch.from_numpy(tensor1).to(device)
tensor2 = torch.from_numpy(tensor2).to(device)
if metric == "l2":
d = -torch.linalg.norm(tensor1 - tensor2, dim=-1)
elif metric == "cos":
tensor1 = tensor1 / tensor1.norm(dim=-1, keepdim=True)
tensor2 = tensor2 / tensor2.norm(dim=-1, keepdim=True)
d = torch.nn.CosineSimilarity(-1)(tensor1, tensor2)
else:
raise NotImplementedError
return d
PROMPT_DICT = dict(
microwave="open the microwave",
kettle="move the kettle to the top left stove",
light_switch="turn on the light",
hinge_cabinet="open the left hinge cabinet",
slide_cabinet="open the right slide cabinet",
top_burner="turn on the top left burner",
bottom_burner="turn on the bottom left burner",
)
PROMPT_DICT.update({k.replace("_", " "): v for k, v in PROMPT_DICT.items()})
|