File size: 2,059 Bytes
29d1fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from __future__ import annotations

import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]
SRC_DIR = PROJECT_ROOT / "src"
sys.path.insert(0, str(SRC_DIR))

from transformers import AutoTokenizer

from hf_processor_practice.utils import SAVED_PROCESSOR_DIR, ensure_dirs, load_tokenizer_with_fallback, print_title


def main() -> None:
    ensure_dirs()
    print_title("01. AutoTokenizer Practice")

    # 1. 모델 이름으로 토크나이저를 자동 로드한다.
    # 인터넷 연결이 없으면 로컬 tiny tokenizer로 fallback한다.
    tokenizer = load_tokenizer_with_fallback()
    print("Tokenizer type:", type(tokenizer))
    print("Using fast tokenizer:", getattr(tokenizer, "is_fast", None))

    # 2. 텍스트 배치를 모델 입력 딕셔너리로 변환한다.
    batch = tokenizer(
        ["hello world", "this is a test"],
        padding=True,
        truncation=True,
        return_tensors="pt",
    )

    print("\nBatch keys:", list(batch.keys()))
    for key, value in batch.items():
        print(f"{key}: shape={tuple(value.shape)}")

    # 3. 토큰 ID를 다시 문자열로 디코딩하여 전처리 결과를 확인한다.
    decoded = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=False)
    decoded_clean = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
    print("\nDecoded with special tokens:", decoded)
    print("Decoded clean:", decoded_clean)

    # 4. save_pretrained로 tokenizer 관련 파일을 저장하고 다시 로드한다.
    save_dir = SAVED_PROCESSOR_DIR / "tmp_tok"
    tokenizer.save_pretrained(save_dir)
    tokenizer2 = AutoTokenizer.from_pretrained(save_dir)

    batch2 = tokenizer2(["hello world"], return_tensors="pt")
    print("\nReloaded tokenizer type:", type(tokenizer2))
    print("Reloaded vocab size:", tokenizer2.vocab_size)
    print("Reloaded input_ids shape:", tuple(batch2["input_ids"].shape))
    print("Saved files:", sorted(p.name for p in save_dir.iterdir()))


if __name__ == "__main__":
    main()