File size: 4,954 Bytes
0c51b93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import difflib
import json
import os
import torch
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
def parse_args():
parser = argparse.ArgumentParser(description="Test a model with a template and example data")
parser.add_argument("--model_path", type=str, required=True, help="Path to model or HF model name")
parser.add_argument("--adapter_path", type=str, default=None, help="Path to PEFT adapter (for QLoRA)")
parser.add_argument("--template_path", type=str, required=True, help="Path to Jinja template file")
parser.add_argument("--example_path", type=str, required=True, help="Path to example data JSON")
parser.add_argument("--max_length", type=int, default=512, help="Maximum output length")
parser.add_argument("--use_qlora", action="store_true", help="Whether to use QLoRA (automatically enables 4-bit)")
return parser.parse_args()
def load_model_and_tokenizer(args):
print(f"Loading model: {args.model_path}")
# Set up tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
# With QLoRA, we automatically use 4-bit quantization as per the training logic
if args.use_qlora:
print("Using QLoRA with 4-bit quantization")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
base_model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=quantization_config
)
else:
print("Using full precision model")
base_model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.float16,
device_map="auto"
)
# Load PEFT adapter if specified
if args.adapter_path:
print(f"Loading adapter from: {args.adapter_path}")
# Check for adapter files
adapter_path = args.adapter_path
if os.path.exists(os.path.join(adapter_path, 'adapter_model.safetensors')) or \
os.path.exists(os.path.join(adapter_path, 'adapter_model.bin')):
model = PeftModelForCausalLM.from_pretrained(base_model, adapter_path)
else:
print(f"No adapter found at {adapter_path}, using base model")
model = base_model
else:
model = base_model
model.eval()
return model, tokenizer
def load_template(template_path):
template_dir = os.path.dirname(template_path)
template_file = os.path.basename(template_path)
if not template_dir:
template_dir = "."
env = Environment(loader=FileSystemLoader(template_dir))
env.filters['tojson'] = lambda obj: json.dumps(obj)
return env.get_template(template_file)
def generate_response(model, tokenizer, prompt, max_length=512):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=max_length,
do_sample=False,
temperature=0.0,
)
response = tokenizer.decode(output[0], skip_special_tokens=False)
return response
def extract_generated_content(full_response, prompt):
"""Extract only the newly generated content by removing the prompt prefix."""
if prompt in full_response:
return full_response[len(prompt):].strip()
return full_response
def compare_with_groundtruth(generated, groundtruth):
"""Compare generated text with groundtruth and show differences."""
diff = difflib.unified_diff(
generated.splitlines(),
groundtruth.splitlines(),
lineterm='',
fromfile='Generated',
tofile='Expected'
)
return '\n'.join(diff)
def main():
args = parse_args()
model, tokenizer = load_model_and_tokenizer(args)
with open(args.example_path, 'r') as f:
example_data = json.load(f)
template = load_template(args.template_path)
for i, example in enumerate(example_data):
print(f"\n===== EXAMPLE {i+1}/{len(example_data)} =====")
rendered_prompt = template.render(
messages=[{"role": "user", "content": example["input"]}],
add_generation_prompt=True, # important for inference
)
full_response = generate_response(model, tokenizer, rendered_prompt, args.max_length)
generated_content = extract_generated_content(full_response, rendered_prompt)
print("\nMODEL OUTPUT:")
print(generated_content)
print("\nGROUNDTRUTH:")
print(example['output'])
if __name__ == "__main__":
main()
|