| | """ |
| | INFERENCE & EVALUATION |
| | Greedy Search, Beam Search, BLEU Score evaluation |
| | """ |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from collections import Counter |
| | import math |
| | from tqdm import tqdm |
| |
|
| | |
| | |
| | |
| |
|
| | def greedy_decode(model, src, src_vocab, tgt_vocab, device, max_len=100, repetition_penalty=1.5, no_repeat_ngram_size=3): |
| | """ |
| | Greedy Decoding - Chọn token có xác suất cao nhất mỗi bước với Repetition Penalty |
| | |
| | Args: |
| | model: Transformer model |
| | src: Source sequence [1, src_len] hoặc list of tokens |
| | src_vocab: Source vocabulary |
| | tgt_vocab: Target vocabulary |
| | device: Device |
| | max_len: Maximum length để generate |
| | repetition_penalty: Penalty factor cho tokens lặp lại (>1.0 để giảm repetition) |
| | no_repeat_ngram_size: Kích thước n-gram để tránh lặp lại (0 = tắt) |
| | |
| | Returns: |
| | decoded_tokens: List of decoded token indices |
| | decoded_sentence: Decoded sentence (string) |
| | """ |
| | model.eval() |
| | |
| | |
| | if isinstance(src, list): |
| | src = torch.LongTensor([src]).to(device) |
| | elif src.dim() == 1: |
| | src = src.unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | |
| | encoder_output, src_mask = model.encode(src) |
| | |
| | |
| | tgt_tokens = [tgt_vocab.SOS_IDX] |
| | |
| | for step in range(max_len): |
| | |
| | tgt = torch.LongTensor([tgt_tokens]).to(device) |
| | |
| | |
| | output = model.decode(tgt, encoder_output, src_mask) |
| | |
| | |
| | next_token_logits = output[0, -1, :] |
| | |
| | |
| | token_counts = {} |
| | for token_id in tgt_tokens: |
| | if token_id not in [tgt_vocab.SOS_IDX, tgt_vocab.EOS_IDX, tgt_vocab.PAD_IDX]: |
| | token_counts[token_id] = token_counts.get(token_id, 0) + 1 |
| | |
| | |
| | for token_id, count in token_counts.items(): |
| | if count > 0 and token_id < len(next_token_logits): |
| | |
| | penalty = repetition_penalty ** count |
| | next_token_logits[token_id] = next_token_logits[token_id] / penalty |
| | |
| | |
| | if no_repeat_ngram_size > 0 and len(tgt_tokens) >= no_repeat_ngram_size: |
| | |
| | last_ngram = tuple(tgt_tokens[-(no_repeat_ngram_size-1):]) |
| | |
| | |
| | for i in range(len(tgt_tokens) - no_repeat_ngram_size + 1): |
| | ngram = tuple(tgt_tokens[i:i+no_repeat_ngram_size-1]) |
| | if ngram == last_ngram: |
| | |
| | if i + no_repeat_ngram_size - 1 < len(tgt_tokens): |
| | repeated_token = tgt_tokens[i + no_repeat_ngram_size - 1] |
| | if repeated_token < len(next_token_logits): |
| | next_token_logits[repeated_token] = next_token_logits[repeated_token] / (repetition_penalty ** 2) |
| | |
| | |
| | next_token = next_token_logits.argmax().item() |
| | |
| | |
| | tgt_tokens.append(next_token) |
| | |
| | |
| | if next_token == tgt_vocab.EOS_IDX: |
| | break |
| | |
| | |
| | decoded_sentence = tgt_vocab.decode(tgt_tokens) |
| | |
| | return tgt_tokens, decoded_sentence |
| |
|
| | |
| | |
| | |
| |
|
| | class BeamSearchNode: |
| | """ |
| | Node trong Beam Search |
| | """ |
| | def __init__(self, tokens, log_prob, length): |
| | self.tokens = tokens |
| | self.log_prob = log_prob |
| | self.length = length |
| | |
| | def eval(self, alpha=0.6): |
| | """ |
| | Tính score với length normalization |
| | |
| | Score = log_prob / (length^alpha) |
| | |
| | Args: |
| | alpha: Length penalty factor |
| | """ |
| | return self.log_prob / (self.length ** alpha) |
| |
|
| | def beam_search_decode(model, src, src_vocab, tgt_vocab, device, beam_size=5, max_len=100, alpha=0.6, repetition_penalty=1.5, no_repeat_ngram_size=3): |
| | """ |
| | Beam Search Decoding - Giữ top-k candidates tốt nhất với Repetition Penalty |
| | |
| | Args: |
| | model: Transformer model |
| | src: Source sequence |
| | src_vocab: Source vocabulary |
| | tgt_vocab: Target vocabulary |
| | device: Device |
| | beam_size: Beam size (số lượng candidates) |
| | max_len: Maximum length |
| | alpha: Length penalty factor |
| | repetition_penalty: Penalty factor cho tokens lặp lại (>1.0 để giảm repetition) |
| | no_repeat_ngram_size: Kích thước n-gram để tránh lặp lại (0 = tắt) |
| | |
| | Returns: |
| | best_tokens: List of best token indices |
| | best_sentence: Best decoded sentence |
| | """ |
| | model.eval() |
| | |
| | |
| | if isinstance(src, list): |
| | src = torch.LongTensor([src]).to(device) |
| | elif src.dim() == 1: |
| | src = src.unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | |
| | encoder_output, src_mask = model.encode(src) |
| | |
| | |
| | beams = [BeamSearchNode( |
| | tokens=[tgt_vocab.SOS_IDX], |
| | log_prob=0.0, |
| | length=1 |
| | )] |
| | |
| | completed_beams = [] |
| | |
| | for step in range(max_len): |
| | candidates = [] |
| | |
| | for beam in beams: |
| | |
| | if beam.tokens[-1] == tgt_vocab.EOS_IDX: |
| | completed_beams.append(beam) |
| | continue |
| | |
| | |
| | tgt = torch.LongTensor([beam.tokens]).to(device) |
| | |
| | |
| | output = model.decode(tgt, encoder_output, src_mask) |
| | |
| | |
| | next_token_logits = output[0, -1, :] |
| | |
| | |
| | |
| | token_counts = {} |
| | for token_id in beam.tokens: |
| | if token_id not in [tgt_vocab.SOS_IDX, tgt_vocab.EOS_IDX, tgt_vocab.PAD_IDX]: |
| | token_counts[token_id] = token_counts.get(token_id, 0) + 1 |
| | |
| | |
| | for token_id, count in token_counts.items(): |
| | if count > 0 and token_id < len(next_token_logits): |
| | |
| | |
| | penalty = repetition_penalty ** (count * 1.5) |
| | next_token_logits[token_id] = next_token_logits[token_id] / penalty |
| | |
| | |
| | if no_repeat_ngram_size > 0 and len(beam.tokens) >= no_repeat_ngram_size: |
| | |
| | last_ngram = tuple(beam.tokens[-(no_repeat_ngram_size-1):]) |
| | |
| | |
| | for i in range(len(beam.tokens) - no_repeat_ngram_size + 1): |
| | ngram = tuple(beam.tokens[i:i+no_repeat_ngram_size-1]) |
| | if ngram == last_ngram: |
| | |
| | if i + no_repeat_ngram_size - 1 < len(beam.tokens): |
| | repeated_token = beam.tokens[i + no_repeat_ngram_size - 1] |
| | if repeated_token < len(next_token_logits): |
| | |
| | next_token_logits[repeated_token] = next_token_logits[repeated_token] / (repetition_penalty ** 3) |
| | |
| | log_probs = F.log_softmax(next_token_logits, dim=-1) |
| | |
| | |
| | top_log_probs, top_tokens = torch.topk(log_probs, beam_size) |
| | |
| | |
| | for log_prob, token in zip(top_log_probs, top_tokens): |
| | new_beam = BeamSearchNode( |
| | tokens=beam.tokens + [token.item()], |
| | log_prob=beam.log_prob + log_prob.item(), |
| | length=beam.length + 1 |
| | ) |
| | candidates.append(new_beam) |
| | |
| | |
| | if not candidates: |
| | break |
| | |
| | |
| | beams = sorted(candidates, key=lambda x: x.eval(alpha), reverse=True)[:beam_size] |
| | |
| | |
| | if len(completed_beams) >= beam_size: |
| | break |
| | |
| | |
| | completed_beams.extend(beams) |
| | |
| | |
| | best_beam = max(completed_beams, key=lambda x: x.eval(alpha)) |
| | |
| | |
| | best_sentence = tgt_vocab.decode(best_beam.tokens) |
| | |
| | return best_beam.tokens, best_sentence |
| |
|
| | |
| | |
| | |
| |
|
| | def calculate_ngrams(tokens, n): |
| | """ |
| | Tính n-grams từ list of tokens |
| | |
| | Args: |
| | tokens: List of tokens |
| | n: n-gram size |
| | |
| | Returns: |
| | ngrams: Counter of n-grams |
| | """ |
| | ngrams = [] |
| | for i in range(len(tokens) - n + 1): |
| | ngram = tuple(tokens[i:i+n]) |
| | ngrams.append(ngram) |
| | return Counter(ngrams) |
| |
|
| | def calculate_bleu_score(references, hypotheses, max_n=4, weights=None): |
| | """ |
| | Tính BLEU score |
| | |
| | BLEU = BP * exp(sum(w_n * log(p_n))) |
| | |
| | Args: |
| | references: List of reference sentences (list of token lists) |
| | hypotheses: List of hypothesis sentences (list of token lists) |
| | max_n: Maximum n-gram size (mặc định 4) |
| | weights: Weights cho mỗi n-gram (mặc định uniform) |
| | |
| | Returns: |
| | bleu_score: BLEU score (0-100) |
| | """ |
| | if weights is None: |
| | weights = [1.0/max_n] * max_n |
| | |
| | |
| | precisions = [] |
| | |
| | for n in range(1, max_n + 1): |
| | matched = 0 |
| | total = 0 |
| | |
| | for ref, hyp in zip(references, hypotheses): |
| | |
| | ref_ngrams = calculate_ngrams(ref, n) |
| | hyp_ngrams = calculate_ngrams(hyp, n) |
| | |
| | |
| | for ngram, count in hyp_ngrams.items(): |
| | matched += min(count, ref_ngrams.get(ngram, 0)) |
| | |
| | total += max(len(hyp) - n + 1, 0) |
| | |
| | |
| | if total > 0: |
| | precision = matched / total |
| | else: |
| | precision = 0 |
| | |
| | precisions.append(precision) |
| | |
| | |
| | ref_length = sum(len(ref) for ref in references) |
| | hyp_length = sum(len(hyp) for hyp in hypotheses) |
| | |
| | if hyp_length > ref_length: |
| | bp = 1.0 |
| | elif hyp_length == 0: |
| | bp = 0.0 |
| | else: |
| | bp = math.exp(1 - ref_length / hyp_length) |
| | |
| | |
| | if min(precisions) > 0: |
| | log_precisions = [w * math.log(p) for w, p in zip(weights, precisions)] |
| | bleu = bp * math.exp(sum(log_precisions)) |
| | else: |
| | bleu = 0.0 |
| | |
| | return bleu * 100 |
| |
|
| | |
| | |
| | |
| |
|
| | def evaluate_model(model, test_loader, src_vocab, tgt_vocab, device, |
| | use_beam_search=True, beam_size=5, max_len=100): |
| | """ |
| | Đánh giá model trên test set |
| | |
| | Args: |
| | model: Transformer model |
| | test_loader: Test DataLoader |
| | src_vocab: Source vocabulary |
| | tgt_vocab: Target vocabulary |
| | device: Device |
| | use_beam_search: Sử dụng beam search hay greedy search |
| | beam_size: Beam size (nếu dùng beam search) |
| | max_len: Maximum decode length |
| | |
| | Returns: |
| | bleu_score: BLEU score |
| | translations: List of (source, reference, hypothesis) tuples |
| | """ |
| | model.eval() |
| | |
| | references = [] |
| | hypotheses = [] |
| | translations = [] |
| | |
| | print(f"\n{'='*70}") |
| | print(f"ĐÁNH GIÁ TRÊN TEST SET") |
| | print(f"{'='*70}") |
| | print(f"Decoding method: {'Beam Search' if use_beam_search else 'Greedy Search'}") |
| | if use_beam_search: |
| | print(f"Beam size: {beam_size}") |
| | print(f"{'='*70}\n") |
| | |
| | with torch.no_grad(): |
| | for src, tgt, _, _ in tqdm(test_loader, desc='Evaluating'): |
| | src = src.to(device) |
| | |
| | batch_size = src.size(0) |
| | |
| | for i in range(batch_size): |
| | src_seq = src[i] |
| | tgt_seq = tgt[i] |
| | |
| | |
| | if use_beam_search: |
| | _, hypothesis = beam_search_decode( |
| | model, src_seq, src_vocab, tgt_vocab, |
| | device, beam_size, max_len, alpha=0.6, repetition_penalty=1.5, no_repeat_ngram_size=3 |
| | ) |
| | else: |
| | _, hypothesis = greedy_decode( |
| | model, src_seq, src_vocab, tgt_vocab, |
| | device, max_len, repetition_penalty=1.5, no_repeat_ngram_size=3 |
| | ) |
| | |
| | |
| | reference = tgt_vocab.decode(tgt_seq.tolist()) |
| | |
| | |
| | source = src_vocab.decode(src_seq.tolist()) |
| | |
| | |
| | ref_tokens = reference.split() |
| | hyp_tokens = hypothesis.split() |
| | |
| | references.append(ref_tokens) |
| | hypotheses.append(hyp_tokens) |
| | |
| | translations.append((source, reference, hypothesis)) |
| | |
| | |
| | bleu = calculate_bleu_score(references, hypotheses) |
| | |
| | print(f"\n{'='*70}") |
| | print(f"KẾT QUẢ ĐÁNH GIÁ") |
| | print(f"{'='*70}") |
| | print(f"BLEU Score: {bleu:.2f}") |
| | print(f"{'='*70}\n") |
| | |
| | return bleu, translations |
| |
|
| | |
| | |
| | |
| |
|
| | def print_sample_translations(translations, num_samples=10): |
| | """ |
| | In một số ví dụ dịch |
| | |
| | Args: |
| | translations: List of (source, reference, hypothesis) tuples |
| | num_samples: Số lượng samples để in |
| | """ |
| | print(f"\n{'='*70}") |
| | print(f"MỘT SỐ VÍ DỤ DỊCH") |
| | print(f"{'='*70}\n") |
| | |
| | for i, (src, ref, hyp) in enumerate(translations[:num_samples], 1): |
| | print(f"Ví dụ {i}:") |
| | print(f" Source: {src}") |
| | print(f" Reference: {ref}") |
| | print(f" Hypothesis: {hyp}") |
| | print() |
| |
|
| | |
| | |
| | |
| |
|
| | def translate_sentence(model, sentence, src_vocab, tgt_vocab, device, |
| | use_beam_search=True, beam_size=5, max_len=100, src_lang='vi', |
| | repetition_penalty=1.5, no_repeat_ngram_size=3): |
| | """ |
| | Dịch một câu đơn |
| | |
| | Args: |
| | model: Transformer model |
| | sentence: Source sentence (string) |
| | src_vocab: Source vocabulary |
| | tgt_vocab: Target vocabulary |
| | device: Device |
| | use_beam_search: Sử dụng beam search hay greedy |
| | beam_size: Beam size |
| | max_len: Maximum length |
| | src_lang: Ngôn ngữ source ('vi' hoặc 'en') |
| | |
| | Returns: |
| | translation: Translated sentence |
| | """ |
| | model.eval() |
| | |
| | |
| | from data_preprocessing import clean_text |
| | sentence = clean_text(sentence, src_lang) |
| | |
| | |
| | tokens = src_vocab.encode(sentence) |
| | src = torch.LongTensor([tokens]).to(device) |
| | |
| | |
| | if use_beam_search: |
| | _, translation = beam_search_decode( |
| | model, src, src_vocab, tgt_vocab, |
| | device, beam_size, max_len, alpha=0.6, |
| | repetition_penalty=repetition_penalty, |
| | no_repeat_ngram_size=no_repeat_ngram_size |
| | ) |
| | else: |
| | _, translation = greedy_decode( |
| | model, src, src_vocab, tgt_vocab, |
| | device, max_len, |
| | repetition_penalty=repetition_penalty, |
| | no_repeat_ngram_size=no_repeat_ngram_size |
| | ) |
| | |
| | return translation |
| |
|
| | |
| | |
| | |
| |
|
| | def interactive_translation(model, src_vocab, tgt_vocab, device, |
| | use_beam_search=True, beam_size=5, direction='vi2en'): |
| | """ |
| | Chế độ dịch tương tác |
| | |
| | Args: |
| | model: Transformer model |
| | src_vocab: Source vocabulary |
| | tgt_vocab: Target vocabulary |
| | device: Device |
| | use_beam_search: Sử dụng beam search |
| | beam_size: Beam size |
| | direction: 'vi2en' hoặc 'en2vi' |
| | """ |
| | print("\n" + "="*70) |
| | print("CHẾ ĐỘ DỊCH TƯƠNG TÁC") |
| | print("="*70) |
| | if direction == 'vi2en': |
| | print("Nhập câu tiếng Việt để dịch sang tiếng Anh") |
| | else: |
| | print("Nhập câu tiếng Anh để dịch sang tiếng Việt") |
| | print("Gõ 'quit' hoặc 'exit' để thoát") |
| | print("="*70 + "\n") |
| | |
| | src_lang = 'en' if direction == 'en2vi' else 'vi' |
| | src_label = "Tiếng Anh" if direction == 'en2vi' else "Tiếng Việt" |
| | tgt_label = "Tiếng Việt" if direction == 'en2vi' else "Tiếng Anh" |
| | |
| | while True: |
| | sentence = input(f"{src_label}: ").strip() |
| | |
| | if sentence.lower() in ['quit', 'exit', '']: |
| | print("Tạm biệt!") |
| | break |
| | |
| | translation = translate_sentence( |
| | model, sentence, src_vocab, tgt_vocab, device, |
| | use_beam_search, beam_size, src_lang=src_lang, max_len=100 |
| | ) |
| | |
| | print(f"{tgt_label}: {translation}\n") |
| |
|
| | |
| | |
| | |
| |
|
| | def save_translations(translations, output_file='translations.txt'): |
| | """ |
| | Lưu translations ra file |
| | |
| | Args: |
| | translations: List of (source, reference, hypothesis) tuples |
| | output_file: Output file path |
| | """ |
| | with open(output_file, 'w', encoding='utf-8') as f: |
| | for i, (src, ref, hyp) in enumerate(translations, 1): |
| | f.write(f"Example {i}:\n") |
| | f.write(f"Source: {src}\n") |
| | f.write(f"Reference: {ref}\n") |
| | f.write(f"Hypothesis: {hyp}\n") |
| | f.write("\n") |
| | |
| | print(f"✓ Saved translations to {output_file}") |