| import fire | |
| import torch | |
| import moe_peft | |
| def inference_callback(cur_pos, outputs): | |
| print(f"Position: {cur_pos}") | |
| for adapter_name, output in outputs.items(): | |
| print(f"{adapter_name} output: {output[0]}") | |
| def main( | |
| base_model: str, | |
| instruction: str, | |
| input: str = None, | |
| template: str = None, | |
| lora_weights: str = None, | |
| load_16bit: bool = True, | |
| load_8bit: bool = False, | |
| load_4bit: bool = False, | |
| flash_attn: bool = False, | |
| max_seq_len: int = None, | |
| stream: bool = False, | |
| device: str = moe_peft.executor.default_device_name(), | |
| ): | |
| model = moe_peft.LLMModel.from_pretrained( | |
| base_model, | |
| device=device, | |
| attn_impl="flash_attn" if flash_attn else "eager", | |
| bits=(8 if load_8bit else (4 if load_4bit else None)), | |
| load_dtype=torch.bfloat16 if load_16bit else torch.float32, | |
| ) | |
| tokenizer = moe_peft.Tokenizer(base_model) | |
| if lora_weights: | |
| adapter_name = model.load_adapter(lora_weights) | |
| else: | |
| adapter_name = model.init_adapter( | |
| moe_peft.AdapterConfig(adapter_name="default") | |
| ) | |
| generate_paramas = moe_peft.GenerateConfig( | |
| adapter_name=adapter_name, | |
| prompt_template=template, | |
| prompts=[(instruction, input)], | |
| ) | |
| output = moe_peft.generate( | |
| model, | |
| tokenizer, | |
| [generate_paramas], | |
| max_gen_len=max_seq_len, | |
| stream_callback=inference_callback if stream else None, | |
| ) | |
| for prompt in output[adapter_name]: | |
| print(f"\n{'='*10}\n") | |
| print(prompt) | |
| print(f"\n{'='*10}\n") | |
| if __name__ == "__main__": | |
| fire.Fire(main) | |