| """ |
| 최소한의 BERT 추론 예제 모음 |
| ============================ |
| 논문의 각 태스크를 가장 작은 코드로 보여주는 함수들입니다. 환경 점검이나 |
| 복사-붙여넣기용으로 유용합니다. |
| |
| 모든 예제 실행: |
| python inference_examples.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForMaskedLM, |
| AutoModelForNextSentencePrediction, |
| pipeline, |
| ) |
|
|
|
|
| |
| |
| |
| def example_mlm() -> None: |
| print("=" * 60) |
| print("1. Masked Language Model (마스크 토큰 예측)") |
| print("=" * 60) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") |
| model.eval() |
|
|
| text = "The capital of France is [MASK]." |
| inputs = tokenizer(text, return_tensors="pt") |
| mask_idx = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] |
|
|
| with torch.no_grad(): |
| logits = model(**inputs).logits |
|
|
| |
| top_5 = torch.topk(logits[0, mask_idx], 5, dim=-1) |
| print(f"입력: {text}") |
| for token_id, score in zip(top_5.indices[0], top_5.values[0]): |
| token = tokenizer.decode([token_id]) |
| print(f" {token!r:<15} 로짓={score.item():.3f}") |
|
|
|
|
| |
| |
| |
| def example_nsp() -> None: |
| print("\n" + "=" * 60) |
| print("2. Next Sentence Prediction (다음 문장 예측)") |
| print("=" * 60) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased") |
| model.eval() |
|
|
| pairs = [ |
| ("The man went to the store.", "He bought a gallon of milk."), |
| ("The man went to the store.", "Penguins are flightless birds."), |
| ] |
| for sent_a, sent_b in pairs: |
| inputs = tokenizer(sent_a, sent_b, return_tensors="pt") |
| with torch.no_grad(): |
| logits = model(**inputs).logits |
| probs = torch.softmax(logits, dim=-1)[0].tolist() |
| |
| verdict = "IsNext" if probs[0] > probs[1] else "NotNext" |
| print(f" A: {sent_a}") |
| print(f" B: {sent_b}") |
| print(f" -> {verdict} (P(IsNext)={probs[0]:.3f}, P(NotNext)={probs[1]:.3f})\n") |
|
|
|
|
| |
| |
| |
| def example_mnli() -> None: |
| print("=" * 60) |
| print("3. MNLI (문장 쌍 분류)") |
| print("=" * 60) |
|
|
| nlp = pipeline( |
| "text-classification", |
| model="textattack/bert-base-uncased-MNLI", |
| ) |
| premise = "A soccer game with multiple males playing." |
| hypothesis = "Some men are playing a sport." |
| result = nlp(f"{premise} [SEP] {hypothesis}")[0] |
| print(f" 전제: {premise}") |
| print(f" 가설: {hypothesis}") |
| print(f" -> {result['label']} ({result['score']:.3f})") |
|
|
|
|
| |
| |
| |
| def example_sst2() -> None: |
| print("\n" + "=" * 60) |
| print("4. SST-2 (단일 문장 감성 분류)") |
| print("=" * 60) |
|
|
| nlp = pipeline( |
| "text-classification", |
| model="textattack/bert-base-uncased-SST-2", |
| ) |
| sentences = [ |
| "This movie was absolutely fantastic.", |
| "What a complete waste of time.", |
| ] |
| for s in sentences: |
| r = nlp(s)[0] |
| print(f" {s}\n -> {r['label']} ({r['score']:.3f})") |
|
|
|
|
| |
| |
| |
| def example_squad() -> None: |
| print("\n" + "=" * 60) |
| print("5. SQuAD v1.1 (추출형 질의응답)") |
| print("=" * 60) |
|
|
| nlp = pipeline( |
| "question-answering", |
| model="bert-large-uncased-whole-word-masking-finetuned-squad", |
| ) |
| context = ( |
| "BERT was introduced by Google AI Language in October 2018. " |
| "It is pre-trained on the BooksCorpus (800M words) and English Wikipedia " |
| "(2,500M words). BERT-Large has 340 million parameters." |
| ) |
| question = "How many parameters does BERT-Large have?" |
| result = nlp(question=question, context=context) |
| print(f" 질문: {question}") |
| print(f" 답변: {result['answer']!r} (점수={result['score']:.3f})") |
|
|
|
|
| |
| |
| |
| def example_ner() -> None: |
| print("\n" + "=" * 60) |
| print("6. CoNLL-2003 NER (개체명 인식)") |
| print("=" * 60) |
|
|
| nlp = pipeline( |
| "token-classification", |
| model="dslim/bert-base-NER", |
| aggregation_strategy="simple", |
| ) |
| text = "Jacob Devlin works at Google in Mountain View, California." |
| for ent in nlp(text): |
| print(f" {ent['word']:<20} {ent['entity_group']:<6} (점수={ent['score']:.3f})") |
|
|
|
|
| if __name__ == "__main__": |
| example_mlm() |
| example_nsp() |
| example_mnli() |
| example_sst2() |
| example_squad() |
| example_ner() |
|
|