File size: 1,487 Bytes
51be264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from typing import Optional
import fire
import torch
from mixlora import MixLoraModelForCausalLM, Prompter
from mixlora.utils import infer_device
from transformers import AutoTokenizer
from transformers.utils import is_torch_bf16_available_on_device
def main(
adapter_model: str,
instruction: str,
template: str = "alpaca",
device: Optional[str] = None,
):
if device is None:
device = infer_device()
model, config = MixLoraModelForCausalLM.from_pretrained(
adapter_model,
torch_dtype=(
torch.bfloat16
if is_torch_bf16_available_on_device(device)
else torch.float16
),
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
prompter = Prompter(template)
input_kwargs = tokenizer(prompter.generate_prompt(instruction), return_tensors="pt")
# send tensors into correct device
for key, value in input_kwargs.items():
if isinstance(value, torch.Tensor):
input_kwargs[key] = value.to(device)
with torch.inference_mode():
outputs = model.generate(
**input_kwargs,
max_new_tokens=100,
)
output = tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][input_kwargs["input_ids"].shape[-1] :]
print(f"\nOutput: {prompter.get_response(output)}\n")
if __name__ == "__main__":
fire.Fire(main)
|