File size: 12,955 Bytes
b381930 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 |
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ColPali model."""
import gc
import unittest
from typing import ClassVar
import torch
from datasets import load_dataset
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from transformers import (
is_torch_available,
)
from transformers.models.colpali.configuration_colpali import ColPaliConfig
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from transformers.testing_utils import (
require_torch,
require_vision,
slow,
torch_device,
)
if is_torch_available():
import torch
class ColPaliForRetrievalModelTester:
def __init__(
self,
parent,
ignore_index=-100,
image_token_index=0,
projector_hidden_act="gelu",
seq_length=25,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
projection_dim=32,
text_config={
"model_type": "gemma",
"seq_length": 128,
"is_training": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"head_dim": 8,
"intermediate_size": 37,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 1,
},
is_training=False,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_image_tokens": 4,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
use_cache=False,
embedding_dim=128,
):
self.parent = parent
self.ignore_index = ignore_index
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.projection_dim = projection_dim
self.pad_token_id = text_config["pad_token_id"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.encoder_seq_length = seq_length
self.use_cache = use_cache
self.embedding_dim = embedding_dim
self.vlm_config = {
"model_type": "paligemma",
"text_config": self.text_config,
"vision_config": self.vision_config,
"ignore_index": self.ignore_index,
"image_token_index": self.image_token_index,
"projector_hidden_act": self.projector_hidden_act,
"projection_dim": self.projection_dim,
"vision_feature_select_strategy": self.vision_feature_select_strategy,
"vision_feature_layer": self.vision_feature_layer,
}
def get_config(self):
return ColPaliConfig(
vlm_config=self.vlm_config,
embedding_dim=self.embedding_dim,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.vlm_config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.vlm_config.image_token_index] = self.pad_token_id
input_ids[:, :16] = config.vlm_config.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"token_type_ids": torch.zeros_like(input_ids),
}
return config, inputs_dict
@require_torch
class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `ColPaliForRetrieval`.
"""
all_model_classes = (ColPaliForRetrieval,) if is_torch_available() else ()
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = ColPaliForRetrievalModelTester(self)
self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@slow
@require_vision
def test_colpali_forward_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now."
)
def test_model_parallelism(self):
pass
@unittest.skip(
reason="PaliGemmma's SigLip encoder uses the same initialization scheme as the Flax original implementation"
)
def test_initialization(self):
pass
# TODO extend valid outputs to include this test @Molbap
@unittest.skip(reason="PaliGemma has currently one output format.")
def test_model_outputs_equivalence(self):
pass
@unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
def test_sdpa_can_compile_dynamic(self):
pass
@require_torch
class ColPaliModelIntegrationTest(unittest.TestCase):
model_name: ClassVar[str] = "vidore/colpali-v1.2-hf"
def setUp(self):
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
def test_model_integration_test(self):
"""
Test if the model is able to retrieve the correct pages for a small and easy dataset.
"""
model = ColPaliForRetrieval.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
).eval()
# Load the test dataset
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
# Preprocess the examples
batch_images = self.processor(images=ds["image"]).to(torch_device)
batch_queries = self.processor(text=ds["query"]).to(torch_device)
# Run inference
with torch.inference_mode():
image_embeddings = model(**batch_images).embeddings
query_embeddings = model(**batch_queries).embeddings
# Compute retrieval scores
scores = self.processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
) # (len(qs), len(ps))
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
# Check if the maximum scores per row are in the diagonal of the matrix score
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original implementation
expected_scores = torch.tensor(
[
[15.5625, 6.5938, 14.4375],
[12.2500, 16.2500, 11.0000],
[15.0625, 11.7500, 21.0000],
],
dtype=scores.dtype,
)
assert torch.allclose(scores, expected_scores, atol=1), f"Expected scores {expected_scores}, got {scores}"
|