|
|
""" |
|
|
Test precision mode mapping to TrainingArguments. |
|
|
""" |
|
|
|
|
|
from unittest.mock import Mock |
|
|
|
|
|
from humigence.train import QLoRATrainer |
|
|
|
|
|
|
|
|
class TestPrecisionModeMapping: |
|
|
"""Test that precision modes correctly map to TrainingArguments flags.""" |
|
|
|
|
|
def test_qlora_nf4_precision_mapping(self): |
|
|
"""Test qlora_nf4 maps to fp16=True, bf16=False.""" |
|
|
|
|
|
config = Mock() |
|
|
config.train.precision_mode = "qlora_nf4" |
|
|
|
|
|
|
|
|
trainer = QLoRATrainer.__new__(QLoRATrainer) |
|
|
trainer.config = config |
|
|
|
|
|
|
|
|
precision_mode = trainer.config.train.precision_mode |
|
|
fp16, bf16 = False, False |
|
|
|
|
|
if precision_mode == "qlora_nf4": |
|
|
|
|
|
fp16 = True |
|
|
bf16 = False |
|
|
elif precision_mode == "lora_fp16": |
|
|
|
|
|
fp16 = True |
|
|
bf16 = False |
|
|
elif precision_mode == "lora_bf16": |
|
|
|
|
|
fp16 = False |
|
|
bf16 = True |
|
|
elif precision_mode == "lora_int8": |
|
|
|
|
|
fp16 = False |
|
|
bf16 = False |
|
|
else: |
|
|
|
|
|
fp16 = True |
|
|
bf16 = False |
|
|
|
|
|
|
|
|
assert fp16 is True |
|
|
assert bf16 is False |
|
|
|
|
|
def test_lora_bf16_precision_mapping(self): |
|
|
"""Test lora_bf16 maps to fp16=False, bf16=True.""" |
|
|
|
|
|
config = Mock() |
|
|
config.train.precision_mode = "lora_bf16" |
|
|
|
|
|
|
|
|
trainer = QLoRATrainer.__new__(QLoRATrainer) |
|
|
trainer.config = config |
|
|
|
|
|
|
|
|
precision_mode = trainer.config.train.precision_mode |
|
|
fp16, bf16 = False, False |
|
|
|
|
|
if precision_mode == "qlora_nf4": |
|
|
|
|
|
fp16 = True |
|
|
bf16 = False |
|
|
elif precision_mode == "lora_fp16": |
|
|
|
|
|
fp16 = True |
|
|
bf16 = False |
|
|
elif precision_mode == "lora_bf16": |
|
|
|
|
|
fp16 = False |
|
|
bf16 = True |
|
|
elif precision_mode == "lora_int8": |
|
|
|
|
|
fp16 = False |
|
|
bf16 = False |
|
|
else: |
|
|
|
|
|
fp16 = True |
|
|
bf16 = False |
|
|
|
|
|
|
|
|
assert fp16 is False |
|
|
assert bf16 is True |
|
|
|