from typing import Callable, Union import pyiqa import torch from torchvision.transforms import functional as TF class QAlignVisionOnlyWrapper(torch.nn.Module): """Wrapper that only runs vision encoder up to visual_abstractor, skipping LLM. Architecture path: - base_model: InferenceModel - base_model.net: QAlign - base_model.net.model: MPLUGOwl2LlamaForCausalLM - base_model.net.model.model: MPLUGOwl2LlamaModel - vision_model: MplugOwlVisionModel - visual_abstractor: MplugOwlVisualAbstractorModel """ def __init__(self, base_model: torch.nn.Module): super().__init__() self.base_model = base_model self.vision_model = base_model.net.model.model.vision_model self.visual_abstractor = base_model.net.model.model.visual_abstractor self.base_model.eval() for parameter in self.base_model.parameters(): parameter.requires_grad_(False) def train(self, mode: bool = True): super().train(mode) self.base_model.eval() return self def eval(self): super().eval() self.base_model.eval() return self def forward(self, images: torch.Tensor): device = next(self.vision_model.parameters()).device if not isinstance(images, torch.Tensor): images = torch.stack(list(images)) pixel_values = images.to(device=device, dtype=next(self.vision_model.parameters()).dtype) with torch.no_grad(): hidden_states = self.vision_model(pixel_values).last_hidden_state abstract_output = self.visual_abstractor(encoder_hidden_states=hidden_states) return abstract_output def flatten_blc_drop_cls(hidden_states: torch.Tensor) -> torch.Tensor: """Drop the last CLS token from (B, L, C) and flatten to (B*L, C).""" if hidden_states.dim() == 3 and hidden_states.shape[1] > 1: hidden_states = hidden_states[:, :-1, :] return hidden_states.reshape(-1, hidden_states.shape[-1]) def _center_crop_with_padding(image: torch.Tensor, crop_size: int) -> torch.Tensor: _, height, width = image.shape pad_height = max(0, crop_size - height) pad_width = max(0, crop_size - width) if pad_height > 0 or pad_width > 0: padding = [pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2] image = TF.pad(image, padding, fill=0, padding_mode="constant") return TF.center_crop(image, [crop_size, crop_size]) def qalign_image_transform(crop_size: int) -> Callable[[torch.Tensor], torch.Tensor]: def _transform(image: torch.Tensor) -> torch.Tensor: return _center_crop_with_padding(image, crop_size) return _transform def create_qalign_model(device: Union[str, torch.device]) -> QAlignVisionOnlyWrapper: """Create Q-Align metric wrapped for vision-only activation extraction.""" base_model = pyiqa.api_helpers.create_metric("qalign") base_model.to(device) return QAlignVisionOnlyWrapper(base_model) def get_visual_abstractor(model: torch.nn.Module): """Get visual_abstractor from a QAlignVisionOnlyWrapper.""" return model.visual_abstractor