File size: 4,357 Bytes
c456c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from __future__ import annotations

import numpy as np
import torch

_VOLTRON_IMPORT_ERROR = None
try:
    import voltron
    from voltron import instantiate_extractor, load
except ImportError as e:
    _VOLTRON_IMPORT_ERROR = e

from torchvision import transforms as T

import uvd.utils as U
from uvd.models.preprocessors.base import Preprocessor


AVAILABLE_VOLTRON_MODEL_TYPES = [
    # === Voltron ViT-Small (Sth-Sth) Models ===
    "v-cond",
    "v-dual",
    "v-gen",
    # === Voltron ViT-Base Model ===
    "v-cond-base",
    # === Data-Locked Reproductions ===
    # "r-mvp",
    # "r-r3m-vit",
    # "r-r3m-rn50",
]


class VoltronPreprocessor(Preprocessor):
    def __init__(
        self,
        model_type: str | None = None,
        device: torch.device | str | None = None,
        remove_bn: bool = False,
        bn_to_gn: bool = False,
        remove_pool: bool = False,
        preprocess_with_fc: bool = False,
        save_fc: bool = False,
        random_crop: bool = False,
        ckpt: str | None = None,
        use_language_goal: bool = False,
    ):
        if _VOLTRON_IMPORT_ERROR is not None:
            raise ImportError(_VOLTRON_IMPORT_ERROR)
        model_type = model_type or "v-cond"
        assert model_type in AVAILABLE_VOLTRON_MODEL_TYPES, (
            model_type,
            AVAILABLE_VOLTRON_MODEL_TYPES,
        )
        self.random_crop = random_crop
        self.ckpt = ckpt
        if save_fc or preprocess_with_fc:
            U.rank_zero_print(f"WARNING: LIV no fc to save", color="red")
            save_fc = False
            preprocess_with_fc = False
        bn_to_gn = False
        super().__init__(
            model_type=model_type,
            device=device,
            remove_bn=remove_bn,
            bn_to_gn=bn_to_gn,
            remove_pool=remove_pool,
            preprocess_with_fc=preprocess_with_fc,
            save_fc=save_fc,
            use_language_goal=use_language_goal,
        )

        self._cached_language_embedding = {}

    def _get_model_and_transform(self, model_type: str) -> tuple:
        vcond, preprocess = load(model_type, freeze=True)
        vector_extractor = instantiate_extractor(vcond)()
        self.vector_extractor = vector_extractor.to(self.device)
        preprocess: T.Compose
        normlayer = preprocess.transforms[-1]
        assert isinstance(normlayer, T.Normalize)
        transform = (
            T.Compose([T.Resize(224), normlayer])
            if not self.random_crop
            else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer])
        )
        return vcond.to(self.device), transform

    def _encode_image(self, img_tensors: torch.Tensor) -> torch.FloatTensor:
        with torch.no_grad():
            return self.vector_extractor(self.model(img_tensors, mode="visual"))

    def _encode_text(
        self, text: str | np.ndarray | list | torch.Tensor
    ) -> torch.Tensor:
        raise NotImplementedError

    def encode_text(self, text: str | np.ndarray | list | torch.Tensor) -> torch.Tensor:
        return self.cached_language_embed(text)

    def cached_language_embed(self, text: str):
        if text in self._cached_language_embedding:
            return self._cached_language_embedding[text]
        text_embed = self._encode_text(text)
        self._cached_language_embedding[text] = text_embed
        return text_embed


def sim(tensor1, tensor2, metric: str = "l2", device=None):
    if type(tensor1) == np.ndarray:
        tensor1 = torch.from_numpy(tensor1).to(device)
        tensor2 = torch.from_numpy(tensor2).to(device)
    if metric == "l2":
        d = -torch.linalg.norm(tensor1 - tensor2, dim=-1)
    elif metric == "cos":
        tensor1 = tensor1 / tensor1.norm(dim=-1, keepdim=True)
        tensor2 = tensor2 / tensor2.norm(dim=-1, keepdim=True)
        d = torch.nn.CosineSimilarity(-1)(tensor1, tensor2)
    else:
        raise NotImplementedError
    return d


PROMPT_DICT = dict(
    microwave="open the microwave",
    kettle="move the kettle to the top left stove",
    light_switch="turn on the light",
    hinge_cabinet="open the left hinge cabinet",
    slide_cabinet="open the right slide cabinet",
    top_burner="turn on the top left burner",
    bottom_burner="turn on the bottom left burner",
)
PROMPT_DICT.update({k.replace("_", " "): v for k, v in PROMPT_DICT.items()})