Text Generation
PEFT
lora
trl
naming
brand-generation
controllable-generation
File size: 1,536 Bytes
40f0fee
 
a6cbdcb
 
40f0fee
 
 
 
 
 
 
 
 
 
a6cbdcb
 
 
 
 
40f0fee
a6cbdcb
40f0fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Check whether required Nomen-AI Hub artifacts exist.

Run before DPO/evaluation. It checks that dataset repos load and that adapter
repos contain PEFT adapter weight files.
"""
from huggingface_hub import HfApi
from datasets import load_dataset

SFT_DATASET = 'krystv/nomen-ai-sft'
DPO_DATASET = 'krystv/nomen-ai-dpo'
SFT_ADAPTER = 'krystv/nomen-ai-sft-lora'
DPO_ADAPTER = 'krystv/nomen-ai-dpo-lora'


def repo_files(repo_id: str):
    info = HfApi().repo_info(repo_id=repo_id, repo_type='model')
    return [s.rfilename for s in info.siblings]


def has_adapter_weights(repo_id: str) -> bool:
    files = repo_files(repo_id)
    return any(name.endswith(('adapter_model.safetensors', 'adapter_model.bin')) for name in files)


def main():
    print('Checking datasets...')
    sft = load_dataset(SFT_DATASET, split='train[:1]')
    dpo = load_dataset(DPO_DATASET, split='train[:1]')
    print('SFT columns:', sft.column_names)
    print('DPO columns:', dpo.column_names)

    print('Checking adapters...')
    sft_ready = has_adapter_weights(SFT_ADAPTER)
    dpo_ready = has_adapter_weights(DPO_ADAPTER)
    print(f'SFT adapter weights present: {sft_ready}')
    print(f'DPO adapter weights present: {dpo_ready}')

    if not sft_ready:
        print('ACTION: run scripts/train_sft.py before DPO/evaluation.')
    if sft_ready and not dpo_ready:
        print('ACTION: run scripts/train_dpo.py to create the DPO adapter.')
    if sft_ready and dpo_ready:
        print('ALL_ARTIFACTS_READY')


if __name__ == '__main__':
    main()