""" INFERENCE & EVALUATION Greedy Search, Beam Search, BLEU Score evaluation """ import torch import torch.nn.functional as F from collections import Counter import math from tqdm import tqdm # ============================================================================ # 1. GREEDY SEARCH DECODING # ============================================================================ def greedy_decode(model, src, src_vocab, tgt_vocab, device, max_len=100, repetition_penalty=1.5, no_repeat_ngram_size=3): """ Greedy Decoding - Chọn token có xác suất cao nhất mỗi bước với Repetition Penalty Args: model: Transformer model src: Source sequence [1, src_len] hoặc list of tokens src_vocab: Source vocabulary tgt_vocab: Target vocabulary device: Device max_len: Maximum length để generate repetition_penalty: Penalty factor cho tokens lặp lại (>1.0 để giảm repetition) no_repeat_ngram_size: Kích thước n-gram để tránh lặp lại (0 = tắt) Returns: decoded_tokens: List of decoded token indices decoded_sentence: Decoded sentence (string) """ model.eval() # Nếu src là list, convert sang tensor if isinstance(src, list): src = torch.LongTensor([src]).to(device) elif src.dim() == 1: src = src.unsqueeze(0) with torch.no_grad(): # Encode source encoder_output, src_mask = model.encode(src) # Khởi tạo target với token tgt_tokens = [tgt_vocab.SOS_IDX] for step in range(max_len): # Tạo target tensor tgt = torch.LongTensor([tgt_tokens]).to(device) # Decode output = model.decode(tgt, encoder_output, src_mask) # Lấy prediction cho token cuối cùng next_token_logits = output[0, -1, :] # Áp dụng repetition penalty cho từng token token_counts = {} for token_id in tgt_tokens: if token_id not in [tgt_vocab.SOS_IDX, tgt_vocab.EOS_IDX, tgt_vocab.PAD_IDX]: token_counts[token_id] = token_counts.get(token_id, 0) + 1 # Áp dụng penalty: giảm logit của các tokens đã xuất hiện for token_id, count in token_counts.items(): if count > 0 and token_id < len(next_token_logits): # Penalty tăng theo số lần lặp lại penalty = repetition_penalty ** count next_token_logits[token_id] = next_token_logits[token_id] / penalty # N-gram repetition penalty: tránh lặp cụm từ if no_repeat_ngram_size > 0 and len(tgt_tokens) >= no_repeat_ngram_size: # Lấy n-gram cuối cùng last_ngram = tuple(tgt_tokens[-(no_repeat_ngram_size-1):]) # Kiểm tra xem n-gram này đã xuất hiện trước đó chưa for i in range(len(tgt_tokens) - no_repeat_ngram_size + 1): ngram = tuple(tgt_tokens[i:i+no_repeat_ngram_size-1]) if ngram == last_ngram: # Nếu n-gram đã xuất hiện, giảm logit của token tiếp theo if i + no_repeat_ngram_size - 1 < len(tgt_tokens): repeated_token = tgt_tokens[i + no_repeat_ngram_size - 1] if repeated_token < len(next_token_logits): next_token_logits[repeated_token] = next_token_logits[repeated_token] / (repetition_penalty ** 2) # Lấy token có xác suất cao nhất next_token = next_token_logits.argmax().item() # Thêm vào sequence tgt_tokens.append(next_token) # Dừng nếu gặp if next_token == tgt_vocab.EOS_IDX: break # Decode thành sentence decoded_sentence = tgt_vocab.decode(tgt_tokens) return tgt_tokens, decoded_sentence # ============================================================================ # 2. BEAM SEARCH DECODING # ============================================================================ class BeamSearchNode: """ Node trong Beam Search """ def __init__(self, tokens, log_prob, length): self.tokens = tokens self.log_prob = log_prob self.length = length def eval(self, alpha=0.6): """ Tính score với length normalization Score = log_prob / (length^alpha) Args: alpha: Length penalty factor """ return self.log_prob / (self.length ** alpha) def beam_search_decode(model, src, src_vocab, tgt_vocab, device, beam_size=5, max_len=100, alpha=0.6, repetition_penalty=1.5, no_repeat_ngram_size=3): """ Beam Search Decoding - Giữ top-k candidates tốt nhất với Repetition Penalty Args: model: Transformer model src: Source sequence src_vocab: Source vocabulary tgt_vocab: Target vocabulary device: Device beam_size: Beam size (số lượng candidates) max_len: Maximum length alpha: Length penalty factor repetition_penalty: Penalty factor cho tokens lặp lại (>1.0 để giảm repetition) no_repeat_ngram_size: Kích thước n-gram để tránh lặp lại (0 = tắt) Returns: best_tokens: List of best token indices best_sentence: Best decoded sentence """ model.eval() # Nếu src là list, convert sang tensor if isinstance(src, list): src = torch.LongTensor([src]).to(device) elif src.dim() == 1: src = src.unsqueeze(0) with torch.no_grad(): # Encode source encoder_output, src_mask = model.encode(src) # Khởi tạo beam với token beams = [BeamSearchNode( tokens=[tgt_vocab.SOS_IDX], log_prob=0.0, length=1 )] completed_beams = [] for step in range(max_len): candidates = [] for beam in beams: # Nếu đã kết thúc, thêm vào completed if beam.tokens[-1] == tgt_vocab.EOS_IDX: completed_beams.append(beam) continue # Tạo target tensor tgt = torch.LongTensor([beam.tokens]).to(device) # Decode output = model.decode(tgt, encoder_output, src_mask) # Lấy log probabilities cho token cuối next_token_logits = output[0, -1, :] # Áp dụng repetition penalty # Đếm số lần xuất hiện của mỗi token trong sequence hiện tại token_counts = {} for token_id in beam.tokens: if token_id not in [tgt_vocab.SOS_IDX, tgt_vocab.EOS_IDX, tgt_vocab.PAD_IDX]: token_counts[token_id] = token_counts.get(token_id, 0) + 1 # Áp dụng penalty: giảm logit của các tokens đã xuất hiện (mạnh hơn) for token_id, count in token_counts.items(): if count > 0 and token_id < len(next_token_logits): # Penalty tăng mạnh theo số lần lặp lại # count=1: penalty nhẹ, count=2+: penalty rất mạnh penalty = repetition_penalty ** (count * 1.5) # Tăng penalty mạnh hơn next_token_logits[token_id] = next_token_logits[token_id] / penalty # N-gram repetition penalty: tránh lặp cụm từ if no_repeat_ngram_size > 0 and len(beam.tokens) >= no_repeat_ngram_size: # Lấy n-gram cuối cùng last_ngram = tuple(beam.tokens[-(no_repeat_ngram_size-1):]) # Kiểm tra xem n-gram này đã xuất hiện trước đó chưa for i in range(len(beam.tokens) - no_repeat_ngram_size + 1): ngram = tuple(beam.tokens[i:i+no_repeat_ngram_size-1]) if ngram == last_ngram: # Nếu n-gram đã xuất hiện, giảm logit của token tiếp theo rất mạnh if i + no_repeat_ngram_size - 1 < len(beam.tokens): repeated_token = beam.tokens[i + no_repeat_ngram_size - 1] if repeated_token < len(next_token_logits): # Penalty rất mạnh cho n-gram repetition next_token_logits[repeated_token] = next_token_logits[repeated_token] / (repetition_penalty ** 3) log_probs = F.log_softmax(next_token_logits, dim=-1) # Lấy top-k tokens top_log_probs, top_tokens = torch.topk(log_probs, beam_size) # Tạo candidates mới for log_prob, token in zip(top_log_probs, top_tokens): new_beam = BeamSearchNode( tokens=beam.tokens + [token.item()], log_prob=beam.log_prob + log_prob.item(), length=beam.length + 1 ) candidates.append(new_beam) # Nếu không còn candidates, dừng if not candidates: break # Chọn top beam_size candidates tốt nhất beams = sorted(candidates, key=lambda x: x.eval(alpha), reverse=True)[:beam_size] # Nếu tất cả beams đã complete, dừng if len(completed_beams) >= beam_size: break # Thêm các beams chưa complete vào completed completed_beams.extend(beams) # Chọn beam tốt nhất best_beam = max(completed_beams, key=lambda x: x.eval(alpha)) # Decode thành sentence best_sentence = tgt_vocab.decode(best_beam.tokens) return best_beam.tokens, best_sentence # ============================================================================ # 3. BLEU SCORE CALCULATION # ============================================================================ def calculate_ngrams(tokens, n): """ Tính n-grams từ list of tokens Args: tokens: List of tokens n: n-gram size Returns: ngrams: Counter of n-grams """ ngrams = [] for i in range(len(tokens) - n + 1): ngram = tuple(tokens[i:i+n]) ngrams.append(ngram) return Counter(ngrams) def calculate_bleu_score(references, hypotheses, max_n=4, weights=None): """ Tính BLEU score BLEU = BP * exp(sum(w_n * log(p_n))) Args: references: List of reference sentences (list of token lists) hypotheses: List of hypothesis sentences (list of token lists) max_n: Maximum n-gram size (mặc định 4) weights: Weights cho mỗi n-gram (mặc định uniform) Returns: bleu_score: BLEU score (0-100) """ if weights is None: weights = [1.0/max_n] * max_n # Tính precision cho mỗi n-gram precisions = [] for n in range(1, max_n + 1): matched = 0 total = 0 for ref, hyp in zip(references, hypotheses): # Tính n-grams ref_ngrams = calculate_ngrams(ref, n) hyp_ngrams = calculate_ngrams(hyp, n) # Đếm matched n-grams for ngram, count in hyp_ngrams.items(): matched += min(count, ref_ngrams.get(ngram, 0)) total += max(len(hyp) - n + 1, 0) # Tính precision if total > 0: precision = matched / total else: precision = 0 precisions.append(precision) # Tính Brevity Penalty (BP) ref_length = sum(len(ref) for ref in references) hyp_length = sum(len(hyp) for hyp in hypotheses) if hyp_length > ref_length: bp = 1.0 elif hyp_length == 0: bp = 0.0 else: bp = math.exp(1 - ref_length / hyp_length) # Tính BLEU score if min(precisions) > 0: log_precisions = [w * math.log(p) for w, p in zip(weights, precisions)] bleu = bp * math.exp(sum(log_precisions)) else: bleu = 0.0 return bleu * 100 # Convert sang 0-100 scale # ============================================================================ # 4. EVALUATE ON TEST SET # ============================================================================ def evaluate_model(model, test_loader, src_vocab, tgt_vocab, device, use_beam_search=True, beam_size=5, max_len=100): """ Đánh giá model trên test set Args: model: Transformer model test_loader: Test DataLoader src_vocab: Source vocabulary tgt_vocab: Target vocabulary device: Device use_beam_search: Sử dụng beam search hay greedy search beam_size: Beam size (nếu dùng beam search) max_len: Maximum decode length Returns: bleu_score: BLEU score translations: List of (source, reference, hypothesis) tuples """ model.eval() references = [] hypotheses = [] translations = [] print(f"\n{'='*70}") print(f"ĐÁNH GIÁ TRÊN TEST SET") print(f"{'='*70}") print(f"Decoding method: {'Beam Search' if use_beam_search else 'Greedy Search'}") if use_beam_search: print(f"Beam size: {beam_size}") print(f"{'='*70}\n") with torch.no_grad(): for src, tgt, _, _ in tqdm(test_loader, desc='Evaluating'): src = src.to(device) batch_size = src.size(0) for i in range(batch_size): src_seq = src[i] tgt_seq = tgt[i] # Decode if use_beam_search: _, hypothesis = beam_search_decode( model, src_seq, src_vocab, tgt_vocab, device, beam_size, max_len, alpha=0.6, repetition_penalty=1.5, no_repeat_ngram_size=3 ) else: _, hypothesis = greedy_decode( model, src_seq, src_vocab, tgt_vocab, device, max_len, repetition_penalty=1.5, no_repeat_ngram_size=3 ) # Reference reference = tgt_vocab.decode(tgt_seq.tolist()) # Source source = src_vocab.decode(src_seq.tolist()) # Tokenize để tính BLEU ref_tokens = reference.split() hyp_tokens = hypothesis.split() references.append(ref_tokens) hypotheses.append(hyp_tokens) translations.append((source, reference, hypothesis)) # Tính BLEU score bleu = calculate_bleu_score(references, hypotheses) print(f"\n{'='*70}") print(f"KẾT QUẢ ĐÁNH GIÁ") print(f"{'='*70}") print(f"BLEU Score: {bleu:.2f}") print(f"{'='*70}\n") return bleu, translations # ============================================================================ # 5. PRINT SAMPLE TRANSLATIONS # ============================================================================ def print_sample_translations(translations, num_samples=10): """ In một số ví dụ dịch Args: translations: List of (source, reference, hypothesis) tuples num_samples: Số lượng samples để in """ print(f"\n{'='*70}") print(f"MỘT SỐ VÍ DỤ DỊCH") print(f"{'='*70}\n") for i, (src, ref, hyp) in enumerate(translations[:num_samples], 1): print(f"Ví dụ {i}:") print(f" Source: {src}") print(f" Reference: {ref}") print(f" Hypothesis: {hyp}") print() # ============================================================================ # 6. TRANSLATE SINGLE SENTENCE # ============================================================================ def translate_sentence(model, sentence, src_vocab, tgt_vocab, device, use_beam_search=True, beam_size=5, max_len=100, src_lang='vi', repetition_penalty=1.5, no_repeat_ngram_size=3): """ Dịch một câu đơn Args: model: Transformer model sentence: Source sentence (string) src_vocab: Source vocabulary tgt_vocab: Target vocabulary device: Device use_beam_search: Sử dụng beam search hay greedy beam_size: Beam size max_len: Maximum length src_lang: Ngôn ngữ source ('vi' hoặc 'en') Returns: translation: Translated sentence """ model.eval() # Tiền xử lý câu from data_preprocessing import clean_text sentence = clean_text(sentence, src_lang) # Encode tokens = src_vocab.encode(sentence) src = torch.LongTensor([tokens]).to(device) # Decode if use_beam_search: _, translation = beam_search_decode( model, src, src_vocab, tgt_vocab, device, beam_size, max_len, alpha=0.6, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size ) else: _, translation = greedy_decode( model, src, src_vocab, tgt_vocab, device, max_len, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size ) return translation # ============================================================================ # 7. INTERACTIVE TRANSLATION # ============================================================================ def interactive_translation(model, src_vocab, tgt_vocab, device, use_beam_search=True, beam_size=5, direction='vi2en'): """ Chế độ dịch tương tác Args: model: Transformer model src_vocab: Source vocabulary tgt_vocab: Target vocabulary device: Device use_beam_search: Sử dụng beam search beam_size: Beam size direction: 'vi2en' hoặc 'en2vi' """ print("\n" + "="*70) print("CHẾ ĐỘ DỊCH TƯƠNG TÁC") print("="*70) if direction == 'vi2en': print("Nhập câu tiếng Việt để dịch sang tiếng Anh") else: print("Nhập câu tiếng Anh để dịch sang tiếng Việt") print("Gõ 'quit' hoặc 'exit' để thoát") print("="*70 + "\n") src_lang = 'en' if direction == 'en2vi' else 'vi' src_label = "Tiếng Anh" if direction == 'en2vi' else "Tiếng Việt" tgt_label = "Tiếng Việt" if direction == 'en2vi' else "Tiếng Anh" while True: sentence = input(f"{src_label}: ").strip() if sentence.lower() in ['quit', 'exit', '']: print("Tạm biệt!") break translation = translate_sentence( model, sentence, src_vocab, tgt_vocab, device, use_beam_search, beam_size, src_lang=src_lang, max_len=100 ) print(f"{tgt_label}: {translation}\n") # ============================================================================ # 8. SAVE TRANSLATIONS # ============================================================================ def save_translations(translations, output_file='translations.txt'): """ Lưu translations ra file Args: translations: List of (source, reference, hypothesis) tuples output_file: Output file path """ with open(output_file, 'w', encoding='utf-8') as f: for i, (src, ref, hyp) in enumerate(translations, 1): f.write(f"Example {i}:\n") f.write(f"Source: {src}\n") f.write(f"Reference: {ref}\n") f.write(f"Hypothesis: {hyp}\n") f.write("\n") print(f"✓ Saved translations to {output_file}")