IQA-Interpretation / analysis /qalign_utils.py
dvarfe's picture
init repo
00a0ce5
Raw
History Blame Contribute Delete
3.17 kB
from typing import Callable, Union
import pyiqa
import torch
from torchvision.transforms import functional as TF
class QAlignVisionOnlyWrapper(torch.nn.Module):
"""Wrapper that only runs vision encoder up to visual_abstractor, skipping LLM.
Architecture path:
- base_model: InferenceModel
- base_model.net: QAlign
- base_model.net.model: MPLUGOwl2LlamaForCausalLM
- base_model.net.model.model: MPLUGOwl2LlamaModel
- vision_model: MplugOwlVisionModel
- visual_abstractor: MplugOwlVisualAbstractorModel
"""
def __init__(self, base_model: torch.nn.Module):
super().__init__()
self.base_model = base_model
self.vision_model = base_model.net.model.model.vision_model
self.visual_abstractor = base_model.net.model.model.visual_abstractor
self.base_model.eval()
for parameter in self.base_model.parameters():
parameter.requires_grad_(False)
def train(self, mode: bool = True):
super().train(mode)
self.base_model.eval()
return self
def eval(self):
super().eval()
self.base_model.eval()
return self
def forward(self, images: torch.Tensor):
device = next(self.vision_model.parameters()).device
if not isinstance(images, torch.Tensor):
images = torch.stack(list(images))
pixel_values = images.to(device=device, dtype=next(self.vision_model.parameters()).dtype)
with torch.no_grad():
hidden_states = self.vision_model(pixel_values).last_hidden_state
abstract_output = self.visual_abstractor(encoder_hidden_states=hidden_states)
return abstract_output
def flatten_blc_drop_cls(hidden_states: torch.Tensor) -> torch.Tensor:
"""Drop the last CLS token from (B, L, C) and flatten to (B*L, C)."""
if hidden_states.dim() == 3 and hidden_states.shape[1] > 1:
hidden_states = hidden_states[:, :-1, :]
return hidden_states.reshape(-1, hidden_states.shape[-1])
def _center_crop_with_padding(image: torch.Tensor, crop_size: int) -> torch.Tensor:
_, height, width = image.shape
pad_height = max(0, crop_size - height)
pad_width = max(0, crop_size - width)
if pad_height > 0 or pad_width > 0:
padding = [pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2]
image = TF.pad(image, padding, fill=0, padding_mode="constant")
return TF.center_crop(image, [crop_size, crop_size])
def qalign_image_transform(crop_size: int) -> Callable[[torch.Tensor], torch.Tensor]:
def _transform(image: torch.Tensor) -> torch.Tensor:
return _center_crop_with_padding(image, crop_size)
return _transform
def create_qalign_model(device: Union[str, torch.device]) -> QAlignVisionOnlyWrapper:
"""Create Q-Align metric wrapped for vision-only activation extraction."""
base_model = pyiqa.api_helpers.create_metric("qalign")
base_model.to(device)
return QAlignVisionOnlyWrapper(base_model)
def get_visual_abstractor(model: torch.nn.Module):
"""Get visual_abstractor from a QAlignVisionOnlyWrapper."""
return model.visual_abstractor