from __future__ import annotations import sys from pathlib import Path PROJECT_ROOT = Path(__file__).resolve().parents[1] SRC_DIR = PROJECT_ROOT / "src" sys.path.insert(0, str(SRC_DIR)) from PIL import Image from transformers import AutoProcessor from hf_processor_practice.utils import ( SAVED_PROCESSOR_DIR, get_cat_image_path, get_dog_image_path, load_clip_processor_with_fallback, print_title, ) def main() -> None: print_title("03. AutoProcessor / CLIP Practice") # 1. 샘플 이미지 2장을 준비한다. 인터넷이 없으면 placeholder 이미지를 만든다. cat_path = get_cat_image_path() dog_path = get_dog_image_path() cat = Image.open(cat_path).convert("RGB") dog = Image.open(dog_path).convert("RGB") print("Cat image size:", cat.size) print("Dog image size:", dog.size) # 2. CLIP은 텍스트와 이미지를 함께 다루는 멀티모달 모델이다. # AutoProcessor는 내부적으로 tokenizer + image_processor를 묶어서 제공한다. processor = load_clip_processor_with_fallback() print("Processor type:", type(processor)) print("Tokenizer type:", type(processor.tokenizer)) print("ImageProcessor type:", type(processor.image_processor)) # 3. 텍스트와 이미지를 동시에 입력하여 모델 입력 딕셔너리를 만든다. out = processor( text=["a photo of a cat", "a photo of a dog"], images=[cat, dog], padding=True, return_tensors="pt", ) print("\nOutput keys:", list(out.keys())) for key, value in out.items(): print(f"{key}: shape={tuple(value.shape)}, dtype={value.dtype}") # 4. Processor 저장/복원 실습 save_dir = SAVED_PROCESSOR_DIR / "tmp_clip_proc" processor.save_pretrained(save_dir) try: processor2 = AutoProcessor.from_pretrained(save_dir) except Exception as exc: # 일부 transformers 버전에서는 local fallback processor에 config.json이 없으면 # AutoProcessor가 실패할 수 있어 같은 Processor 클래스에서 직접 로드한다. print(f"AutoProcessor local reload failed, using processor class reload: {exc}") processor2 = type(processor).from_pretrained(save_dir) out2 = processor2( text=["a photo of a cat"], images=[cat], padding=True, return_tensors="pt", ) print("\nReloaded processor type:", type(processor2)) print("Reloaded keys:", list(out2.keys())) print("Saved files:", sorted(p.name for p in save_dir.iterdir())) if __name__ == "__main__": main()