Spaces:
Running
Running
| from typing import Callable, Union | |
| import pyiqa | |
| import torch | |
| from torchvision.transforms import functional as TF | |
| class QAlignVisionOnlyWrapper(torch.nn.Module): | |
| """Wrapper that only runs vision encoder up to visual_abstractor, skipping LLM. | |
| Architecture path: | |
| - base_model: InferenceModel | |
| - base_model.net: QAlign | |
| - base_model.net.model: MPLUGOwl2LlamaForCausalLM | |
| - base_model.net.model.model: MPLUGOwl2LlamaModel | |
| - vision_model: MplugOwlVisionModel | |
| - visual_abstractor: MplugOwlVisualAbstractorModel | |
| """ | |
| def __init__(self, base_model: torch.nn.Module): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.vision_model = base_model.net.model.model.vision_model | |
| self.visual_abstractor = base_model.net.model.model.visual_abstractor | |
| self.base_model.eval() | |
| for parameter in self.base_model.parameters(): | |
| parameter.requires_grad_(False) | |
| def train(self, mode: bool = True): | |
| super().train(mode) | |
| self.base_model.eval() | |
| return self | |
| def eval(self): | |
| super().eval() | |
| self.base_model.eval() | |
| return self | |
| def forward(self, images: torch.Tensor): | |
| device = next(self.vision_model.parameters()).device | |
| if not isinstance(images, torch.Tensor): | |
| images = torch.stack(list(images)) | |
| pixel_values = images.to(device=device, dtype=next(self.vision_model.parameters()).dtype) | |
| with torch.no_grad(): | |
| hidden_states = self.vision_model(pixel_values).last_hidden_state | |
| abstract_output = self.visual_abstractor(encoder_hidden_states=hidden_states) | |
| return abstract_output | |
| def flatten_blc_drop_cls(hidden_states: torch.Tensor) -> torch.Tensor: | |
| """Drop the last CLS token from (B, L, C) and flatten to (B*L, C).""" | |
| if hidden_states.dim() == 3 and hidden_states.shape[1] > 1: | |
| hidden_states = hidden_states[:, :-1, :] | |
| return hidden_states.reshape(-1, hidden_states.shape[-1]) | |
| def _center_crop_with_padding(image: torch.Tensor, crop_size: int) -> torch.Tensor: | |
| _, height, width = image.shape | |
| pad_height = max(0, crop_size - height) | |
| pad_width = max(0, crop_size - width) | |
| if pad_height > 0 or pad_width > 0: | |
| padding = [pad_width // 2, pad_height // 2, pad_width - pad_width // 2, pad_height - pad_height // 2] | |
| image = TF.pad(image, padding, fill=0, padding_mode="constant") | |
| return TF.center_crop(image, [crop_size, crop_size]) | |
| def qalign_image_transform(crop_size: int) -> Callable[[torch.Tensor], torch.Tensor]: | |
| def _transform(image: torch.Tensor) -> torch.Tensor: | |
| return _center_crop_with_padding(image, crop_size) | |
| return _transform | |
| def create_qalign_model(device: Union[str, torch.device]) -> QAlignVisionOnlyWrapper: | |
| """Create Q-Align metric wrapped for vision-only activation extraction.""" | |
| base_model = pyiqa.api_helpers.create_metric("qalign") | |
| base_model.to(device) | |
| return QAlignVisionOnlyWrapper(base_model) | |
| def get_visual_abstractor(model: torch.nn.Module): | |
| """Get visual_abstractor from a QAlignVisionOnlyWrapper.""" | |
| return model.visual_abstractor | |