vi-en-transformer-25m / src /inference_evaluation.py
Cong123779's picture
Upload model source code
51b3b77 verified
"""
INFERENCE & EVALUATION
Greedy Search, Beam Search, BLEU Score evaluation
"""
import torch
import torch.nn.functional as F
from collections import Counter
import math
from tqdm import tqdm
# ============================================================================
# 1. GREEDY SEARCH DECODING
# ============================================================================
def greedy_decode(model, src, src_vocab, tgt_vocab, device, max_len=100, repetition_penalty=1.5, no_repeat_ngram_size=3):
"""
Greedy Decoding - Chọn token có xác suất cao nhất mỗi bước với Repetition Penalty
Args:
model: Transformer model
src: Source sequence [1, src_len] hoặc list of tokens
src_vocab: Source vocabulary
tgt_vocab: Target vocabulary
device: Device
max_len: Maximum length để generate
repetition_penalty: Penalty factor cho tokens lặp lại (>1.0 để giảm repetition)
no_repeat_ngram_size: Kích thước n-gram để tránh lặp lại (0 = tắt)
Returns:
decoded_tokens: List of decoded token indices
decoded_sentence: Decoded sentence (string)
"""
model.eval()
# Nếu src là list, convert sang tensor
if isinstance(src, list):
src = torch.LongTensor([src]).to(device)
elif src.dim() == 1:
src = src.unsqueeze(0)
with torch.no_grad():
# Encode source
encoder_output, src_mask = model.encode(src)
# Khởi tạo target với <sos> token
tgt_tokens = [tgt_vocab.SOS_IDX]
for step in range(max_len):
# Tạo target tensor
tgt = torch.LongTensor([tgt_tokens]).to(device)
# Decode
output = model.decode(tgt, encoder_output, src_mask)
# Lấy prediction cho token cuối cùng
next_token_logits = output[0, -1, :]
# Áp dụng repetition penalty cho từng token
token_counts = {}
for token_id in tgt_tokens:
if token_id not in [tgt_vocab.SOS_IDX, tgt_vocab.EOS_IDX, tgt_vocab.PAD_IDX]:
token_counts[token_id] = token_counts.get(token_id, 0) + 1
# Áp dụng penalty: giảm logit của các tokens đã xuất hiện
for token_id, count in token_counts.items():
if count > 0 and token_id < len(next_token_logits):
# Penalty tăng theo số lần lặp lại
penalty = repetition_penalty ** count
next_token_logits[token_id] = next_token_logits[token_id] / penalty
# N-gram repetition penalty: tránh lặp cụm từ
if no_repeat_ngram_size > 0 and len(tgt_tokens) >= no_repeat_ngram_size:
# Lấy n-gram cuối cùng
last_ngram = tuple(tgt_tokens[-(no_repeat_ngram_size-1):])
# Kiểm tra xem n-gram này đã xuất hiện trước đó chưa
for i in range(len(tgt_tokens) - no_repeat_ngram_size + 1):
ngram = tuple(tgt_tokens[i:i+no_repeat_ngram_size-1])
if ngram == last_ngram:
# Nếu n-gram đã xuất hiện, giảm logit của token tiếp theo
if i + no_repeat_ngram_size - 1 < len(tgt_tokens):
repeated_token = tgt_tokens[i + no_repeat_ngram_size - 1]
if repeated_token < len(next_token_logits):
next_token_logits[repeated_token] = next_token_logits[repeated_token] / (repetition_penalty ** 2)
# Lấy token có xác suất cao nhất
next_token = next_token_logits.argmax().item()
# Thêm vào sequence
tgt_tokens.append(next_token)
# Dừng nếu gặp <eos>
if next_token == tgt_vocab.EOS_IDX:
break
# Decode thành sentence
decoded_sentence = tgt_vocab.decode(tgt_tokens)
return tgt_tokens, decoded_sentence
# ============================================================================
# 2. BEAM SEARCH DECODING
# ============================================================================
class BeamSearchNode:
"""
Node trong Beam Search
"""
def __init__(self, tokens, log_prob, length):
self.tokens = tokens
self.log_prob = log_prob
self.length = length
def eval(self, alpha=0.6):
"""
Tính score với length normalization
Score = log_prob / (length^alpha)
Args:
alpha: Length penalty factor
"""
return self.log_prob / (self.length ** alpha)
def beam_search_decode(model, src, src_vocab, tgt_vocab, device, beam_size=5, max_len=100, alpha=0.6, repetition_penalty=1.5, no_repeat_ngram_size=3):
"""
Beam Search Decoding - Giữ top-k candidates tốt nhất với Repetition Penalty
Args:
model: Transformer model
src: Source sequence
src_vocab: Source vocabulary
tgt_vocab: Target vocabulary
device: Device
beam_size: Beam size (số lượng candidates)
max_len: Maximum length
alpha: Length penalty factor
repetition_penalty: Penalty factor cho tokens lặp lại (>1.0 để giảm repetition)
no_repeat_ngram_size: Kích thước n-gram để tránh lặp lại (0 = tắt)
Returns:
best_tokens: List of best token indices
best_sentence: Best decoded sentence
"""
model.eval()
# Nếu src là list, convert sang tensor
if isinstance(src, list):
src = torch.LongTensor([src]).to(device)
elif src.dim() == 1:
src = src.unsqueeze(0)
with torch.no_grad():
# Encode source
encoder_output, src_mask = model.encode(src)
# Khởi tạo beam với <sos> token
beams = [BeamSearchNode(
tokens=[tgt_vocab.SOS_IDX],
log_prob=0.0,
length=1
)]
completed_beams = []
for step in range(max_len):
candidates = []
for beam in beams:
# Nếu đã kết thúc, thêm vào completed
if beam.tokens[-1] == tgt_vocab.EOS_IDX:
completed_beams.append(beam)
continue
# Tạo target tensor
tgt = torch.LongTensor([beam.tokens]).to(device)
# Decode
output = model.decode(tgt, encoder_output, src_mask)
# Lấy log probabilities cho token cuối
next_token_logits = output[0, -1, :]
# Áp dụng repetition penalty
# Đếm số lần xuất hiện của mỗi token trong sequence hiện tại
token_counts = {}
for token_id in beam.tokens:
if token_id not in [tgt_vocab.SOS_IDX, tgt_vocab.EOS_IDX, tgt_vocab.PAD_IDX]:
token_counts[token_id] = token_counts.get(token_id, 0) + 1
# Áp dụng penalty: giảm logit của các tokens đã xuất hiện (mạnh hơn)
for token_id, count in token_counts.items():
if count > 0 and token_id < len(next_token_logits):
# Penalty tăng mạnh theo số lần lặp lại
# count=1: penalty nhẹ, count=2+: penalty rất mạnh
penalty = repetition_penalty ** (count * 1.5) # Tăng penalty mạnh hơn
next_token_logits[token_id] = next_token_logits[token_id] / penalty
# N-gram repetition penalty: tránh lặp cụm từ
if no_repeat_ngram_size > 0 and len(beam.tokens) >= no_repeat_ngram_size:
# Lấy n-gram cuối cùng
last_ngram = tuple(beam.tokens[-(no_repeat_ngram_size-1):])
# Kiểm tra xem n-gram này đã xuất hiện trước đó chưa
for i in range(len(beam.tokens) - no_repeat_ngram_size + 1):
ngram = tuple(beam.tokens[i:i+no_repeat_ngram_size-1])
if ngram == last_ngram:
# Nếu n-gram đã xuất hiện, giảm logit của token tiếp theo rất mạnh
if i + no_repeat_ngram_size - 1 < len(beam.tokens):
repeated_token = beam.tokens[i + no_repeat_ngram_size - 1]
if repeated_token < len(next_token_logits):
# Penalty rất mạnh cho n-gram repetition
next_token_logits[repeated_token] = next_token_logits[repeated_token] / (repetition_penalty ** 3)
log_probs = F.log_softmax(next_token_logits, dim=-1)
# Lấy top-k tokens
top_log_probs, top_tokens = torch.topk(log_probs, beam_size)
# Tạo candidates mới
for log_prob, token in zip(top_log_probs, top_tokens):
new_beam = BeamSearchNode(
tokens=beam.tokens + [token.item()],
log_prob=beam.log_prob + log_prob.item(),
length=beam.length + 1
)
candidates.append(new_beam)
# Nếu không còn candidates, dừng
if not candidates:
break
# Chọn top beam_size candidates tốt nhất
beams = sorted(candidates, key=lambda x: x.eval(alpha), reverse=True)[:beam_size]
# Nếu tất cả beams đã complete, dừng
if len(completed_beams) >= beam_size:
break
# Thêm các beams chưa complete vào completed
completed_beams.extend(beams)
# Chọn beam tốt nhất
best_beam = max(completed_beams, key=lambda x: x.eval(alpha))
# Decode thành sentence
best_sentence = tgt_vocab.decode(best_beam.tokens)
return best_beam.tokens, best_sentence
# ============================================================================
# 3. BLEU SCORE CALCULATION
# ============================================================================
def calculate_ngrams(tokens, n):
"""
Tính n-grams từ list of tokens
Args:
tokens: List of tokens
n: n-gram size
Returns:
ngrams: Counter of n-grams
"""
ngrams = []
for i in range(len(tokens) - n + 1):
ngram = tuple(tokens[i:i+n])
ngrams.append(ngram)
return Counter(ngrams)
def calculate_bleu_score(references, hypotheses, max_n=4, weights=None):
"""
Tính BLEU score
BLEU = BP * exp(sum(w_n * log(p_n)))
Args:
references: List of reference sentences (list of token lists)
hypotheses: List of hypothesis sentences (list of token lists)
max_n: Maximum n-gram size (mặc định 4)
weights: Weights cho mỗi n-gram (mặc định uniform)
Returns:
bleu_score: BLEU score (0-100)
"""
if weights is None:
weights = [1.0/max_n] * max_n
# Tính precision cho mỗi n-gram
precisions = []
for n in range(1, max_n + 1):
matched = 0
total = 0
for ref, hyp in zip(references, hypotheses):
# Tính n-grams
ref_ngrams = calculate_ngrams(ref, n)
hyp_ngrams = calculate_ngrams(hyp, n)
# Đếm matched n-grams
for ngram, count in hyp_ngrams.items():
matched += min(count, ref_ngrams.get(ngram, 0))
total += max(len(hyp) - n + 1, 0)
# Tính precision
if total > 0:
precision = matched / total
else:
precision = 0
precisions.append(precision)
# Tính Brevity Penalty (BP)
ref_length = sum(len(ref) for ref in references)
hyp_length = sum(len(hyp) for hyp in hypotheses)
if hyp_length > ref_length:
bp = 1.0
elif hyp_length == 0:
bp = 0.0
else:
bp = math.exp(1 - ref_length / hyp_length)
# Tính BLEU score
if min(precisions) > 0:
log_precisions = [w * math.log(p) for w, p in zip(weights, precisions)]
bleu = bp * math.exp(sum(log_precisions))
else:
bleu = 0.0
return bleu * 100 # Convert sang 0-100 scale
# ============================================================================
# 4. EVALUATE ON TEST SET
# ============================================================================
def evaluate_model(model, test_loader, src_vocab, tgt_vocab, device,
use_beam_search=True, beam_size=5, max_len=100):
"""
Đánh giá model trên test set
Args:
model: Transformer model
test_loader: Test DataLoader
src_vocab: Source vocabulary
tgt_vocab: Target vocabulary
device: Device
use_beam_search: Sử dụng beam search hay greedy search
beam_size: Beam size (nếu dùng beam search)
max_len: Maximum decode length
Returns:
bleu_score: BLEU score
translations: List of (source, reference, hypothesis) tuples
"""
model.eval()
references = []
hypotheses = []
translations = []
print(f"\n{'='*70}")
print(f"ĐÁNH GIÁ TRÊN TEST SET")
print(f"{'='*70}")
print(f"Decoding method: {'Beam Search' if use_beam_search else 'Greedy Search'}")
if use_beam_search:
print(f"Beam size: {beam_size}")
print(f"{'='*70}\n")
with torch.no_grad():
for src, tgt, _, _ in tqdm(test_loader, desc='Evaluating'):
src = src.to(device)
batch_size = src.size(0)
for i in range(batch_size):
src_seq = src[i]
tgt_seq = tgt[i]
# Decode
if use_beam_search:
_, hypothesis = beam_search_decode(
model, src_seq, src_vocab, tgt_vocab,
device, beam_size, max_len, alpha=0.6, repetition_penalty=1.5, no_repeat_ngram_size=3
)
else:
_, hypothesis = greedy_decode(
model, src_seq, src_vocab, tgt_vocab,
device, max_len, repetition_penalty=1.5, no_repeat_ngram_size=3
)
# Reference
reference = tgt_vocab.decode(tgt_seq.tolist())
# Source
source = src_vocab.decode(src_seq.tolist())
# Tokenize để tính BLEU
ref_tokens = reference.split()
hyp_tokens = hypothesis.split()
references.append(ref_tokens)
hypotheses.append(hyp_tokens)
translations.append((source, reference, hypothesis))
# Tính BLEU score
bleu = calculate_bleu_score(references, hypotheses)
print(f"\n{'='*70}")
print(f"KẾT QUẢ ĐÁNH GIÁ")
print(f"{'='*70}")
print(f"BLEU Score: {bleu:.2f}")
print(f"{'='*70}\n")
return bleu, translations
# ============================================================================
# 5. PRINT SAMPLE TRANSLATIONS
# ============================================================================
def print_sample_translations(translations, num_samples=10):
"""
In một số ví dụ dịch
Args:
translations: List of (source, reference, hypothesis) tuples
num_samples: Số lượng samples để in
"""
print(f"\n{'='*70}")
print(f"MỘT SỐ VÍ DỤ DỊCH")
print(f"{'='*70}\n")
for i, (src, ref, hyp) in enumerate(translations[:num_samples], 1):
print(f"Ví dụ {i}:")
print(f" Source: {src}")
print(f" Reference: {ref}")
print(f" Hypothesis: {hyp}")
print()
# ============================================================================
# 6. TRANSLATE SINGLE SENTENCE
# ============================================================================
def translate_sentence(model, sentence, src_vocab, tgt_vocab, device,
use_beam_search=True, beam_size=5, max_len=100, src_lang='vi',
repetition_penalty=1.5, no_repeat_ngram_size=3):
"""
Dịch một câu đơn
Args:
model: Transformer model
sentence: Source sentence (string)
src_vocab: Source vocabulary
tgt_vocab: Target vocabulary
device: Device
use_beam_search: Sử dụng beam search hay greedy
beam_size: Beam size
max_len: Maximum length
src_lang: Ngôn ngữ source ('vi' hoặc 'en')
Returns:
translation: Translated sentence
"""
model.eval()
# Tiền xử lý câu
from data_preprocessing import clean_text
sentence = clean_text(sentence, src_lang)
# Encode
tokens = src_vocab.encode(sentence)
src = torch.LongTensor([tokens]).to(device)
# Decode
if use_beam_search:
_, translation = beam_search_decode(
model, src, src_vocab, tgt_vocab,
device, beam_size, max_len, alpha=0.6,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size
)
else:
_, translation = greedy_decode(
model, src, src_vocab, tgt_vocab,
device, max_len,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size
)
return translation
# ============================================================================
# 7. INTERACTIVE TRANSLATION
# ============================================================================
def interactive_translation(model, src_vocab, tgt_vocab, device,
use_beam_search=True, beam_size=5, direction='vi2en'):
"""
Chế độ dịch tương tác
Args:
model: Transformer model
src_vocab: Source vocabulary
tgt_vocab: Target vocabulary
device: Device
use_beam_search: Sử dụng beam search
beam_size: Beam size
direction: 'vi2en' hoặc 'en2vi'
"""
print("\n" + "="*70)
print("CHẾ ĐỘ DỊCH TƯƠNG TÁC")
print("="*70)
if direction == 'vi2en':
print("Nhập câu tiếng Việt để dịch sang tiếng Anh")
else:
print("Nhập câu tiếng Anh để dịch sang tiếng Việt")
print("Gõ 'quit' hoặc 'exit' để thoát")
print("="*70 + "\n")
src_lang = 'en' if direction == 'en2vi' else 'vi'
src_label = "Tiếng Anh" if direction == 'en2vi' else "Tiếng Việt"
tgt_label = "Tiếng Việt" if direction == 'en2vi' else "Tiếng Anh"
while True:
sentence = input(f"{src_label}: ").strip()
if sentence.lower() in ['quit', 'exit', '']:
print("Tạm biệt!")
break
translation = translate_sentence(
model, sentence, src_vocab, tgt_vocab, device,
use_beam_search, beam_size, src_lang=src_lang, max_len=100
)
print(f"{tgt_label}: {translation}\n")
# ============================================================================
# 8. SAVE TRANSLATIONS
# ============================================================================
def save_translations(translations, output_file='translations.txt'):
"""
Lưu translations ra file
Args:
translations: List of (source, reference, hypothesis) tuples
output_file: Output file path
"""
with open(output_file, 'w', encoding='utf-8') as f:
for i, (src, ref, hyp) in enumerate(translations, 1):
f.write(f"Example {i}:\n")
f.write(f"Source: {src}\n")
f.write(f"Reference: {ref}\n")
f.write(f"Hypothesis: {hyp}\n")
f.write("\n")
print(f"✓ Saved translations to {output_file}")