Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
File size: 3,391 Bytes
ce6edc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Build and push Nomen-AI SFT/DPO datasets."""
import argparse, random, re
from datasets import Dataset, DatasetDict
from nomen_ai.control import ROOT_FAMILIES, THEMES, ControlVector, SYSTEM_PROMPT
from nomen_ai.synth import make_example
from nomen_ai.antidup import AntiDuplicationMatrix
from nomen_ai.phonetics import count_syllables
GENERIC_SUFFIXES=["ify","ly","ster","hub","ged","io","zy","able"]
GENERIC_PREFIXES=["Tech","Smart","Get","My","Go","The","Best","Pro"]
GENERIC_CORES=["brand","name","tube","vlog","media","studio","world","zone"]

def quality_ok(name):
    return 3<=len(name)<=13 and re.search(r"[aeiouy]",name.lower()) and count_syllables(name)<=4 and not re.search(r"(.)\1\1",name.lower())

def derivative_name(rng):
    x=rng.random()
    if x<0.4: return rng.choice(GENERIC_PREFIXES)+rng.choice(GENERIC_CORES).capitalize()
    if x<0.7: return rng.choice(["Stream","Click","Pixel","Brand","Creat","Snap","Vibe"])+rng.choice(GENERIC_SUFFIXES)
    return rng.choice(GENERIC_PREFIXES)+rng.choice(GENERIC_SUFFIXES).capitalize()

def sample_cv(rng):
    k=rng.choices([1,2,3],weights=[5,4,1])[0]
    return rng.sample(ROOT_FAMILIES,k), rng.choice(THEMES), rng.choice([0.1,0.3,0.5,0.7,0.9])

def build(n_sft,n_dpo,seed=42):
    rng=random.Random(seed); anti=AntiDuplicationMatrix(min_novelty=0.7); sft=[]; dpo=[]
    while len(sft)<n_sft:
        roots,theme,creativity=sample_cv(rng); ex=make_example(roots,theme,creativity,rng,anti)
        if not ex: continue
        name,syl,clen=ex
        if not quality_ok(name): continue
        cv=ControlVector(roots=roots,theme=theme,syllables=syl,char_len=clen,creativity=creativity)
        sft.append({'messages':[{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':cv.to_prompt()},{'role':'assistant','content':name}]})
    while len(dpo)<n_dpo:
        roots,theme,creativity=sample_cv(rng); ex=make_example(roots,theme,creativity,rng,anti)
        if not ex: continue
        name,syl,clen=ex
        if not quality_ok(name): continue
        cv=ControlVector(roots=roots,theme=theme,syllables=syl,char_len=clen,creativity=creativity)
        dpo.append({'prompt':[{'role':'system','content':SYSTEM_PROMPT},{'role':'user','content':cv.to_prompt()}], 'chosen':[{'role':'assistant','content':name}], 'rejected':[{'role':'assistant','content':derivative_name(rng)}]})
    return sft,dpo

def split(rows):
    cut=int(len(rows)*0.97); return DatasetDict(train=Dataset.from_list(rows[:cut]), test=Dataset.from_list(rows[cut:]))

def main():
    ap=argparse.ArgumentParser(); ap.add_argument('--n_sft',type=int,default=12000); ap.add_argument('--n_dpo',type=int,default=6000); ap.add_argument('--sft_repo',default='krystv/nomen-ai-sft'); ap.add_argument('--dpo_repo',default='krystv/nomen-ai-dpo'); ap.add_argument('--push',action='store_true'); ap.add_argument('--out',default='data'); args=ap.parse_args()
    sft,dpo=build(args.n_sft,args.n_dpo); print('SFT rows',len(sft),'DPO rows',len(dpo)); sft_ds,dpo_ds=split(sft),split(dpo); sft_ds.save_to_disk(f'{args.out}/sft'); dpo_ds.save_to_disk(f'{args.out}/dpo')
    print('sample', sft[0]['messages'][1]['content'], '->', sft[0]['messages'][2]['content']); print('dpo', dpo[0]['chosen'][0]['content'], 'vs', dpo[0]['rejected'][0]['content'])
    if args.push: sft_ds.push_to_hub(args.sft_repo); dpo_ds.push_to_hub(args.dpo_repo)
if __name__=='__main__': main()