Instructions to use krystv/nomen-ai with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use krystv/nomen-ai with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
Add dataset builder script
Browse files- scripts/build_dataset.py +51 -0
scripts/build_dataset.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Build and push Nomen-AI SFT/DPO datasets."""
|
| 2 |
+
import argparse, random, re
|
| 3 |
+
from datasets import Dataset, DatasetDict
|
| 4 |
+
from nomen_ai.control import ROOT_FAMILIES, THEMES, ControlVector, SYSTEM_PROMPT
|
| 5 |
+
from nomen_ai.synth import make_example
|
| 6 |
+
from nomen_ai.antidup import AntiDuplicationMatrix
|
| 7 |
+
from nomen_ai.phonetics import count_syllables
|
| 8 |
+
GENERIC_SUFFIXES=["ify","ly","ster","hub","ged","io","zy","able"]
|
| 9 |
+
GENERIC_PREFIXES=["Tech","Smart","Get","My","Go","The","Best","Pro"]
|
| 10 |
+
GENERIC_CORES=["brand","name","tube","vlog","media","studio","world","zone"]
|
| 11 |
+
|
| 12 |
+
def quality_ok(name):
|
| 13 |
+
return 3<=len(name)<=13 and re.search(r"[aeiouy]",name.lower()) and count_syllables(name)<=4 and not re.search(r"(.)\1\1",name.lower())
|
| 14 |
+
|
| 15 |
+
def derivative_name(rng):
|
| 16 |
+
x=rng.random()
|
| 17 |
+
if x<0.4: return rng.choice(GENERIC_PREFIXES)+rng.choice(GENERIC_CORES).capitalize()
|
| 18 |
+
if x<0.7: return rng.choice(["Stream","Click","Pixel","Brand","Creat","Snap","Vibe"])+rng.choice(GENERIC_SUFFIXES)
|
| 19 |
+
return rng.choice(GENERIC_PREFIXES)+rng.choice(GENERIC_SUFFIXES).capitalize()
|
| 20 |
+
|
| 21 |
+
def sample_cv(rng):
|
| 22 |
+
k=rng.choices([1,2,3],weights=[5,4,1])[0]
|
| 23 |
+
return rng.sample(ROOT_FAMILIES,k), rng.choice(THEMES), rng.choice([0.1,0.3,0.5,0.7,0.9])
|
| 24 |
+
|
| 25 |
+
def build(n_sft,n_dpo,seed=42):
|
| 26 |
+
rng=random.Random(seed); anti=AntiDuplicationMatrix(min_novelty=0.7); sft=[]; dpo=[]
|
| 27 |
+
while len(sft)<n_sft:
|
| 28 |
+
roots,theme,creativity=sample_cv(rng); ex=make_example(roots,theme,creativity,rng,anti)
|
| 29 |
+
if not ex: continue
|
| 30 |
+
name,syl,clen=ex
|
| 31 |
+
if not quality_ok(name): continue
|
| 32 |
+
cv=ControlVector(roots=roots,theme=theme,syllables=syl,char_len=clen,creativity=creativity)
|
| 33 |
+
sft.append({'messages':[{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':cv.to_prompt()},{'role':'assistant','content':name}]})
|
| 34 |
+
while len(dpo)<n_dpo:
|
| 35 |
+
roots,theme,creativity=sample_cv(rng); ex=make_example(roots,theme,creativity,rng,anti)
|
| 36 |
+
if not ex: continue
|
| 37 |
+
name,syl,clen=ex
|
| 38 |
+
if not quality_ok(name): continue
|
| 39 |
+
cv=ControlVector(roots=roots,theme=theme,syllables=syl,char_len=clen,creativity=creativity)
|
| 40 |
+
dpo.append({'prompt':[{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':cv.to_prompt()}], 'chosen':[{'role':'assistant','content':name}], 'rejected':[{'role':'assistant','content':derivative_name(rng)}]})
|
| 41 |
+
return sft,dpo
|
| 42 |
+
|
| 43 |
+
def split(rows):
|
| 44 |
+
cut=int(len(rows)*0.97); return DatasetDict(train=Dataset.from_list(rows[:cut]), test=Dataset.from_list(rows[cut:]))
|
| 45 |
+
|
| 46 |
+
def main():
|
| 47 |
+
ap=argparse.ArgumentParser(); ap.add_argument('--n_sft',type=int,default=12000); ap.add_argument('--n_dpo',type=int,default=6000); ap.add_argument('--sft_repo',default='krystv/nomen-ai-sft'); ap.add_argument('--dpo_repo',default='krystv/nomen-ai-dpo'); ap.add_argument('--push',action='store_true'); ap.add_argument('--out',default='data'); args=ap.parse_args()
|
| 48 |
+
sft,dpo=build(args.n_sft,args.n_dpo); print('SFT rows',len(sft),'DPO rows',len(dpo)); sft_ds,dpo_ds=split(sft),split(dpo); sft_ds.save_to_disk(f'{args.out}/sft'); dpo_ds.save_to_disk(f'{args.out}/dpo')
|
| 49 |
+
print('sample', sft[0]['messages'][1]['content'], '->', sft[0]['messages'][2]['content']); print('dpo', dpo[0]['chosen'][0]['content'], 'vs', dpo[0]['rejected'][0]['content'])
|
| 50 |
+
if args.push: sft_ds.push_to_hub(args.sft_repo); dpo_ds.push_to_hub(args.dpo_repo)
|
| 51 |
+
if __name__=='__main__': main()
|