import fire import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer def main( base_model: str, instruction: str, lora_weights: str = None, device: str = "cuda:0" ): tokenizer = AutoTokenizer.from_pretrained(base_model) model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.bfloat16, device_map=device, ) model = PeftModel.from_pretrained(model, lora_weights, torch_dtype=torch.float32) input_ids = tokenizer(instruction, return_tensors="pt").input_ids.to(device) output = "" with torch.inference_mode(): outputs = model.generate( input_ids=input_ids, max_new_tokens=128, ) output = tokenizer.batch_decode( outputs.detach().cpu().numpy(), skip_special_tokens=True )[0][input_ids.shape[-1] :] print(f"Prompt:\n{instruction}\n") print(f"Generated:\n{output}") if __name__ == "__main__": fire.Fire(main)