| from __future__ import annotations |
|
|
| import argparse |
| import time |
| from pathlib import Path |
|
|
| try: |
| from .model_utils import ( |
| DEFAULT_DO_SAMPLE, |
| DEFAULT_MODEL_PATH, |
| DEFAULT_MAX_NEW_TOKENS, |
| DEFAULT_PROMPT, |
| DEFAULT_REPETITION_PENALTY, |
| DEFAULT_TEMPERATURE, |
| DEFAULT_TOP_P, |
| QuantizedSkinGPTModel, |
| build_single_turn_messages, |
| ) |
| except ImportError: |
| from model_utils import ( |
| DEFAULT_DO_SAMPLE, |
| DEFAULT_MODEL_PATH, |
| DEFAULT_MAX_NEW_TOKENS, |
| DEFAULT_PROMPT, |
| DEFAULT_REPETITION_PENALTY, |
| DEFAULT_TEMPERATURE, |
| DEFAULT_TOP_P, |
| QuantizedSkinGPTModel, |
| build_single_turn_messages, |
| ) |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="SkinGPT-R1 INT4 inference") |
| parser.add_argument("--model_path", type=str, default=DEFAULT_MODEL_PATH) |
| parser.add_argument("--image_path", type=str, required=True, help="Path to the test image") |
| parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt for diagnosis") |
| parser.add_argument("--max_new_tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS) |
| parser.add_argument("--do_sample", action="store_true", default=DEFAULT_DO_SAMPLE) |
| parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE) |
| parser.add_argument("--top_p", type=float, default=DEFAULT_TOP_P) |
| parser.add_argument("--repetition_penalty", type=float, default=DEFAULT_REPETITION_PENALTY) |
| return parser |
|
|
|
|
| def main() -> None: |
| args = build_parser().parse_args() |
|
|
| if not Path(args.image_path).exists(): |
| print(f"Error: Image not found at {args.image_path}") |
| return |
|
|
| print("=== [1] Initializing INT4 Quantization ===") |
| print("BitsAndBytesConfig will be applied during model loading.") |
|
|
| print("=== [2] Loading Model and Processor ===") |
| start_load = time.time() |
| model = QuantizedSkinGPTModel(args.model_path) |
| print(f"Model loaded in {time.time() - start_load:.2f} seconds.") |
|
|
| print("=== [3] Preparing Input ===") |
| messages = build_single_turn_messages(args.image_path, args.prompt) |
|
|
| print("=== [4] Generating Response ===") |
| start_infer = time.time() |
| output_text = model.generate_response( |
| messages, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=args.do_sample, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| ) |
|
|
| print(f"Inference completed in {time.time() - start_infer:.2f} seconds.") |
| print("\n================ MODEL OUTPUT ================\n") |
| print(output_text) |
| print("\n==============================================\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|