File size: 2,793 Bytes
52a881a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations

import argparse
import time
from pathlib import Path

try:
    from .model_utils import (
        DEFAULT_DO_SAMPLE,
        DEFAULT_MODEL_PATH,
        DEFAULT_MAX_NEW_TOKENS,
        DEFAULT_PROMPT,
        DEFAULT_REPETITION_PENALTY,
        DEFAULT_TEMPERATURE,
        DEFAULT_TOP_P,
        QuantizedSkinGPTModel,
        build_single_turn_messages,
    )
except ImportError:
    from model_utils import (
        DEFAULT_DO_SAMPLE,
        DEFAULT_MODEL_PATH,
        DEFAULT_MAX_NEW_TOKENS,
        DEFAULT_PROMPT,
        DEFAULT_REPETITION_PENALTY,
        DEFAULT_TEMPERATURE,
        DEFAULT_TOP_P,
        QuantizedSkinGPTModel,
        build_single_turn_messages,
    )


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 inference")
    parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
    parser.add_argument("--image_path", type=str, required=True, help="Path to the test image")
    parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for diagnosis")
    parser.add_argument("--max_new_tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS)
    parser.add_argument("--do_sample", action="store_true", default=DEFAULT_DO_SAMPLE)
    parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE)
    parser.add_argument("--top_p", type=float, default=DEFAULT_TOP_P)
    parser.add_argument("--repetition_penalty", type=float, default=DEFAULT_REPETITION_PENALTY)
    return parser


def main() -> None:
    args = build_parser().parse_args()

    if not Path(args.image_path).exists():
        print(f"Error: Image not found at {args.image_path}")
        return

    print("=== [1] Initializing INT4 Quantization ===")
    print("BitsAndBytesConfig will be applied during model loading.")

    print("=== [2] Loading Model and Processor ===")
    start_load = time.time()
    model = QuantizedSkinGPTModel(args.model_path)
    print(f"Model loaded in {time.time() - start_load:.2f} seconds.")

    print("=== [3] Preparing Input ===")
    messages = build_single_turn_messages(args.image_path, args.prompt)

    print("=== [4] Generating Response ===")
    start_infer = time.time()
    output_text = model.generate_response(
        messages,
        max_new_tokens=args.max_new_tokens,
        do_sample=args.do_sample,
        temperature=args.temperature,
        top_p=args.top_p,
        repetition_penalty=args.repetition_penalty,
    )

    print(f"Inference completed in {time.time() - start_infer:.2f} seconds.")
    print("\n================ MODEL OUTPUT ================\n")
    print(output_text)
    print("\n==============================================\n")


if __name__ == "__main__":
    main()