Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
nomen-ai / tests /test_static.py
krystv's picture
Add static validation tests
efffeb1 verified
"""CPU-only static validation for the Nomen-AI public repo.
Run:
pip install -e . datasets trl peft transformers rapidfuzz pyphen
python tests/test_static.py
"""
from datasets import load_dataset
from trl import SFTConfig, DPOConfig
from nomen_ai.control import ControlVector, ROOT_FAMILIES
from nomen_ai.phonetics import count_syllables, char_len
from nomen_ai.antidup import AntiDuplicationMatrix
from nomen_ai.synth import make_example
import random
def main():
assert len(ROOT_FAMILIES) >= 20
cv = ControlVector(roots=['japanese','nordic'], blend=[40,60], theme='gaming', syllables=3, char_len=8, creativity=0.8).validate()
assert '[ROOT:japanese:40+nordic:60]' in cv.to_prompt()
assert count_syllables('Velorix') == 3
assert char_len('Velorix') == 7
anti = AntiDuplicationMatrix()
assert anti.is_novel('Spotifu') is False
assert make_example(['japanese','nordic'], 'gaming', 0.8, random.Random(0), anti) is not None
sft = load_dataset('krystv/nomen-ai-sft', split='train[:2]')
dpo = load_dataset('krystv/nomen-ai-dpo', split='train[:2]')
assert 'messages' in sft.column_names
assert all(c in dpo.column_names for c in ['prompt','chosen','rejected'])
sft_fields=set(SFTConfig.__dataclass_fields__)
dpo_fields=set(DPOConfig.__dataclass_fields__)
for f in ['max_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
assert f in sft_fields, f'missing SFTConfig {f}'
for f in ['max_length','max_prompt_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
assert f in dpo_fields, f'missing DPOConfig {f}'
print('CPU_STATIC_VALIDATION_PASS')
if __name__ == '__main__':
main()