| import os |
| import torch |
| import torch.nn as nn |
| from transformers import CLIPVisionModel, AutoModel |
|
|
| from .configuration_vora import VoRAConfig |
|
|
|
|
| def _reload_aimv2_weights(model, model_path: str): |
| """Reload AIMv2 weights from mlx_model.safetensors with NHWC→NCHW conversion.""" |
| mlx_path = os.path.join(model_path, "mlx_model.safetensors") |
| if not os.path.isfile(mlx_path): |
| return |
| from safetensors.torch import load_file |
| sd = load_file(mlx_path) |
| converted = {} |
| for k, v in sd.items(): |
| if v.ndim == 4: |
| converted[k] = v.permute(0, 3, 1, 2).contiguous() |
| else: |
| converted[k] = v |
| msg = model.load_state_dict(converted, strict=False) |
| if msg.missing_keys or msg.unexpected_keys: |
| print(f"[AuxVision] _reload_aimv2_weights: missing={msg.missing_keys}, unexpected={msg.unexpected_keys}") |
| else: |
| print("[AuxVision] _reload_aimv2_weights: all keys matched successfully") |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
| def extra_repr(self) -> str: |
| return f"{tuple(self.weight.shape)}, eps={self.eps}" |
|
|
| def _norm(self, x: torch.Tensor) -> torch.Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
|
|
| class CosineLoss(nn.Module): |
| def __init__(self, reduction='mean'): |
| super(CosineLoss, self).__init__() |
| self.reduction = reduction |
|
|
| @staticmethod |
| def interpolate_tokens_2d(self, teacher_tokens, target_size): |
| """ |
| Interpolate teacher tokens to the target size using bilinear interpolation. |
| """ |
| |
| teacher_tokens = teacher_tokens.permute(0, 3, 1, 2) |
| interpolated = torch.nn.functional.interpolate(teacher_tokens, size=target_size, mode='bilinear', align_corners=True).flatten(2) |
| return interpolated.permute(0, 2, 1) |
|
|
| def forward(self, input: torch.Tensor, target: torch.Tensor, input_shape=None, target_shape=None) -> torch.Tensor: |
| if input_shape is not None and target_shape is not None: |
| input = input.reshape((input.shape[0], ) + input_shape + (-1, )) |
| input = self.interpolate_tokens_2d(input, target_shape) |
|
|
| cos_sim = nn.functional.cosine_similarity(input, target, dim=1) |
| loss = 1 - cos_sim |
|
|
| if self.reduction == 'mean': |
| return loss.mean() |
| elif self.reduction == 'sum': |
| return loss.sum() |
| else: |
| return loss |
|
|
|
|
| class AuxVision(nn.Module): |
| def __init__(self, |
| config: VoRAConfig = None, |
| ): |
| super().__init__() |
| self.skip_aux_cls = config.skip_aux_cls |
| |
| if 'clip' in config.aux_vision.lower(): |
| self.aux_model = CLIPVisionModel.from_pretrained(config.aux_vision) |
| vision_hidden_size = self.aux_model.vision_model.config.hidden_size |
| num_hidden_layers = self.aux_model.vision_model.config.num_hidden_layers |
| else: |
| self.aux_model = AutoModel.from_pretrained(config.aux_vision, trust_remote_code=True) |
| _reload_aimv2_weights(self.aux_model, config.aux_vision) |
| vision_hidden_size = self.aux_model.config.hidden_size |
| num_hidden_layers = self.aux_model.config.num_hidden_layers |
| for name, param in self.aux_model.named_parameters(): |
| param.requires_grad = False |
| |
|
|
| |
| self.aux_layers = list(range(num_hidden_layers)) |
| for layer_id in self.aux_layers: |
| self.add_module(f"aux_layer_{layer_id}", self.build_aux_layer(config.hidden_size, vision_hidden_size)) |
| |
|
|
| self.loss_function = CosineLoss() |
| self.loss_keys = [f"loss_aux_layer_{layer_id}" for layer_id in self.aux_layers] |
|
|
| def build_aux_layer(self, llm_hidden_size, vit_hidden_size): |
| return nn.Sequential( |
| RMSNorm(llm_hidden_size), |
| nn.Linear( |
| llm_hidden_size, |
| vit_hidden_size, |
| bias=False, |
| ) |
| ) |
|
|
| def forward(self, frames, llm_hidden_states, vision_mask): |
| vision_hidden_states = self.aux_model(frames, output_hidden_states=True).hidden_states |
| losses = {} |
| for layer_idx in self.aux_layers: |
| aux_hidden_states = getattr(self, f"aux_layer_{layer_idx}")(llm_hidden_states[layer_idx][vision_mask == 1]) |
| start_id = 1 if self.skip_aux_cls else 0 |
| aux_loss = self.loss_function(vision_hidden_states[layer_idx][:, start_id:].reshape(aux_hidden_states.shape), aux_hidden_states) |
| losses[f"loss_aux_layer_{layer_idx}"] = aux_loss |
| return losses |
|
|