File size: 3,518 Bytes
61c72b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dba87af
 
 
 
 
 
dbb959c
 
 
 
 
61c72b6
 
 
 
 
 
 
 
 
dba87af
dbb959c
dba87af
 
61c72b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb959c
 
dba87af
61c72b6
 
 
 
 
dba87af
61c72b6
dbb959c
61c72b6
 
 
dba87af
61c72b6
dbb959c
61c72b6
 
dbb959c
 
 
 
61c72b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
from typing import Optional

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer


def parse_args():
    p = argparse.ArgumentParser(description="Compare base vs. fine-tuned LoRA outputs.")
    p.add_argument("--base_model", required=True, help="Base model id or local path.")
    p.add_argument("--lora_dir", required=True, help="Path to LoRA adapter folder (e.g., outputs/tinyllama-lora).")
    p.add_argument("--prompt", required=True, help="Prompt to generate with.")
    p.add_argument("--max_new_tokens", type=int, default=128)
    p.add_argument("--temperature", type=float, default=0.7)
    p.add_argument("--top_p", type=float, default=0.9)
    p.add_argument(
        "--torch_dtype",
        default="auto",
        choices=["auto", "float16", "bfloat16", "float32"],
        help="Force dtype for model load.",
    )
    p.add_argument(
        "--device",
        default="auto",
        choices=["auto", "cpu", "cuda", "mps"],
        help="Force device map; on CPU use this to keep everything on host.",
    )
    p.add_argument(
        "--offload_dir",
        default=None,
        help="Optional offload directory when using device_map='auto' on constrained devices.",
    )
    return p.parse_args()


def resolve_dtype(name: str) -> Optional[torch.dtype]:
    if name == "auto":
        return None
    return {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[name]


def resolve_device_map(device: str):
    return None if device == "cpu" else "auto"


def generate(model, tokenizer, prompt: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)


def main():
    args = parse_args()
    force_cpu = args.device == "cpu"
    torch_dtype = torch.float32 if force_cpu else resolve_dtype(args.torch_dtype)
    device_map = resolve_device_map(args.device) if args.device != "auto" else "auto"

    tokenizer = AutoTokenizer.from_pretrained(args.lora_dir, use_fast=False)

    base_model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        device_map=device_map,
        torch_dtype=torch_dtype,
        offload_folder=args.offload_dir,
    )
    lora_wrapped = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        device_map=device_map,
        torch_dtype=torch_dtype,
        offload_folder=args.offload_dir,
    )
    lora_wrapped = PeftModel.from_pretrained(lora_wrapped, args.lora_dir)
    if force_cpu:
        # Avoid Accelerate dispatch/offload; keep everything on CPU.
        base_model.to("cpu")
        lora_wrapped.to("cpu")

    base_out = generate(
        base_model,
        tokenizer,
        args.prompt,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
    )
    lora_out = generate(
        lora_wrapped,
        tokenizer,
        args.prompt,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
    )

    print("=== Base model ===")
    print(base_out)
    print("\n=== LoRA model ===")
    print(lora_out)


if __name__ == "__main__":
    main()