File size: 2,078 Bytes
29d1fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations

import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]
SRC_DIR = PROJECT_ROOT / "src"
sys.path.insert(0, str(SRC_DIR))

from PIL import Image
from transformers import AutoImageProcessor

from hf_processor_practice.utils import SAVED_PROCESSOR_DIR, get_cat_image_path, load_vit_image_processor_with_fallback, print_title


def main() -> None:
    print_title("02. AutoImageProcessor Practice")

    # 1. 샘플 이미지를 준비한다. 인터넷이 없으면 placeholder 이미지를 만든다.
    cat_path = get_cat_image_path()
    image = Image.open(cat_path).convert("RGB")
    print("Image path:", cat_path)
    print("Original image size:", image.size)

    # 2. 모델 이름으로 ImageProcessor를 자동 로드한다.
    # 인터넷 연결이 없으면 로컬 ViTImageProcessor로 fallback한다.
    image_processor = load_vit_image_processor_with_fallback()
    print("ImageProcessor type:", type(image_processor))

    # 3. 이미지를 모델 입력 pixel_values로 변환한다.
    batch = image_processor(images=[image], return_tensors="pt")
    print("\nBatch keys:", list(batch.keys()))
    for key, value in batch.items():
        print(f"{key}: shape={tuple(value.shape)}, dtype={value.dtype}")

    # 4. save_pretrained로 preprocessor_config.json을 저장하고 다시 로드한다.
    save_dir = SAVED_PROCESSOR_DIR / "tmp_imgproc"
    image_processor.save_pretrained(save_dir)
    try:
        image_processor2 = AutoImageProcessor.from_pretrained(save_dir)
    except Exception as exc:
        print(f"AutoImageProcessor local reload failed, using direct class reload: {exc}")
        image_processor2 = type(image_processor).from_pretrained(save_dir)

    batch2 = image_processor2(images=[image], return_tensors="pt")
    print("\nReloaded ImageProcessor type:", type(image_processor2))
    print("Reloaded pixel_values shape:", tuple(batch2["pixel_values"].shape))
    print("Saved files:", sorted(p.name for p in save_dir.iterdir()))


if __name__ == "__main__":
    main()