File size: 5,829 Bytes
1ffe7f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
#!/usr/bin/env python3
"""
Run inference with a base model and a saved QLoRA adapter.
Example:
python qlora_inference.py \
--base-model meta-llama/Meta-Llama-3.1-8B-Instruct \
--adapter-path saves/llama31-8b/kaggle_cyberbullying/qlora \
--system-prompt "You are a helpful assistant." \
--user-input "How should we moderate this post?"
"""
import argparse
from typing import Optional, Tuple
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate text with a base model and LoRA/QLoRA adapter."
)
parser.add_argument(
"--base-model",
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
help="Base model name or local path (default: %(default)s).",
)
parser.add_argument(
"--adapter-path",
required=True,
help="Path to the trained adapter directory (containing adapter_model.safetensors).",
)
parser.add_argument(
"--system-prompt",
default="You are a helpful assistant.",
help="System prompt to steer the assistant (set empty string to skip).",
)
parser.add_argument(
"--user-input",
required=True,
help="User instruction or text the model should respond to.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=512,
help="Maximum number of new tokens to generate.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature; set <= 0 for deterministic decoding.",
)
parser.add_argument(
"--top-p",
type=float,
default=0.9,
help="Top-p value for nucleus sampling (ignored if temperature <= 0).",
)
parser.add_argument(
"--no-quantization",
action="store_true",
help="Disable 4-bit loading; use this if you have a full-precision base model locally.",
)
return parser.parse_args()
def build_prompt(
tokenizer: AutoTokenizer, system_prompt: str, user_input: str
) -> str:
"""Create a conversation prompt using chat templates when available."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_input})
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Fallback for tokenizers without chat templates.
if system_prompt:
return f"{system_prompt}\n\nUser: {user_input}\nAssistant:"
return f"User: {user_input}\nAssistant:"
def load_model_and_tokenizer(
base_model: str, adapter_path: str, use_quantization: bool
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Load the base model, adapter, and tokenizer with sensible defaults."""
tokenizer = AutoTokenizer.from_pretrained(
adapter_path, use_fast=True, trust_remote_code=True
)
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
quantization_config: Optional[BitsAndBytesConfig] = None
if use_quantization:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map="auto",
torch_dtype=torch_dtype,
trust_remote_code=True,
quantization_config=quantization_config,
)
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False)
model.eval()
return model, tokenizer
def generate_response(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> str:
"""You are a helpful Assistant."""
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(model.device)
generation_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.pad_token_id,
}
if temperature > 0:
generation_kwargs.update(
{
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
}
)
else:
generation_kwargs.update({"do_sample": False})
with torch.inference_mode():
outputs = model.generate(**generation_kwargs)
generated_ids = outputs[0, input_ids.shape[-1] :]
return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
def main() -> None:
args = parse_args()
model, tokenizer = load_model_and_tokenizer(
base_model=args.base_model,
adapter_path=args.adapter_path,
use_quantization=not args.no_quantization,
)
prompt = build_prompt(tokenizer, args.system_prompt, args.user_input)
response = generate_response(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
)
print("\n=== Model Response ===")
print(response)
if __name__ == "__main__":
main()
|