File size: 9,930 Bytes
a9bd396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | import gc
import unittest
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
from transformers.utils import is_torch_available
from transformers.utils.quantization_config import CompressedTensorsConfig
if is_torch_available():
import torch
@require_compressed_tensors
@require_torch
class StackCompressedModelTest(unittest.TestCase):
# Define stubs as class attributes
compressed_uncompressed_model_stubs = [
(
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
),
(
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
),
(
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
),
]
# Flatten the list for tests that require a single list of stubs.
model_stubs = [stub for pair in compressed_uncompressed_model_stubs for stub in pair]
# For the outputs matching test, use the sparse-only pair.
sparse_compressed_model = "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed"
sparse_uncompressed_model = "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed"
prompt = "Paris is the capital of which country?"
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
def test_compressed_uncompressed_model_shapes(self):
"""
Verify that the weights of an uncompressed model and its decompressed compressed counterpart match.
Note: Weights for sparsely compressed models may differ due to packing.
"""
def _has_nested_attr(obj, attr_path):
attrs = attr_path.split(".")
for attr in attrs:
if not hasattr(obj, attr):
return None
obj = getattr(obj, attr)
return obj
from compressed_tensors.quantization.utils import iter_named_leaf_modules
for compressed_model, uncompressed_model in self.compressed_uncompressed_model_stubs:
with self.subTest(compressed_model=compressed_model, uncompressed_model=uncompressed_model):
uncompressed = AutoModelForCausalLM.from_pretrained(
uncompressed_model,
device_map="auto",
dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
compressed_decompressed = AutoModelForCausalLM.from_pretrained(
compressed_model,
device_map="auto",
dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
for name, submodule in iter_named_leaf_modules(uncompressed):
comp_decomp_obj = _has_nested_attr(compressed_decompressed, name)
if comp_decomp_obj is not None and hasattr(submodule, "weight"):
if "sparse-only" in uncompressed_model:
self.assertTrue(
torch.equal(
submodule.weight.to(torch_device), comp_decomp_obj.weight.to(torch_device)
),
f"Weight mismatch for module '{name}' in sparse-only model.",
)
else:
self.assertTrue(
torch.allclose(
submodule.weight.to(torch_device),
comp_decomp_obj.weight.to(torch_device),
atol=0.2,
),
f"Weight mismatch for module '{name}' in quantized-only or stacked model.",
)
def test_outputs_match(self):
"""
Ensure that the generated outputs match between the uncompressed model
and its decompressed compressed counterpart.
"""
tokenizer = AutoTokenizer.from_pretrained(self.sparse_uncompressed_model)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
uncompressed = AutoModelForCausalLM.from_pretrained(
self.sparse_uncompressed_model,
device_map="auto",
dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
output_uncompressed = uncompressed.generate(input_ids.to(uncompressed.device), max_new_tokens=100)
decompressed = AutoModelForCausalLM.from_pretrained(
self.sparse_compressed_model,
device_map="auto",
dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
output_decompressed = decompressed.generate(input_ids.to(decompressed.device), max_new_tokens=100)
self.assertEqual(
tokenizer.decode(output_uncompressed[0]),
tokenizer.decode(output_decompressed[0]),
"Generated outputs do not match between compressed and uncompressed models.",
)
def test_no_warnings_for_all_models(self):
"""
Confirm that loading any model using compressed tensors does not trigger
warnings about missing or unexpected keys.
"""
for model_stub in self.model_stubs:
with self.subTest(model_stub=model_stub):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
AutoModelForCausalLM.from_pretrained(
model_stub,
device_map="auto",
dtype="auto",
quantization_config=CompressedTensorsConfig(run_compressed=False),
)
for warning in caught_warnings:
self.assertNotIn(
"missing keys",
str(warning.message).lower(),
f"'missing keys' found in warnings for model {model_stub}",
)
self.assertNotIn(
"unexpected keys",
str(warning.message).lower(),
f"'unexpected keys' found in warnings for model {model_stub}",
)
@require_compressed_tensors
@require_torch
class RunCompressedTest(unittest.TestCase):
tinyllama_w4a16 = "nm-testing/tinyllama-w4a16-compressed"
tinyllama_w8a8 = "nm-testing/tinyllama-w8a8-compressed"
prompt = "Paris is the capital of which country?"
stubs = [tinyllama_w4a16, tinyllama_w8a8]
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
def test_default_run_compressed__True(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# some linear models are not compressed - ex. lm_head
assert compressed_linear_counts > 0
def test_default_run_compressed__False(self):
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
model = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
compressed_linear_counts = 0
for _, submodule in iter_named_leaf_modules(
model,
):
if isinstance(submodule, CompressedLinear):
compressed_linear_counts += 1
# No modules should be CompressedLinear
assert compressed_linear_counts == 0
def test_run_compressed_outputs_match(self):
"""Check that run_compressed=True/False output are the same"""
from transformers import AutoTokenizer
from transformers.utils.quantization_config import CompressedTensorsConfig
quantization_config = CompressedTensorsConfig(run_compressed=False)
for stub in self.stubs:
tokenizer = AutoTokenizer.from_pretrained(stub)
input_ids = tokenizer(self.prompt, return_tensors="pt").input_ids
model_run_compressed__True = AutoModelForCausalLM.from_pretrained(
stub,
)
output_rc_true = model_run_compressed__True.generate(input_ids, max_new_tokens=100)
model_run_compressed__False = AutoModelForCausalLM.from_pretrained(
stub,
quantization_config=quantization_config,
)
output_rc_false = model_run_compressed__False.generate(input_ids, max_new_tokens=100)
assert tokenizer.decode(output_rc_true[0]) == tokenizer.decode(output_rc_false[0])
|