File size: 24,168 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, LlavaForConditionalGeneration
from peft import LoraConfig, PeftModel, VeraConfig, get_peft_model
from peft.utils.other import ModulesToSaveWrapper, _get_no_split_modules
class ModelWithModuleDict(nn.Module):
def __init__(self):
super().__init__()
self.other_layer = nn.Linear(10, 10)
self.module = nn.ModuleDict({"foo": nn.Linear(10, 10)})
def forward(self):
return self.module["foo"](torch.rand(1, 10))
class ModelWithModuleList(nn.Module):
def __init__(self):
super().__init__()
self.other_layer = nn.Linear(10, 10)
self.module = nn.ModuleList([nn.Linear(10, 10)])
def forward(self):
return self.module[0](torch.rand(1, 10))
class ModelWithParameterDict(nn.Module):
def __init__(self):
super().__init__()
self.other_layer = nn.Linear(10, 10)
self.module = nn.ParameterDict({"foo": nn.Parameter(torch.rand(10, 10))})
def forward(self):
return self.module["foo"]
class ModelWithParameterList(nn.Module):
def __init__(self):
super().__init__()
self.other_layer = nn.Linear(10, 10)
self.module = nn.ParameterList([nn.Parameter(torch.rand(10, 10))])
def forward(self):
return self.module[0]
@pytest.mark.parametrize(
"cls", [ModelWithModuleDict, ModelWithModuleList, ModelWithParameterDict, ModelWithParameterList]
)
def test_modules_to_save_targets_module_dict_raises(cls):
model = cls()
peft_config = LoraConfig(
target_modules=["other_layer"],
modules_to_save=["module"],
)
model() # sanity check that the model would normally work
msg = "modules_to_save cannot be applied to modules of type"
with pytest.raises(TypeError, match=msg):
get_peft_model(model=model, peft_config=peft_config)
def test_get_peft_model_revision_warning(tmp_path):
base_model_id = "peft-internal-testing/tiny-random-BertModel"
base_revision = "v2.0.0"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id, revision=base_revision).eval()
lora_config = LoraConfig(revision=base_revision)
overwrite_revision = "main"
overwrite_warning = f"peft config has already set base model revision to {base_revision}, overwriting with revision {overwrite_revision}"
with pytest.warns(UserWarning, match=overwrite_warning):
_ = get_peft_model(base_model, lora_config, revision=overwrite_revision)
def test_load_multiple_adapters_different_modules_to_save(tmp_path):
# This tests the error described in #2422 where loading multiple adapters with different modules_to_save
# attributes fails (due to a regression from #2376).
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-LlamaForCausalLM")
def peft_config(**kwargs):
return LoraConfig(target_modules="all-linear", **kwargs)
original_model = copy.deepcopy(model)
peft_config_0 = peft_config(modules_to_save=["0.post_attention_layernorm"])
peft_config_1 = peft_config(modules_to_save=["0.post_attention_layernorm"])
peft_config_2 = peft_config(modules_to_save=["1.post_attention_layernorm"])
# Save adapter 0, nothing fancy, should be equal to base model weighs
peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_0)
peft_model.save_pretrained(tmp_path / "adapter_0")
# Save adapter 1, modules to save weights are modified randomly, should be unique to adapter 1
peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_1)
peft_model.model.model.layers[0].post_attention_layernorm.weight.data = torch.rand_like(
peft_model.model.model.layers[0].post_attention_layernorm.weight.data
)
adapter_1_saved = peft_model.model.model.layers[0].post_attention_layernorm.weight.data.clone()
peft_model.save_pretrained(tmp_path / "adapter_1")
# Save adapter 2, modules to save weights are modified randomly, should be unique to adapter 2
peft_model = get_peft_model(copy.deepcopy(original_model), peft_config_2)
peft_model.model.model.layers[1].post_attention_layernorm.weight.data = torch.rand_like(
peft_model.model.model.layers[1].post_attention_layernorm.weight.data
)
adapter_2_saved = peft_model.model.model.layers[1].post_attention_layernorm.weight.data.clone()
peft_model.save_pretrained(tmp_path / "adapter_2")
del peft_model
combined_model = PeftModel.from_pretrained(original_model, tmp_path / "adapter_0", adapter_name="adapter_0")
combined_model.load_adapter(tmp_path / "adapter_1", adapter_name="adapter_1")
combined_model.load_adapter(tmp_path / "adapter_2", adapter_name="adapter_2")
# For adapter 0 we expect every mentioned modules to save layer of this test to be equal to the original model
# since we didn't modify it for adapter 0 and only adapter 0 is active.
combined_model.set_adapter("adapter_0")
assert torch.allclose(
combined_model.model.model.layers[0].post_attention_layernorm.weight,
original_model.model.layers[0].post_attention_layernorm.weight,
)
assert torch.allclose(
combined_model.model.model.layers[1].post_attention_layernorm.weight,
original_model.model.layers[1].post_attention_layernorm.weight,
)
# For adapter 1 we expect that the modified module to save 0.post_attention_layernorm is modified, the other
# module to save layers mentioned above should be untouched.
combined_model.set_adapter("adapter_1")
assert torch.allclose(
combined_model.model.model.layers[0].post_attention_layernorm.weight,
adapter_1_saved,
)
assert torch.allclose(
combined_model.model.model.layers[1].post_attention_layernorm.weight,
original_model.model.layers[1].post_attention_layernorm.weight,
)
# For adapter 2 we expect its module to save layer (1.post_attention_layernorm) to be modified but the other
# module to save weights should be kept original.
combined_model.set_adapter("adapter_2")
assert torch.allclose(
combined_model.model.model.layers[0].post_attention_layernorm.weight,
original_model.model.layers[0].post_attention_layernorm.weight,
)
assert torch.allclose(
combined_model.model.model.layers[1].post_attention_layernorm.weight,
adapter_2_saved,
)
class TestModulesToSaveAttributeAccess:
"""Test attribute access on the ModulesToSaveWrapper class.
When we have modules_to_save, the original module is wrapped. As long as only forward was called on this wrapped
module, we were good. However, if, for instance, model parameters were directly accessed by another module, this
would typically fail, as the wrapper does not have this attribute. We had special properties for weight and bias,
but this is not enough. Therefore, attribute access is now transiently delegated to the active adapter (or original
module, if the adapter is disabled).
For one example, see #2099.
"""
@pytest.fixture
def mlp(self):
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(1, 2)
self.lin1 = nn.Linear(3, 4)
return MLP()
def test_transient_attribute_access_default_adapter(self, mlp):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
assert model.lin1.weight is model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is model.lin1.modules_to_save["default"].bias
def test_transient_attribute_access_non_default_adapter(self, mlp):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
model.add_adapter("other", config)
# at this point, default is still active
assert model.lin1.weight is model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is not model.lin1.modules_to_save["other"].weight
assert model.lin1.bias is not model.lin1.modules_to_save["other"].bias
model.set_adapter("other")
assert model.lin1.weight is not model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is not model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is model.lin1.modules_to_save["other"].weight
assert model.lin1.bias is model.lin1.modules_to_save["other"].bias
def test_transient_attribute_access_disabled_adapter(self, mlp):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
# at this point, default is still active
assert model.lin1.weight is model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is not model.lin1.original_module.weight
assert model.lin1.bias is not model.lin1.original_module.bias
with model.disable_adapter():
assert model.lin1.weight is not model.lin1.modules_to_save["default"].weight
assert model.lin1.bias is not model.lin1.modules_to_save["default"].bias
assert model.lin1.weight is model.lin1.original_module.weight
assert model.lin1.bias is model.lin1.original_module.bias
def test_transient_attribute_access_uninitialized_adapter(self, mlp):
# ensure that there is no weird infinite recursion when accessing a non-existing attribute on the class itself
with pytest.raises(AttributeError, match="has no attribute 'original_module'"):
ModulesToSaveWrapper.original_module
def test_transient_attribute_access_attr_does_not_exist_on_modules_to_save(self, mlp):
# ensure that there is no weird infinite recursion when accessing a non-existing attribute on the
# ModelToSaveWrapper instance
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
with pytest.raises(AttributeError, match="has no attribute 'foo'"):
model.lin1.foo
def test_transient_attribute_access_attr_does_not_exist_on_original_module(self, mlp):
# ensure that there is no weird infinite recursion when accessing a non-existing attribute on the
# original module of the ModelToSaveWrapper instance
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
with pytest.raises(AttributeError, match="has no attribute 'foo'"):
with model.disable_adapter():
model.lin1.foo
def test_transient_attribute_access_non_existing_adapter(self, mlp):
# This should normally never happen, as the active adapter should always exist, but it's a failsafe
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(mlp, config)
model.base_model.model.lin1._active_adapter = "does-not-exist"
with pytest.raises(AttributeError, match="has no attribute 'weight'"):
model.lin1.weight
class TestModulesToSaveNameSubstringBug:
"""Test a bug that could occur with multiple modules to save where one adapter's name is a substring of another
adapter's name.
This bug was the result of an error in the logic of modifying the state_dict for modules_to_save in
set_peft_model_state_dict. The error in the logic was that it was checked if an entry from modules_to_save (a set
of strings) is a substring of a key of the state_dict. If it was, a new name was assigned to that key in the
state_dict, which would allow to load the weight later.
The issue that stems from the substring check occurs if there are multiple modules_to_save, and one of them has a
name that is a substring of another. So e.g. if one is named "classifier" and the other is named "classifier2",
there could be a false match.
This bug was reported in #2289.
"""
def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(5, 4)
# important: "classifier" is a substring of "classifier2", "classifier3", "classifier4"
self.classifier = nn.Linear(4, 2)
self.classifier2 = nn.Linear(4, 2)
self.classifier3 = nn.Linear(4, 2)
self.classifier4 = nn.Linear(4, 2)
def forward(self, x):
x = self.lin(x)
return self.classifier(x) + self.classifier2(x) + self.classifier3(x) + self.classifier4(x)
torch.manual_seed(0)
return MyModule()
@pytest.fixture
def path_merged_and_unmerged(self, tmp_path):
# Create 2 checkpoints:
# 1. merged: the model after calling merge_and_unload
# 2. unmerged: the PEFT model saved without calling merge_and_unload
path = tmp_path / "model.pt"
lora_config = LoraConfig(
target_modules=["lin"],
# important: "classifier" is a substring of "classifier2", "classifier3", "classifier4"
modules_to_save=["classifier", "classifier2", "classifier3", "classifier4"],
)
model = get_peft_model(self.get_model(), lora_config)
# mock training
for _ in range(5):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
output = model(torch.randn(10, 5))
loss = output.sum()
loss.backward()
optimizer.step()
# save the peft model without merging
path_unmerged = tmp_path / "unmerged"
model.save_pretrained(path_unmerged)
# merge the model and save state_dict
path_merged = tmp_path / "merged"
merged = model.merge_and_unload()
state_dict = merged.state_dict()
torch.save(state_dict, path_merged)
return path_merged, path_unmerged
def test_load_merged_and_unmerged_same_weights(self, path_merged_and_unmerged):
# Note that this test is quasi flaky, it has a 1 in 4 chance of passing even without the bugfix. It passes when
# "classifier" happens to be the last element of the set model.modules_to_save. The order of the set is random.
# It is not possible just run this test multiple times to minimize the probability of this happening, because
# within the same process, the hash order is consistent. With the bug fix, this doesn't matter, as the test will
# always pass, but if there is a regression, there is a 1 in 4 chance of not catching it. Since the CI runs many
# tests, it is overall very unlikely that none will catch it though. If you see this test failing in CI, thus be
# aware that some of the passing tests may just pass owing to randomness.
path_merged, path_unmerged = path_merged_and_unmerged
# load the merged model directly
state_dict = torch.load(path_merged, weights_only=True)
model = self.get_model()
model.load_state_dict(state_dict)
sd_merged = model.state_dict()
del model
# load the unmerged model and merge it
unmerged = PeftModel.from_pretrained(self.get_model(), path_unmerged)
sd_unmerged = unmerged.merge_and_unload().state_dict()
assert sd_merged.keys() == sd_unmerged.keys()
for key in sd_merged.keys():
param_merged = sd_merged[key]
param_unmerged = sd_unmerged[key]
assert torch.allclose(param_merged, param_unmerged)
class TestTargetingAuxiliaryTrainingWrapper:
"""AuxiliaryTrainingWrapper such as ModulesToSaveWrapper and TrainableTokensWrapper are
in general not to be targeted by PEFT methods such as adapters. For example, a ModulesToSaveWrapper's children
modules should not be targeted by `LoraConfig(target_modules='all-linear')`, among other things.
"""
@pytest.fixture
def plain_model_cls(self):
class PlainModel(nn.Module):
def __init__(self, i, o):
super().__init__()
self.layer1 = nn.Linear(i, o)
def forward(self, x):
return self.layer1(x)
return PlainModel
@pytest.fixture
def nested_model_cls(self, plain_model_cls):
class NestedModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20)
self.layer2 = nn.Linear(20, 5)
self.layer3 = plain_model_cls(5, 10)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
return NestedModel
def test_nested_ignores_modules_to_save(self, nested_model_cls, plain_model_cls):
# Make sure that `target_modules` is not targeting the nested modules of a module marked as module to save.
model = nested_model_cls()
config = LoraConfig(
target_modules=["layer1"],
modules_to_save=["layer3"],
)
peft_model = get_peft_model(model, config)
assert isinstance(peft_model.model.layer3.modules_to_save.default, plain_model_cls)
def test_targeting_module_to_save_raises(self, nested_model_cls):
model = nested_model_cls()
config = LoraConfig(
target_modules=["layer1"],
modules_to_save=["layer1"],
)
msg = "No modules were targeted for adaptation. This might be caused by a combination"
with pytest.raises(ValueError, match=msg):
get_peft_model(model, config)
def test_modules_to_save_targets_tuner_layer_raises(self):
# See e.g. issue 2027 and 2477
# Prevent users from (accidentally) targeting the same layer both with a tuner and modules_to_save. Normally, PEFT
# will not target the same layer with both a tuner and ModulesToSaveWrapper. However, if modules_to_save is
# automatically inferred, e.g. when using AutoModelForSequenceClassification, the ModulesToSaveWrapper is applied ex
# post, which can lead to the double wrapping.
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
# Note: target_modules="all-linear" would also work and is closer to the original issue, but let's explicitly target
# "score" here in case that "all-linear" will be fixed to no longer target the score layer.
peft_config = LoraConfig(target_modules=["score"], task_type="SEQ_CLS")
# Since the `score` layer is in `model.modules_to_save` it should be ignored when targeted,
# therefore the layer should not be adapted.
msg = "No modules were targeted for adaptation. This might be caused by a combination"
with pytest.raises(ValueError, match=msg) as e:
get_peft_model(model, peft_config)
def test_targeting_trainable_tokens_raises(self):
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
peft_config = LoraConfig(target_modules=["embed_tokens"], task_type="SEQ_CLS", trainable_token_indices=[0, 1])
# While this message might not be the most helpful message, at least it is not silently failing
msg = "trainable_token_indices cannot be applied to modules of type <class 'peft.tuners.lora.layer.Embedding'>"
with pytest.raises(TypeError, match=msg) as e:
get_peft_model(model, peft_config)
class TestAdapterTargeting:
"""Make sure that already existing adapters cannot be targeted to avoid conflicts."""
@pytest.fixture
def base_model_cls(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(10, 20)
self.l2 = torch.nn.Conv2d(1, 1, 2)
def forward(self, x):
return self.l2(self.l1(x))
return M
@pytest.mark.parametrize(
"config_cls, config_kwargs",
[
(LoraConfig, {"target_modules": "l1.*"}),
(LoraConfig, {"target_modules": "l2.*"}),
(VeraConfig, {"target_modules": "l1.*"}),
(VeraConfig, {"target_modules": "(l1|vera_A).*"}), # also target the shared layer
],
)
def test_self_targeting_is_ignored(self, base_model_cls, config_cls, config_kwargs):
base_model = base_model_cls()
config1 = config_cls(**config_kwargs)
config2 = config_cls(**config_kwargs)
adapter1_name = "ADAPTER_1_512858" # sufficiently unique names to make reliable testing easier
adapter2_name = "ADAPTER_2_845781"
peft_model = get_peft_model(base_model, config1, adapter_name=adapter1_name)
state_dict_keys_1 = peft_model.state_dict().keys()
peft_model.add_adapter(adapter2_name, config2)
state_dict_keys_2 = peft_model.state_dict().keys()
# Ideally there should be no new modules targeted beyond existing ModuleDicts. Therefore the keys
# of the new state dict should only differ after the adapter name portion of the keys - not before.
# Expected:
# - a.b.<adapter_name_1>.xyz
# - a.b.<adapter_name_2>.xyz
# We're not expecting this to happen and test against it:
# - a.b.<adapter_name_1>.xyz
# - a.<adapter_name_2>.xyz
def remove_adapter_portion(adapter_name, key):
if key.endswith(f".{adapter_name}"):
return key.removesuffix(f".{adapter_name}")
return key.split(f".{adapter_name}.")[0]
adapter_invariant_keys1 = {remove_adapter_portion(adapter1_name, key) for key in state_dict_keys_1}
adapter_invariant_keys2 = {
remove_adapter_portion(adapter2_name, remove_adapter_portion(adapter1_name, key))
for key in state_dict_keys_2
}
assert adapter_invariant_keys1 == adapter_invariant_keys2
class TestGetNoSplitModules:
# Ensure that children are considered when determining _no_split_modules
# see https://github.com/huggingface/transformers/pull/38141
def test_get_no_split_modules_simple(self):
# choose a model where recursively visiting children is *not* required
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
assert model._no_split_modules == ["OPTDecoderLayer"]
no_split_modules = _get_no_split_modules(model)
assert no_split_modules == {"OPTDecoderLayer"}
def test_get_no_split_modules_recursive(self):
# choose a model where recursively visiting children is required
model_id = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
model = LlavaForConditionalGeneration.from_pretrained(model_id)
# sanity check: just visiting the model itself is not enough:
assert model._no_split_modules == []
no_split_modules = _get_no_split_modules(model)
assert no_split_modules == {"CLIPEncoderLayer", "LlamaDecoderLayer"}
|