c2cite / tests /mixlora_generation.py
loadingy's picture
first push
51be264
from typing import Optional
import fire
import torch
from mixlora import MixLoraModelForCausalLM, Prompter
from mixlora.utils import infer_device
from transformers import AutoTokenizer
from transformers.utils import is_torch_bf16_available_on_device
def main(
adapter_model: str,
instruction: str,
template: str = "alpaca",
device: Optional[str] = None,
):
if device is None:
device = infer_device()
model, config = MixLoraModelForCausalLM.from_pretrained(
adapter_model,
torch_dtype=(
torch.bfloat16
if is_torch_bf16_available_on_device(device)
else torch.float16
),
device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
prompter = Prompter(template)
input_kwargs = tokenizer(prompter.generate_prompt(instruction), return_tensors="pt")
# send tensors into correct device
for key, value in input_kwargs.items():
if isinstance(value, torch.Tensor):
input_kwargs[key] = value.to(device)
with torch.inference_mode():
outputs = model.generate(
**input_kwargs,
max_new_tokens=100,
)
output = tokenizer.batch_decode(
outputs.detach().cpu().numpy(), skip_special_tokens=True
)[0][input_kwargs["input_ids"].shape[-1] :]
print(f"\nOutput: {prompter.get_response(output)}\n")
if __name__ == "__main__":
fire.Fire(main)