BERT / inference_examples.py
JangTaeng's picture
Upload 7 files
05c8e16 verified
"""
최소한의 BERT 추론 예제 모음
============================
논문의 각 태스크를 가장 작은 코드로 보여주는 함수들입니다. 환경 점검이나
복사-붙여넣기용으로 유용합니다.
모든 예제 실행:
python inference_examples.py
"""
from __future__ import annotations
import torch
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
AutoModelForNextSentencePrediction,
pipeline,
)
# ---------------------------------------------------------------------------
# 1. Masked Language Model (논문 §3.1, Task #1)
# ---------------------------------------------------------------------------
def example_mlm() -> None:
print("=" * 60)
print("1. Masked Language Model (마스크 토큰 예측)")
print("=" * 60)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
model.eval()
text = "The capital of France is [MASK]."
inputs = tokenizer(text, return_tensors="pt")
mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
with torch.no_grad():
logits = model(**inputs).logits
# 마스크 위치에 대한 상위 5개 예측
top_5 = torch.topk(logits[0, mask_idx], 5, dim=-1)
print(f"입력: {text}")
for token_id, score in zip(top_5.indices[0], top_5.values[0]):
token = tokenizer.decode([token_id])
print(f" {token!r:<15} 로짓={score.item():.3f}")
# ---------------------------------------------------------------------------
# 2. Next Sentence Prediction (논문 §3.1, Task #2)
# ---------------------------------------------------------------------------
def example_nsp() -> None:
print("\n" + "=" * 60)
print("2. Next Sentence Prediction (다음 문장 예측)")
print("=" * 60)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased")
model.eval()
pairs = [
("The man went to the store.", "He bought a gallon of milk."),
("The man went to the store.", "Penguins are flightless birds."),
]
for sent_a, sent_b in pairs:
inputs = tokenizer(sent_a, sent_b, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0].tolist()
# HuggingFace BERT NSP: 라벨 0 = IsNext, 라벨 1 = NotNext
verdict = "IsNext" if probs[0] > probs[1] else "NotNext"
print(f" A: {sent_a}")
print(f" B: {sent_b}")
print(f" -> {verdict} (P(IsNext)={probs[0]:.3f}, P(NotNext)={probs[1]:.3f})\n")
# ---------------------------------------------------------------------------
# 3. 문장 쌍 분류 (MNLI; 논문 §4.1, Figure 4a)
# ---------------------------------------------------------------------------
def example_mnli() -> None:
print("=" * 60)
print("3. MNLI (문장 쌍 분류)")
print("=" * 60)
nlp = pipeline(
"text-classification",
model="textattack/bert-base-uncased-MNLI",
)
premise = "A soccer game with multiple males playing."
hypothesis = "Some men are playing a sport."
result = nlp(f"{premise} [SEP] {hypothesis}")[0]
print(f" 전제: {premise}")
print(f" 가설: {hypothesis}")
print(f" -> {result['label']} ({result['score']:.3f})")
# ---------------------------------------------------------------------------
# 4. 단일 문장 분류 (SST-2; 논문 §4.1, Figure 4b)
# ---------------------------------------------------------------------------
def example_sst2() -> None:
print("\n" + "=" * 60)
print("4. SST-2 (단일 문장 감성 분류)")
print("=" * 60)
nlp = pipeline(
"text-classification",
model="textattack/bert-base-uncased-SST-2",
)
sentences = [
"This movie was absolutely fantastic.",
"What a complete waste of time.",
]
for s in sentences:
r = nlp(s)[0]
print(f" {s}\n -> {r['label']} ({r['score']:.3f})")
# ---------------------------------------------------------------------------
# 5. 질의응답 (SQuAD; 논문 §4.2, Figure 4c)
# ---------------------------------------------------------------------------
def example_squad() -> None:
print("\n" + "=" * 60)
print("5. SQuAD v1.1 (추출형 질의응답)")
print("=" * 60)
nlp = pipeline(
"question-answering",
model="bert-large-uncased-whole-word-masking-finetuned-squad",
)
context = (
"BERT was introduced by Google AI Language in October 2018. "
"It is pre-trained on the BooksCorpus (800M words) and English Wikipedia "
"(2,500M words). BERT-Large has 340 million parameters."
)
question = "How many parameters does BERT-Large have?"
result = nlp(question=question, context=context)
print(f" 질문: {question}")
print(f" 답변: {result['answer']!r} (점수={result['score']:.3f})")
# ---------------------------------------------------------------------------
# 6. 개체명 인식 (CoNLL-2003; 논문 §5.3, Figure 4d)
# ---------------------------------------------------------------------------
def example_ner() -> None:
print("\n" + "=" * 60)
print("6. CoNLL-2003 NER (개체명 인식)")
print("=" * 60)
nlp = pipeline(
"token-classification",
model="dslim/bert-base-NER",
aggregation_strategy="simple",
)
text = "Jacob Devlin works at Google in Mountain View, California."
for ent in nlp(text):
print(f" {ent['word']:<20} {ent['entity_group']:<6} (점수={ent['score']:.3f})")
if __name__ == "__main__":
example_mlm()
example_nsp()
example_mnli()
example_sst2()
example_squad()
example_ner()