|
|
from typing import Optional |
|
|
|
|
|
import fire |
|
|
import torch |
|
|
from mixlora import MixLoraModelForCausalLM, Prompter |
|
|
from mixlora.utils import infer_device |
|
|
from transformers import AutoTokenizer |
|
|
from transformers.utils import is_torch_bf16_available_on_device |
|
|
|
|
|
|
|
|
def main( |
|
|
adapter_model: str, |
|
|
instruction: str, |
|
|
template: str = "alpaca", |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
if device is None: |
|
|
device = infer_device() |
|
|
|
|
|
model, config = MixLoraModelForCausalLM.from_pretrained( |
|
|
adapter_model, |
|
|
torch_dtype=( |
|
|
torch.bfloat16 |
|
|
if is_torch_bf16_available_on_device(device) |
|
|
else torch.float16 |
|
|
), |
|
|
device_map=device, |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
|
prompter = Prompter(template) |
|
|
|
|
|
input_kwargs = tokenizer(prompter.generate_prompt(instruction), return_tensors="pt") |
|
|
|
|
|
for key, value in input_kwargs.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
input_kwargs[key] = value.to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model.generate( |
|
|
**input_kwargs, |
|
|
max_new_tokens=100, |
|
|
) |
|
|
output = tokenizer.batch_decode( |
|
|
outputs.detach().cpu().numpy(), skip_special_tokens=True |
|
|
)[0][input_kwargs["input_ids"].shape[-1] :] |
|
|
|
|
|
print(f"\nOutput: {prompter.get_response(output)}\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
fire.Fire(main) |
|
|
|