File size: 2,178 Bytes
51be264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import fire
import torch
import moe_peft
def main(
base_model: str,
adapter_name: str = "lora_0",
train_data: str = "TUDB-Labs/Dummy-MoE-PEFT",
test_prompt: str = "Could you provide an introduction to MoE-PEFT?",
):
moe_peft.setup_logging("INFO")
model: moe_peft.LLMModel = moe_peft.LLMModel.from_pretrained(
base_model,
device=moe_peft.executor.default_device_name(),
load_dtype=torch.bfloat16,
)
tokenizer = moe_peft.Tokenizer(base_model)
lora_config = moe_peft.LoraConfig(
adapter_name=adapter_name,
lora_r_=32,
lora_alpha_=64,
lora_dropout_=0.05,
target_modules_={
"q_proj": True,
"k_proj": True,
"v_proj": True,
"o_proj": True,
},
)
train_config = moe_peft.TrainConfig(
adapter_name=adapter_name,
data_path=train_data,
num_epochs=10,
batch_size=16,
micro_batch_size=8,
learning_rate=1e-4,
)
with moe_peft.executors.no_cache():
model.init_adapter(lora_config)
moe_peft.train(model=model, tokenizer=tokenizer, configs=[train_config])
lora_config, lora_weight = model.unload_adapter(adapter_name)
generate_configs = [
moe_peft.GenerateConfig(
adapter_name=adapter_name,
prompts=[test_prompt],
stop_token="\n",
),
moe_peft.GenerateConfig(
adapter_name="default",
prompts=[test_prompt],
stop_token="\n",
),
]
with moe_peft.executors.no_cache():
model.init_adapter(lora_config, lora_weight)
model.init_adapter(moe_peft.AdapterConfig(adapter_name="default"))
outputs = moe_peft.generate(
model=model,
tokenizer=tokenizer,
configs=generate_configs,
max_gen_len=128,
)
print(f"\n{'='*10}\n")
print(f"PROMPT: {test_prompt}\n")
for adapter_name, output in outputs.items():
print(f"{adapter_name} OUTPUT:")
print(f"{output[0]}\n")
print(f"\n{'='*10}\n")
if __name__ == "__main__":
fire.Fire(main)
|