|
|
|
|
|
""" |
|
|
Run inference with a base model and a saved QLoRA adapter. |
|
|
|
|
|
Example: |
|
|
python qlora_inference.py \ |
|
|
--base-model meta-llama/Meta-Llama-3.1-8B-Instruct \ |
|
|
--adapter-path saves/llama31-8b/kaggle_cyberbullying/qlora \ |
|
|
--system-prompt "You are a helpful assistant." \ |
|
|
--user-input "How should we moderate this post?" |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from peft import PeftModel |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Generate text with a base model and LoRA/QLoRA adapter." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--base-model", |
|
|
default="meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
|
help="Base model name or local path (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--adapter-path", |
|
|
required=True, |
|
|
help="Path to the trained adapter directory (containing adapter_model.safetensors).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--system-prompt", |
|
|
default="You are a helpful assistant.", |
|
|
help="System prompt to steer the assistant (set empty string to skip).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--user-input", |
|
|
required=True, |
|
|
help="User instruction or text the model should respond to.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-new-tokens", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Maximum number of new tokens to generate.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--temperature", |
|
|
type=float, |
|
|
default=0.7, |
|
|
help="Sampling temperature; set <= 0 for deterministic decoding.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--top-p", |
|
|
type=float, |
|
|
default=0.9, |
|
|
help="Top-p value for nucleus sampling (ignored if temperature <= 0).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--no-quantization", |
|
|
action="store_true", |
|
|
help="Disable 4-bit loading; use this if you have a full-precision base model locally.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def build_prompt( |
|
|
tokenizer: AutoTokenizer, system_prompt: str, user_input: str |
|
|
) -> str: |
|
|
"""Create a conversation prompt using chat templates when available.""" |
|
|
messages = [] |
|
|
if system_prompt: |
|
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: |
|
|
return tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
if system_prompt: |
|
|
return f"{system_prompt}\n\nUser: {user_input}\nAssistant:" |
|
|
return f"User: {user_input}\nAssistant:" |
|
|
|
|
|
|
|
|
def load_model_and_tokenizer( |
|
|
base_model: str, adapter_path: str, use_quantization: bool |
|
|
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: |
|
|
"""Load the base model, adapter, and tokenizer with sensible defaults.""" |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
adapter_path, use_fast=True, trust_remote_code=True |
|
|
) |
|
|
if tokenizer.pad_token is None and tokenizer.eos_token is not None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
quantization_config: Optional[BitsAndBytesConfig] = None |
|
|
if use_quantization: |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model, |
|
|
device_map="auto", |
|
|
torch_dtype=torch_dtype, |
|
|
trust_remote_code=True, |
|
|
quantization_config=quantization_config, |
|
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(model, adapter_path, is_trainable=False) |
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def generate_response( |
|
|
model: AutoModelForCausalLM, |
|
|
tokenizer: AutoTokenizer, |
|
|
prompt: str, |
|
|
max_new_tokens: int, |
|
|
temperature: float, |
|
|
top_p: float, |
|
|
) -> str: |
|
|
"""You are a helpful Assistant.""" |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
input_ids = inputs["input_ids"].to(model.device) |
|
|
attention_mask = inputs.get("attention_mask") |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(model.device) |
|
|
|
|
|
generation_kwargs = { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"pad_token_id": tokenizer.pad_token_id, |
|
|
} |
|
|
|
|
|
if temperature > 0: |
|
|
generation_kwargs.update( |
|
|
{ |
|
|
"do_sample": True, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
} |
|
|
) |
|
|
else: |
|
|
generation_kwargs.update({"do_sample": False}) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = model.generate(**generation_kwargs) |
|
|
|
|
|
generated_ids = outputs[0, input_ids.shape[-1] :] |
|
|
return tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer( |
|
|
base_model=args.base_model, |
|
|
adapter_path=args.adapter_path, |
|
|
use_quantization=not args.no_quantization, |
|
|
) |
|
|
|
|
|
prompt = build_prompt(tokenizer, args.system_prompt, args.user_input) |
|
|
response = generate_response( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=prompt, |
|
|
max_new_tokens=args.max_new_tokens, |
|
|
temperature=args.temperature, |
|
|
top_p=args.top_p, |
|
|
) |
|
|
|
|
|
print("\n=== Model Response ===") |
|
|
print(response) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|