File size: 4,954 Bytes
0c51b93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import difflib
import json
import os

import torch
from jinja2 import Environment, FileSystemLoader
from peft import PeftModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


def parse_args():
    parser = argparse.ArgumentParser(description="Test a model with a template and example data")
    parser.add_argument("--model_path", type=str, required=True, help="Path to model or HF model name")
    parser.add_argument("--adapter_path", type=str, default=None, help="Path to PEFT adapter (for QLoRA)")
    parser.add_argument("--template_path", type=str, required=True, help="Path to Jinja template file")
    parser.add_argument("--example_path", type=str, required=True, help="Path to example data JSON")
    parser.add_argument("--max_length", type=int, default=512, help="Maximum output length")
    parser.add_argument("--use_qlora", action="store_true", help="Whether to use QLoRA (automatically enables 4-bit)")
    return parser.parse_args()

def load_model_and_tokenizer(args):
    print(f"Loading model: {args.model_path}")

    # Set up tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)

    # With QLoRA, we automatically use 4-bit quantization as per the training logic
    if args.use_qlora:
        print("Using QLoRA with 4-bit quantization")
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )

        base_model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            torch_dtype=torch.float16,
            device_map="auto",
            quantization_config=quantization_config
        )
    else:
        print("Using full precision model")
        base_model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )

    # Load PEFT adapter if specified
    if args.adapter_path:
        print(f"Loading adapter from: {args.adapter_path}")
        # Check for adapter files
        adapter_path = args.adapter_path
        if os.path.exists(os.path.join(adapter_path, 'adapter_model.safetensors')) or \
           os.path.exists(os.path.join(adapter_path, 'adapter_model.bin')):
            model = PeftModelForCausalLM.from_pretrained(base_model, adapter_path)
        else:
            print(f"No adapter found at {adapter_path}, using base model")
            model = base_model
    else:
        model = base_model

    model.eval()

    return model, tokenizer

def load_template(template_path):
    template_dir = os.path.dirname(template_path)
    template_file = os.path.basename(template_path)

    if not template_dir:
        template_dir = "."

    env = Environment(loader=FileSystemLoader(template_dir))
    env.filters['tojson'] = lambda obj: json.dumps(obj)
    return env.get_template(template_file)

def generate_response(model, tokenizer, prompt, max_length=512):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_length=max_length,
            do_sample=False,
            temperature=0.0,
        )

    response = tokenizer.decode(output[0], skip_special_tokens=False)
    return response

def extract_generated_content(full_response, prompt):
    """Extract only the newly generated content by removing the prompt prefix."""
    if prompt in full_response:
        return full_response[len(prompt):].strip()

    return full_response

def compare_with_groundtruth(generated, groundtruth):
    """Compare generated text with groundtruth and show differences."""
    diff = difflib.unified_diff(
        generated.splitlines(),
        groundtruth.splitlines(),
        lineterm='',
        fromfile='Generated',
        tofile='Expected'
    )
    return '\n'.join(diff)

def main():
    args = parse_args()

    model, tokenizer = load_model_and_tokenizer(args)

    with open(args.example_path, 'r') as f:
        example_data = json.load(f)

    template = load_template(args.template_path)

    for i, example in enumerate(example_data):
        print(f"\n===== EXAMPLE {i+1}/{len(example_data)} =====")

        rendered_prompt = template.render(
            messages=[{"role": "user", "content": example["input"]}],
            add_generation_prompt=True, # important for inference
        )

        full_response = generate_response(model, tokenizer, rendered_prompt, args.max_length)

        generated_content = extract_generated_content(full_response, rendered_prompt)

        print("\nMODEL OUTPUT:")
        print(generated_content)

        print("\nGROUNDTRUTH:")
        print(example['output'])


if __name__ == "__main__":
    main()