from typing import Optional import fire import torch from mixlora import MixLoraModelForCausalLM, Prompter from mixlora.utils import infer_device from transformers import AutoTokenizer from transformers.utils import is_torch_bf16_available_on_device def main( adapter_model: str, instruction: str, template: str = "alpaca", device: Optional[str] = None, ): if device is None: device = infer_device() model, config = MixLoraModelForCausalLM.from_pretrained( adapter_model, torch_dtype=( torch.bfloat16 if is_torch_bf16_available_on_device(device) else torch.float16 ), device_map=device, ) tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) prompter = Prompter(template) input_kwargs = tokenizer(prompter.generate_prompt(instruction), return_tensors="pt") # send tensors into correct device for key, value in input_kwargs.items(): if isinstance(value, torch.Tensor): input_kwargs[key] = value.to(device) with torch.inference_mode(): outputs = model.generate( **input_kwargs, max_new_tokens=100, ) output = tokenizer.batch_decode( outputs.detach().cpu().numpy(), skip_special_tokens=True )[0][input_kwargs["input_ids"].shape[-1] :] print(f"\nOutput: {prompter.get_response(output)}\n") if __name__ == "__main__": fire.Fire(main)