Yuhao
Restructure inference and add INT4 serving
52a881a
from __future__ import annotations
import argparse
import time
from pathlib import Path
try:
from .model_utils import (
DEFAULT_DO_SAMPLE,
DEFAULT_MODEL_PATH,
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_PROMPT,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
QuantizedSkinGPTModel,
build_single_turn_messages,
)
except ImportError:
from model_utils import (
DEFAULT_DO_SAMPLE,
DEFAULT_MODEL_PATH,
DEFAULT_MAX_NEW_TOKENS,
DEFAULT_PROMPT,
DEFAULT_REPETITION_PENALTY,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
QuantizedSkinGPTModel,
build_single_turn_messages,
)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 inference")
parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH)
parser.add_argument("--image_path", type=str, required=True, help="Path to the test image")
parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for diagnosis")
parser.add_argument("--max_new_tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS)
parser.add_argument("--do_sample", action="store_true", default=DEFAULT_DO_SAMPLE)
parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE)
parser.add_argument("--top_p", type=float, default=DEFAULT_TOP_P)
parser.add_argument("--repetition_penalty", type=float, default=DEFAULT_REPETITION_PENALTY)
return parser
def main() -> None:
args = build_parser().parse_args()
if not Path(args.image_path).exists():
print(f"Error: Image not found at {args.image_path}")
return
print("=== [1] Initializing INT4 Quantization ===")
print("BitsAndBytesConfig will be applied during model loading.")
print("=== [2] Loading Model and Processor ===")
start_load = time.time()
model = QuantizedSkinGPTModel(args.model_path)
print(f"Model loaded in {time.time() - start_load:.2f} seconds.")
print("=== [3] Preparing Input ===")
messages = build_single_turn_messages(args.image_path, args.prompt)
print("=== [4] Generating Response ===")
start_infer = time.time()
output_text = model.generate_response(
messages,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
print(f"Inference completed in {time.time() - start_infer:.2f} seconds.")
print("\n================ MODEL OUTPUT ================\n")
print(output_text)
print("\n==============================================\n")
if __name__ == "__main__":
main()