| import fire | |
| import torch | |
| import moe_peft | |
| def main( | |
| base_model: str, | |
| adapter_name: str = "lora_0", | |
| train_data: str = "TUDB-Labs/Dummy-MoE-PEFT", | |
| test_prompt: str = "Could you provide an introduction to MoE-PEFT?", | |
| ): | |
| moe_peft.setup_logging("INFO") | |
| model: moe_peft.LLMModel = moe_peft.LLMModel.from_pretrained( | |
| base_model, | |
| device=moe_peft.executor.default_device_name(), | |
| load_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = moe_peft.Tokenizer(base_model) | |
| lora_config = moe_peft.LoraConfig( | |
| adapter_name=adapter_name, | |
| lora_r_=32, | |
| lora_alpha_=64, | |
| lora_dropout_=0.05, | |
| target_modules_={ | |
| "q_proj": True, | |
| "k_proj": True, | |
| "v_proj": True, | |
| "o_proj": True, | |
| }, | |
| ) | |
| train_config = moe_peft.TrainConfig( | |
| adapter_name=adapter_name, | |
| data_path=train_data, | |
| num_epochs=10, | |
| batch_size=16, | |
| micro_batch_size=8, | |
| learning_rate=1e-4, | |
| ) | |
| with moe_peft.executors.no_cache(): | |
| model.init_adapter(lora_config) | |
| moe_peft.train(model=model, tokenizer=tokenizer, configs=[train_config]) | |
| lora_config, lora_weight = model.unload_adapter(adapter_name) | |
| generate_configs = [ | |
| moe_peft.GenerateConfig( | |
| adapter_name=adapter_name, | |
| prompts=[test_prompt], | |
| stop_token="\n", | |
| ), | |
| moe_peft.GenerateConfig( | |
| adapter_name="default", | |
| prompts=[test_prompt], | |
| stop_token="\n", | |
| ), | |
| ] | |
| with moe_peft.executors.no_cache(): | |
| model.init_adapter(lora_config, lora_weight) | |
| model.init_adapter(moe_peft.AdapterConfig(adapter_name="default")) | |
| outputs = moe_peft.generate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| configs=generate_configs, | |
| max_gen_len=128, | |
| ) | |
| print(f"\n{'='*10}\n") | |
| print(f"PROMPT: {test_prompt}\n") | |
| for adapter_name, output in outputs.items(): | |
| print(f"{adapter_name} OUTPUT:") | |
| print(f"{output[0]}\n") | |
| print(f"\n{'='*10}\n") | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |