File size: 5,397 Bytes
2e7f2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, AutoModel

from .configuration_vora import VoRAConfig


def _reload_aimv2_weights(model, model_path: str):
    """Reload AIMv2 weights from mlx_model.safetensors with NHWC→NCHW conversion."""
    mlx_path = os.path.join(model_path, "mlx_model.safetensors")
    if not os.path.isfile(mlx_path):
        return
    from safetensors.torch import load_file
    sd = load_file(mlx_path)
    converted = {}
    for k, v in sd.items():
        if v.ndim == 4:
            converted[k] = v.permute(0, 3, 1, 2).contiguous()
        else:
            converted[k] = v
    msg = model.load_state_dict(converted, strict=False)
    if msg.missing_keys or msg.unexpected_keys:
        print(f"[AuxVision] _reload_aimv2_weights: missing={msg.missing_keys}, unexpected={msg.unexpected_keys}")
    else:
        print("[AuxVision] _reload_aimv2_weights: all keys matched successfully")


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

    def extra_repr(self) -> str:
        return f"{tuple(self.weight.shape)}, eps={self.eps}"

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)


class CosineLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super(CosineLoss, self).__init__()
        self.reduction = reduction

    @staticmethod
    def interpolate_tokens_2d(self, teacher_tokens, target_size):
        """
        Interpolate teacher tokens to the target size using bilinear interpolation.
        """
        # teacher_tokens shape is (batch_size, height, width, feature_dim)
        teacher_tokens = teacher_tokens.permute(0, 3, 1, 2)  # Convert to (batch_size, feature_dim, height, width)
        interpolated = torch.nn.functional.interpolate(teacher_tokens, size=target_size, mode='bilinear', align_corners=True).flatten(2)  # Flatten height and width dimensions
        return interpolated.permute(0, 2, 1)  # Convert back to (batch_size, new_height * new_width, feature_dim)

    def forward(self, input: torch.Tensor, target: torch.Tensor, input_shape=None, target_shape=None) -> torch.Tensor:
        if input_shape is not None and target_shape is not None:
            input = input.reshape((input.shape[0], ) + input_shape + (-1, ))
            input = self.interpolate_tokens_2d(input, target_shape)

        cos_sim = nn.functional.cosine_similarity(input, target, dim=1)
        loss = 1 - cos_sim

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss


class AuxVision(nn.Module):
    def __init__(self,
        config: VoRAConfig = None,
    ):
        super().__init__()
        self.skip_aux_cls = config.skip_aux_cls  # whether to skip the cls token in ViT
        # ---------------- Setup Aux Model ----------------
        if 'clip' in config.aux_vision.lower():
            self.aux_model = CLIPVisionModel.from_pretrained(config.aux_vision)
            vision_hidden_size = self.aux_model.vision_model.config.hidden_size
            num_hidden_layers = self.aux_model.vision_model.config.num_hidden_layers
        else:
            self.aux_model = AutoModel.from_pretrained(config.aux_vision, trust_remote_code=True)
            _reload_aimv2_weights(self.aux_model, config.aux_vision)
            vision_hidden_size = self.aux_model.config.hidden_size
            num_hidden_layers = self.aux_model.config.num_hidden_layers
        for name, param in self.aux_model.named_parameters():
            param.requires_grad = False
        # -------------------------------------------------

        # ---------------- Setup Aux Heads ----------------
        self.aux_layers = list(range(num_hidden_layers))
        for layer_id in self.aux_layers:
            self.add_module(f"aux_layer_{layer_id}", self.build_aux_layer(config.hidden_size, vision_hidden_size))
        # -------------------------------------------------

        self.loss_function = CosineLoss()
        self.loss_keys = [f"loss_aux_layer_{layer_id}" for layer_id in self.aux_layers]

    def build_aux_layer(self, llm_hidden_size, vit_hidden_size):
        return nn.Sequential(
                    RMSNorm(llm_hidden_size),
                    nn.Linear(
                        llm_hidden_size,
                        vit_hidden_size,
                        bias=False,
                    )
                )

    def forward(self, frames, llm_hidden_states, vision_mask):
        vision_hidden_states = self.aux_model(frames, output_hidden_states=True).hidden_states
        losses = {}
        for layer_idx in self.aux_layers:
            aux_hidden_states = getattr(self, f"aux_layer_{layer_idx}")(llm_hidden_states[layer_idx][vision_mask == 1])
            start_id = 1 if self.skip_aux_cls else 0
            aux_loss = self.loss_function(vision_hidden_states[layer_idx][:, start_id:].reshape(aux_hidden_states.shape), aux_hidden_states)
            losses[f"loss_aux_layer_{layer_idx}"] = aux_loss
        return losses