Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
krystv commited on
Commit
efffeb1
·
verified ·
1 Parent(s): 451bc6a

Add static validation tests

Browse files
Files changed (1) hide show
  1. tests/test_static.py +42 -0
tests/test_static.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CPU-only static validation for the Nomen-AI public repo.
2
+
3
+ Run:
4
+ pip install -e . datasets trl peft transformers rapidfuzz pyphen
5
+ python tests/test_static.py
6
+ """
7
+ from datasets import load_dataset
8
+ from trl import SFTConfig, DPOConfig
9
+ from nomen_ai.control import ControlVector, ROOT_FAMILIES
10
+ from nomen_ai.phonetics import count_syllables, char_len
11
+ from nomen_ai.antidup import AntiDuplicationMatrix
12
+ from nomen_ai.synth import make_example
13
+ import random
14
+
15
+
16
+ def main():
17
+ assert len(ROOT_FAMILIES) >= 20
18
+ cv = ControlVector(roots=['japanese','nordic'], blend=[40,60], theme='gaming', syllables=3, char_len=8, creativity=0.8).validate()
19
+ assert '[ROOT:japanese:40+nordic:60]' in cv.to_prompt()
20
+ assert count_syllables('Velorix') == 3
21
+ assert char_len('Velorix') == 7
22
+ anti = AntiDuplicationMatrix()
23
+ assert anti.is_novel('Spotifu') is False
24
+ assert make_example(['japanese','nordic'], 'gaming', 0.8, random.Random(0), anti) is not None
25
+
26
+ sft = load_dataset('krystv/nomen-ai-sft', split='train[:2]')
27
+ dpo = load_dataset('krystv/nomen-ai-dpo', split='train[:2]')
28
+ assert 'messages' in sft.column_names
29
+ assert all(c in dpo.column_names for c in ['prompt','chosen','rejected'])
30
+
31
+ sft_fields=set(SFTConfig.__dataclass_fields__)
32
+ dpo_fields=set(DPOConfig.__dataclass_fields__)
33
+ for f in ['max_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
34
+ assert f in sft_fields, f'missing SFTConfig {f}'
35
+ for f in ['max_length','max_prompt_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
36
+ assert f in dpo_fields, f'missing DPOConfig {f}'
37
+
38
+ print('CPU_STATIC_VALIDATION_PASS')
39
+
40
+
41
+ if __name__ == '__main__':
42
+ main()