File size: 2,178 Bytes
51be264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import fire
import torch

import moe_peft


def main(
    base_model: str,
    adapter_name: str = "lora_0",
    train_data: str = "TUDB-Labs/Dummy-MoE-PEFT",
    test_prompt: str = "Could you provide an introduction to MoE-PEFT?",
):
    moe_peft.setup_logging("INFO")

    model: moe_peft.LLMModel = moe_peft.LLMModel.from_pretrained(
        base_model,
        device=moe_peft.executor.default_device_name(),
        load_dtype=torch.bfloat16,
    )
    tokenizer = moe_peft.Tokenizer(base_model)

    lora_config = moe_peft.LoraConfig(
        adapter_name=adapter_name,
        lora_r_=32,
        lora_alpha_=64,
        lora_dropout_=0.05,
        target_modules_={
            "q_proj": True,
            "k_proj": True,
            "v_proj": True,
            "o_proj": True,
        },
    )

    train_config = moe_peft.TrainConfig(
        adapter_name=adapter_name,
        data_path=train_data,
        num_epochs=10,
        batch_size=16,
        micro_batch_size=8,
        learning_rate=1e-4,
    )

    with moe_peft.executors.no_cache():
        model.init_adapter(lora_config)
        moe_peft.train(model=model, tokenizer=tokenizer, configs=[train_config])
        lora_config, lora_weight = model.unload_adapter(adapter_name)

    generate_configs = [
        moe_peft.GenerateConfig(
            adapter_name=adapter_name,
            prompts=[test_prompt],
            stop_token="\n",
        ),
        moe_peft.GenerateConfig(
            adapter_name="default",
            prompts=[test_prompt],
            stop_token="\n",
        ),
    ]

    with moe_peft.executors.no_cache():
        model.init_adapter(lora_config, lora_weight)
        model.init_adapter(moe_peft.AdapterConfig(adapter_name="default"))
        outputs = moe_peft.generate(
            model=model,
            tokenizer=tokenizer,
            configs=generate_configs,
            max_gen_len=128,
        )

    print(f"\n{'='*10}\n")
    print(f"PROMPT: {test_prompt}\n")
    for adapter_name, output in outputs.items():
        print(f"{adapter_name} OUTPUT:")
        print(f"{output[0]}\n")
    print(f"\n{'='*10}\n")


if __name__ == "__main__":
    fire.Fire(main)