| from typing import List |
|
|
| import pytest |
|
|
| import lm_eval |
|
|
|
|
| def assert_less_than(value, threshold, desc): |
| if value is not None: |
| assert float(value) < threshold, f"{desc} should be less than {threshold}" |
|
|
|
|
| @pytest.mark.skip(reason="requires CUDA") |
| class Test_GPTQModel: |
| gptqmodel = pytest.importorskip("gptqmodel", minversion="1.0.9") |
| MODEL_ID = "ModelCloud/Opt-125-GPTQ-4bit-10-25-2024" |
|
|
| def test_gptqmodel(self) -> None: |
| acc = "acc" |
| acc_norm = "acc_norm" |
| acc_value = None |
| acc_norm_value = None |
| task = "arc_easy" |
|
|
| model_args = f"pretrained={self.MODEL_ID},gptqmodel=True" |
|
|
| tasks: List[str] = [task] |
|
|
| results = lm_eval.simple_evaluate( |
| model="hf", |
| model_args=model_args, |
| tasks=tasks, |
| device="cuda", |
| ) |
|
|
| column = "results" |
| dic = results.get(column, {}).get(self.task) |
| if dic is not None: |
| if "alias" in dic: |
| _ = dic.pop("alias") |
| items = sorted(dic.items()) |
| for k, v in items: |
| m, _, f = k.partition(",") |
| if m.endswith("_stderr"): |
| continue |
|
|
| if m == acc: |
| acc_value = "%.4f" % v if isinstance(v, float) else v |
|
|
| if m == acc_norm: |
| acc_norm_value = "%.4f" % v if isinstance(v, float) else v |
|
|
| assert_less_than(acc_value, 0.43, "acc") |
| assert_less_than(acc_norm_value, 0.39, "acc_norm") |
|
|