Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
krystv commited on
Commit
ce6edc0
·
verified ·
1 Parent(s): 7a8a6e2

Add dataset builder script

Browse files
Files changed (1) hide show
  1. scripts/build_dataset.py +51 -0
scripts/build_dataset.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Build and push Nomen-AI SFT/DPO datasets."""
2
+ import argparse, random, re
3
+ from datasets import Dataset, DatasetDict
4
+ from nomen_ai.control import ROOT_FAMILIES, THEMES, ControlVector, SYSTEM_PROMPT
5
+ from nomen_ai.synth import make_example
6
+ from nomen_ai.antidup import AntiDuplicationMatrix
7
+ from nomen_ai.phonetics import count_syllables
8
+ GENERIC_SUFFIXES=["ify","ly","ster","hub","ged","io","zy","able"]
9
+ GENERIC_PREFIXES=["Tech","Smart","Get","My","Go","The","Best","Pro"]
10
+ GENERIC_CORES=["brand","name","tube","vlog","media","studio","world","zone"]
11
+
12
+ def quality_ok(name):
13
+ return 3<=len(name)<=13 and re.search(r"[aeiouy]",name.lower()) and count_syllables(name)<=4 and not re.search(r"(.)\1\1",name.lower())
14
+
15
+ def derivative_name(rng):
16
+ x=rng.random()
17
+ if x<0.4: return rng.choice(GENERIC_PREFIXES)+rng.choice(GENERIC_CORES).capitalize()
18
+ if x<0.7: return rng.choice(["Stream","Click","Pixel","Brand","Creat","Snap","Vibe"])+rng.choice(GENERIC_SUFFIXES)
19
+ return rng.choice(GENERIC_PREFIXES)+rng.choice(GENERIC_SUFFIXES).capitalize()
20
+
21
+ def sample_cv(rng):
22
+ k=rng.choices([1,2,3],weights=[5,4,1])[0]
23
+ return rng.sample(ROOT_FAMILIES,k), rng.choice(THEMES), rng.choice([0.1,0.3,0.5,0.7,0.9])
24
+
25
+ def build(n_sft,n_dpo,seed=42):
26
+ rng=random.Random(seed); anti=AntiDuplicationMatrix(min_novelty=0.7); sft=[]; dpo=[]
27
+ while len(sft)<n_sft:
28
+ roots,theme,creativity=sample_cv(rng); ex=make_example(roots,theme,creativity,rng,anti)
29
+ if not ex: continue
30
+ name,syl,clen=ex
31
+ if not quality_ok(name): continue
32
+ cv=ControlVector(roots=roots,theme=theme,syllables=syl,char_len=clen,creativity=creativity)
33
+ sft.append({'messages':[{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':cv.to_prompt()},{'role':'assistant','content':name}]})
34
+ while len(dpo)<n_dpo:
35
+ roots,theme,creativity=sample_cv(rng); ex=make_example(roots,theme,creativity,rng,anti)
36
+ if not ex: continue
37
+ name,syl,clen=ex
38
+ if not quality_ok(name): continue
39
+ cv=ControlVector(roots=roots,theme=theme,syllables=syl,char_len=clen,creativity=creativity)
40
+ dpo.append({'prompt':[{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':cv.to_prompt()}], 'chosen':[{'role':'assistant','content':name}], 'rejected':[{'role':'assistant','content':derivative_name(rng)}]})
41
+ return sft,dpo
42
+
43
+ def split(rows):
44
+ cut=int(len(rows)*0.97); return DatasetDict(train=Dataset.from_list(rows[:cut]), test=Dataset.from_list(rows[cut:]))
45
+
46
+ def main():
47
+ ap=argparse.ArgumentParser(); ap.add_argument('--n_sft',type=int,default=12000); ap.add_argument('--n_dpo',type=int,default=6000); ap.add_argument('--sft_repo',default='krystv/nomen-ai-sft'); ap.add_argument('--dpo_repo',default='krystv/nomen-ai-dpo'); ap.add_argument('--push',action='store_true'); ap.add_argument('--out',default='data'); args=ap.parse_args()
48
+ sft,dpo=build(args.n_sft,args.n_dpo); print('SFT rows',len(sft),'DPO rows',len(dpo)); sft_ds,dpo_ds=split(sft),split(dpo); sft_ds.save_to_disk(f'{args.out}/sft'); dpo_ds.save_to_disk(f'{args.out}/dpo')
49
+ print('sample', sft[0]['messages'][1]['content'], '->', sft[0]['messages'][2]['content']); print('dpo', dpo[0]['chosen'][0]['content'], 'vs', dpo[0]['rejected'][0]['content'])
50
+ if args.push: sft_ds.push_to_hub(args.sft_repo); dpo_ds.push_to_hub(args.dpo_repo)
51
+ if __name__=='__main__': main()