Instructions to use krystv/nomen-ai with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use krystv/nomen-ai with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
Add static validation tests
Browse files- tests/test_static.py +42 -0
tests/test_static.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CPU-only static validation for the Nomen-AI public repo.
|
| 2 |
+
|
| 3 |
+
Run:
|
| 4 |
+
pip install -e . datasets trl peft transformers rapidfuzz pyphen
|
| 5 |
+
python tests/test_static.py
|
| 6 |
+
"""
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
from trl import SFTConfig, DPOConfig
|
| 9 |
+
from nomen_ai.control import ControlVector, ROOT_FAMILIES
|
| 10 |
+
from nomen_ai.phonetics import count_syllables, char_len
|
| 11 |
+
from nomen_ai.antidup import AntiDuplicationMatrix
|
| 12 |
+
from nomen_ai.synth import make_example
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
assert len(ROOT_FAMILIES) >= 20
|
| 18 |
+
cv = ControlVector(roots=['japanese','nordic'], blend=[40,60], theme='gaming', syllables=3, char_len=8, creativity=0.8).validate()
|
| 19 |
+
assert '[ROOT:japanese:40+nordic:60]' in cv.to_prompt()
|
| 20 |
+
assert count_syllables('Velorix') == 3
|
| 21 |
+
assert char_len('Velorix') == 7
|
| 22 |
+
anti = AntiDuplicationMatrix()
|
| 23 |
+
assert anti.is_novel('Spotifu') is False
|
| 24 |
+
assert make_example(['japanese','nordic'], 'gaming', 0.8, random.Random(0), anti) is not None
|
| 25 |
+
|
| 26 |
+
sft = load_dataset('krystv/nomen-ai-sft', split='train[:2]')
|
| 27 |
+
dpo = load_dataset('krystv/nomen-ai-dpo', split='train[:2]')
|
| 28 |
+
assert 'messages' in sft.column_names
|
| 29 |
+
assert all(c in dpo.column_names for c in ['prompt','chosen','rejected'])
|
| 30 |
+
|
| 31 |
+
sft_fields=set(SFTConfig.__dataclass_fields__)
|
| 32 |
+
dpo_fields=set(DPOConfig.__dataclass_fields__)
|
| 33 |
+
for f in ['max_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
|
| 34 |
+
assert f in sft_fields, f'missing SFTConfig {f}'
|
| 35 |
+
for f in ['max_length','max_prompt_length','fp16','gradient_checkpointing','logging_first_step','disable_tqdm','eval_strategy','push_to_hub','hub_model_id']:
|
| 36 |
+
assert f in dpo_fields, f'missing DPOConfig {f}'
|
| 37 |
+
|
| 38 |
+
print('CPU_STATIC_VALIDATION_PASS')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == '__main__':
|
| 42 |
+
main()
|