Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
File size: 1,813 Bytes
efffeb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""CPU-only static validation for the Nomen-AI public repo.

Run:
  pip install -e . datasets trl peft transformers rapidfuzz pyphen
  python tests/test_static.py
"""
from datasets import load_dataset
from trl import SFTConfig, DPOConfig
from nomen_ai.control import ControlVector, ROOT_FAMILIES
from nomen_ai.phonetics import count_syllables, char_len
from nomen_ai.antidup import AntiDuplicationMatrix
from nomen_ai.synth import make_example
import random


def main():
    assert len(ROOT_FAMILIES) >= 20
    cv = ControlVector(roots=['japanese','nordic'], blend=[40,60], theme='gaming', syllables=3, char_len=8, creativity=0.8).validate()
    assert '[ROOT:japanese:40+nordic:60]' in cv.to_prompt()
    assert count_syllables('Velorix') == 3
    assert char_len('Velorix') == 7
    anti = AntiDuplicationMatrix()
    assert anti.is_novel('Spotifu') is False
    assert make_example(['japanese','nordic'], 'gaming', 0.8, random.Random(0), anti) is not None

    sft = load_dataset('krystv/nomen-ai-sft', split='train[:2]')
    dpo = load_dataset('krystv/nomen-ai-dpo', split='train[:2]')
    assert 'messages' in sft.column_names
    assert all(c in dpo.column_names for c in ['prompt','chosen','rejected'])

    sft_fields=set(SFTConfig.__dataclass_fields__)
    dpo_fields=set(DPOConfig.__dataclass_fields__)
    for f in ['max_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
        assert f in sft_fields, f'missing SFTConfig {f}'
    for f in ['max_length','max_prompt_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
        assert f in dpo_fields, f'missing DPOConfig {f}'

    print('CPU_STATIC_VALIDATION_PASS')


if __name__ == '__main__':
    main()