File size: 5,862 Bytes
49f69db
05c8e16
 
 
 
49f69db
05c8e16
49f69db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
 
 
 
 
05c8e16
49f69db
05c8e16
49f69db
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c8e16
49f69db
 
 
 
 
 
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
05c8e16
 
49f69db
 
 
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
 
 
 
 
05c8e16
 
49f69db
 
 
05c8e16
49f69db
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
05c8e16
49f69db
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
최소한의 BERT 추론 예제 모음
============================
논문의 각 태스크를 가장 작은 코드로 보여주는 함수들입니다. 환경 점검이나
복사-붙여넣기용으로 유용합니다.

모든 예제 실행:
    python inference_examples.py
"""

from __future__ import annotations

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    AutoModelForNextSentencePrediction,
    pipeline,
)


# ---------------------------------------------------------------------------
# 1. Masked Language Model (논문 §3.1, Task #1)
# ---------------------------------------------------------------------------
def example_mlm() -> None:
    print("=" * 60)
    print("1. Masked Language Model (마스크 토큰 예측)")
    print("=" * 60)

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
    model.eval()

    text = "The capital of France is [MASK]."
    inputs = tokenizer(text, return_tensors="pt")
    mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]

    with torch.no_grad():
        logits = model(**inputs).logits

    # 마스크 위치에 대한 상위 5개 예측
    top_5 = torch.topk(logits[0, mask_idx], 5, dim=-1)
    print(f"입력: {text}")
    for token_id, score in zip(top_5.indices[0], top_5.values[0]):
        token = tokenizer.decode([token_id])
        print(f"  {token!r:<15} 로짓={score.item():.3f}")


# ---------------------------------------------------------------------------
# 2. Next Sentence Prediction (논문 §3.1, Task #2)
# ---------------------------------------------------------------------------
def example_nsp() -> None:
    print("\n" + "=" * 60)
    print("2. Next Sentence Prediction (다음 문장 예측)")
    print("=" * 60)

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
    model.eval()

    pairs = [
        ("The man went to the store.", "He bought a gallon of milk."),
        ("The man went to the store.", "Penguins are flightless birds."),
    ]
    for sent_a, sent_b in pairs:
        inputs = tokenizer(sent_a, sent_b, return_tensors="pt")
        with torch.no_grad():
            logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)[0].tolist()
        # HuggingFace BERT NSP: 라벨 0 = IsNext, 라벨 1 = NotNext
        verdict = "IsNext" if probs[0] > probs[1] else "NotNext"
        print(f"  A: {sent_a}")
        print(f"  B: {sent_b}")
        print(f"  -> {verdict}  (P(IsNext)={probs[0]:.3f}, P(NotNext)={probs[1]:.3f})\n")


# ---------------------------------------------------------------------------
# 3. 문장 쌍 분류 (MNLI; 논문 §4.1, Figure 4a)
# ---------------------------------------------------------------------------
def example_mnli() -> None:
    print("=" * 60)
    print("3. MNLI (문장 쌍 분류)")
    print("=" * 60)

    nlp = pipeline(
        "text-classification",
        model="textattack/bert-base-uncased-MNLI",
    )
    premise = "A soccer game with multiple males playing."
    hypothesis = "Some men are playing a sport."
    result = nlp(f"{premise} [SEP] {hypothesis}")[0]
    print(f"  전제:   {premise}")
    print(f"  가설:   {hypothesis}")
    print(f"  -> {result['label']}  ({result['score']:.3f})")


# ---------------------------------------------------------------------------
# 4. 단일 문장 분류 (SST-2; 논문 §4.1, Figure 4b)
# ---------------------------------------------------------------------------
def example_sst2() -> None:
    print("\n" + "=" * 60)
    print("4. SST-2 (단일 문장 감성 분류)")
    print("=" * 60)

    nlp = pipeline(
        "text-classification",
        model="textattack/bert-base-uncased-SST-2",
    )
    sentences = [
        "This movie was absolutely fantastic.",
        "What a complete waste of time.",
    ]
    for s in sentences:
        r = nlp(s)[0]
        print(f"  {s}\n    -> {r['label']}  ({r['score']:.3f})")


# ---------------------------------------------------------------------------
# 5. 질의응답 (SQuAD; 논문 §4.2, Figure 4c)
# ---------------------------------------------------------------------------
def example_squad() -> None:
    print("\n" + "=" * 60)
    print("5. SQuAD v1.1 (추출형 질의응답)")
    print("=" * 60)

    nlp = pipeline(
        "question-answering",
        model="bert-large-uncased-whole-word-masking-finetuned-squad",
    )
    context = (
        "BERT was introduced by Google AI Language in October 2018. "
        "It is pre-trained on the BooksCorpus (800M words) and English Wikipedia "
        "(2,500M words). BERT-Large has 340 million parameters."
    )
    question = "How many parameters does BERT-Large have?"
    result = nlp(question=question, context=context)
    print(f"  질문: {question}")
    print(f"  답변: {result['answer']!r}  (점수={result['score']:.3f})")


# ---------------------------------------------------------------------------
# 6. 개체명 인식 (CoNLL-2003; 논문 §5.3, Figure 4d)
# ---------------------------------------------------------------------------
def example_ner() -> None:
    print("\n" + "=" * 60)
    print("6. CoNLL-2003 NER (개체명 인식)")
    print("=" * 60)

    nlp = pipeline(
        "token-classification",
        model="dslim/bert-base-NER",
        aggregation_strategy="simple",
    )
    text = "Jacob Devlin works at Google in Mountain View, California."
    for ent in nlp(text):
        print(f"  {ent['word']:<20} {ent['entity_group']:<6} (점수={ent['score']:.3f})")


if __name__ == "__main__":
    example_mlm()
    example_nsp()
    example_mnli()
    example_sst2()
    example_squad()
    example_ner()