eval-pack / models /aux_vision.py
jun-1001's picture
Upload folder using huggingface_hub
2e7f2ce verified
import os
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, AutoModel
from .configuration_vora import VoRAConfig
def _reload_aimv2_weights(model, model_path: str):
"""Reload AIMv2 weights from mlx_model.safetensors with NHWC→NCHW conversion."""
mlx_path = os.path.join(model_path, "mlx_model.safetensors")
if not os.path.isfile(mlx_path):
return
from safetensors.torch import load_file
sd = load_file(mlx_path)
converted = {}
for k, v in sd.items():
if v.ndim == 4:
converted[k] = v.permute(0, 3, 1, 2).contiguous()
else:
converted[k] = v
msg = model.load_state_dict(converted, strict=False)
if msg.missing_keys or msg.unexpected_keys:
print(f"[AuxVision] _reload_aimv2_weights: missing={msg.missing_keys}, unexpected={msg.unexpected_keys}")
else:
print("[AuxVision] _reload_aimv2_weights: all keys matched successfully")
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.eps}"
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class CosineLoss(nn.Module):
def __init__(self, reduction='mean'):
super(CosineLoss, self).__init__()
self.reduction = reduction
@staticmethod
def interpolate_tokens_2d(self, teacher_tokens, target_size):
"""
Interpolate teacher tokens to the target size using bilinear interpolation.
"""
# teacher_tokens shape is (batch_size, height, width, feature_dim)
teacher_tokens = teacher_tokens.permute(0, 3, 1, 2) # Convert to (batch_size, feature_dim, height, width)
interpolated = torch.nn.functional.interpolate(teacher_tokens, size=target_size, mode='bilinear', align_corners=True).flatten(2) # Flatten height and width dimensions
return interpolated.permute(0, 2, 1) # Convert back to (batch_size, new_height * new_width, feature_dim)
def forward(self, input: torch.Tensor, target: torch.Tensor, input_shape=None, target_shape=None) -> torch.Tensor:
if input_shape is not None and target_shape is not None:
input = input.reshape((input.shape[0], ) + input_shape + (-1, ))
input = self.interpolate_tokens_2d(input, target_shape)
cos_sim = nn.functional.cosine_similarity(input, target, dim=1)
loss = 1 - cos_sim
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
class AuxVision(nn.Module):
def __init__(self,
config: VoRAConfig = None,
):
super().__init__()
self.skip_aux_cls = config.skip_aux_cls # whether to skip the cls token in ViT
# ---------------- Setup Aux Model ----------------
if 'clip' in config.aux_vision.lower():
self.aux_model = CLIPVisionModel.from_pretrained(config.aux_vision)
vision_hidden_size = self.aux_model.vision_model.config.hidden_size
num_hidden_layers = self.aux_model.vision_model.config.num_hidden_layers
else:
self.aux_model = AutoModel.from_pretrained(config.aux_vision, trust_remote_code=True)
_reload_aimv2_weights(self.aux_model, config.aux_vision)
vision_hidden_size = self.aux_model.config.hidden_size
num_hidden_layers = self.aux_model.config.num_hidden_layers
for name, param in self.aux_model.named_parameters():
param.requires_grad = False
# -------------------------------------------------
# ---------------- Setup Aux Heads ----------------
self.aux_layers = list(range(num_hidden_layers))
for layer_id in self.aux_layers:
self.add_module(f"aux_layer_{layer_id}", self.build_aux_layer(config.hidden_size, vision_hidden_size))
# -------------------------------------------------
self.loss_function = CosineLoss()
self.loss_keys = [f"loss_aux_layer_{layer_id}" for layer_id in self.aux_layers]
def build_aux_layer(self, llm_hidden_size, vit_hidden_size):
return nn.Sequential(
RMSNorm(llm_hidden_size),
nn.Linear(
llm_hidden_size,
vit_hidden_size,
bias=False,
)
)
def forward(self, frames, llm_hidden_states, vision_mask):
vision_hidden_states = self.aux_model(frames, output_hidden_states=True).hidden_states
losses = {}
for layer_idx in self.aux_layers:
aux_hidden_states = getattr(self, f"aux_layer_{layer_idx}")(llm_hidden_states[layer_idx][vision_mask == 1])
start_id = 1 if self.skip_aux_cls else 0
aux_loss = self.loss_function(vision_hidden_states[layer_idx][:, start_id:].reshape(aux_hidden_states.shape), aux_hidden_states)
losses[f"loss_aux_layer_{layer_idx}"] = aux_loss
return losses