yb1n's picture
Upload folder using huggingface_hub
29d1fb6 verified
Raw
History Blame Contribute Delete
2.61 kB
from __future__ import annotations
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SRC_DIR = PROJECT_ROOT / "src"
sys.path.insert(0, str(SRC_DIR))
from PIL import Image
from transformers import AutoProcessor
from hf_processor_practice.utils import (
SAVED_PROCESSOR_DIR,
get_cat_image_path,
get_dog_image_path,
load_clip_processor_with_fallback,
print_title,
)
def main() -> None:
print_title("03. AutoProcessor / CLIP Practice")
# 1. 샘플 이미지 2장을 준비한다. 인터넷이 없으면 placeholder 이미지를 만든다.
cat_path = get_cat_image_path()
dog_path = get_dog_image_path()
cat = Image.open(cat_path).convert("RGB")
dog = Image.open(dog_path).convert("RGB")
print("Cat image size:", cat.size)
print("Dog image size:", dog.size)
# 2. CLIP은 텍스트와 이미지를 함께 다루는 멀티모달 모델이다.
# AutoProcessor는 내부적으로 tokenizer + image_processor를 묶어서 제공한다.
processor = load_clip_processor_with_fallback()
print("Processor type:", type(processor))
print("Tokenizer type:", type(processor.tokenizer))
print("ImageProcessor type:", type(processor.image_processor))
# 3. 텍스트와 이미지를 동시에 입력하여 모델 입력 딕셔너리를 만든다.
out = processor(
text=["a photo of a cat", "a photo of a dog"],
images=[cat, dog],
padding=True,
return_tensors="pt",
)
print("\nOutput keys:", list(out.keys()))
for key, value in out.items():
print(f"{key}: shape={tuple(value.shape)}, dtype={value.dtype}")
# 4. Processor 저장/복원 실습
save_dir = SAVED_PROCESSOR_DIR / "tmp_clip_proc"
processor.save_pretrained(save_dir)
try:
processor2 = AutoProcessor.from_pretrained(save_dir)
except Exception as exc:
# 일부 transformers 버전에서는 local fallback processor에 config.json이 없으면
# AutoProcessor가 실패할 수 있어 같은 Processor 클래스에서 직접 로드한다.
print(f"AutoProcessor local reload failed, using processor class reload: {exc}")
processor2 = type(processor).from_pretrained(save_dir)
out2 = processor2(
text=["a photo of a cat"],
images=[cat],
padding=True,
return_tensors="pt",
)
print("\nReloaded processor type:", type(processor2))
print("Reloaded keys:", list(out2.keys()))
print("Saved files:", sorted(p.name for p in save_dir.iterdir()))
if __name__ == "__main__":
main()