|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from pathlib import Path
|
|
|
import math
|
|
|
import logging
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
|
|
|
|
|
|
|
CONFIG = {
|
|
|
"SRC_LANG": "en",
|
|
|
"TGT_LANG": "zh",
|
|
|
"TOKENIZER_FILE": "opus_en_zh_tokenizer.json",
|
|
|
"MAX_SEQ_LEN": 128,
|
|
|
"DIM": 256,
|
|
|
"ENCODER_LAYERS": 4,
|
|
|
"DECODER_LAYERS": 4,
|
|
|
"N_HEADS": 8,
|
|
|
"FF_DIM": 512,
|
|
|
"DROPOUT": 0.1,
|
|
|
"CHECKPOINT_DIR": "checkpoints_translation",
|
|
|
}
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
def __init__(self, dim, dropout, max_len=5000):
|
|
|
super().__init__()
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
position = torch.arange(max_len).unsqueeze(1)
|
|
|
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
|
|
|
pe = torch.zeros(max_len, 1, dim)
|
|
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
|
|
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
|
|
self.register_buffer('pe', pe)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x + self.pe[:x.size(0)]
|
|
|
return self.dropout(x)
|
|
|
|
|
|
class TranslationTransformer(nn.Module):
|
|
|
def __init__(self, vocab_size, dim, n_heads, encoder_layers, decoder_layers, ff_dim, dropout, max_len):
|
|
|
super().__init__()
|
|
|
self.embedding = nn.Embedding(vocab_size, dim)
|
|
|
self.pos_encoder = PositionalEncoding(dim, dropout, max_len)
|
|
|
self.transformer = nn.Transformer(
|
|
|
d_model=dim, nhead=n_heads, num_encoder_layers=encoder_layers,
|
|
|
num_decoder_layers=decoder_layers, dim_feedforward=ff_dim,
|
|
|
dropout=dropout, batch_first=True
|
|
|
)
|
|
|
self.generator = nn.Linear(dim, vocab_size)
|
|
|
|
|
|
def _generate_mask(self, src, tgt, pad_id):
|
|
|
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[1], device=tgt.device)
|
|
|
src_padding_mask = (src == pad_id)
|
|
|
tgt_padding_mask = (tgt == pad_id)
|
|
|
return tgt_mask, src_padding_mask, tgt_padding_mask
|
|
|
|
|
|
def forward(self, src, tgt, pad_id):
|
|
|
src_emb = self.pos_encoder((self.embedding(src) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
|
|
|
tgt_emb = self.pos_encoder((self.embedding(tgt) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
|
|
|
tgt_mask, src_padding_mask, tgt_padding_mask = self._generate_mask(src, tgt, pad_id)
|
|
|
output = self.transformer(
|
|
|
src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask,
|
|
|
tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask
|
|
|
)
|
|
|
return self.generator(output)
|
|
|
|
|
|
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
|
class Translator:
|
|
|
def __init__(self, config):
|
|
|
self.config = config
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
logging.info(f"Using device: {self.device}")
|
|
|
|
|
|
|
|
|
tokenizer_path = Path(self.config["TOKENIZER_FILE"])
|
|
|
if not tokenizer_path.exists():
|
|
|
raise FileNotFoundError(f"Tokenizer file not found at {tokenizer_path}. Please run the training script first.")
|
|
|
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
|
|
|
|
|
|
|
|
self.bos_id = self.tokenizer.token_to_id("<s>")
|
|
|
self.eos_id = self.tokenizer.token_to_id("</s>")
|
|
|
self.pad_id = self.tokenizer.token_to_id("<pad>")
|
|
|
|
|
|
|
|
|
self.model = TranslationTransformer(
|
|
|
vocab_size=self.tokenizer.get_vocab_size(),
|
|
|
dim=self.config["DIM"], n_heads=self.config["N_HEADS"],
|
|
|
encoder_layers=self.config["ENCODER_LAYERS"], decoder_layers=self.config["DECODER_LAYERS"],
|
|
|
ff_dim=self.config["FF_DIM"], dropout=self.config["DROPOUT"], max_len=self.config["MAX_SEQ_LEN"]
|
|
|
)
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
def load_best_checkpoint(self):
|
|
|
"""Finds and loads the checkpoint with the lowest validation loss."""
|
|
|
checkpoint_dir = Path(self.config["CHECKPOINT_DIR"])
|
|
|
if not checkpoint_dir.exists():
|
|
|
raise FileNotFoundError(f"Checkpoint directory not found at {checkpoint_dir}.")
|
|
|
|
|
|
best_loss = float('inf')
|
|
|
best_checkpoint_path = None
|
|
|
|
|
|
for chk_path in checkpoint_dir.glob("*.pt"):
|
|
|
|
|
|
match = re.search(r'valloss_([\d.]+)\.pt', chk_path.name)
|
|
|
if match:
|
|
|
val_loss = float(match.group(1))
|
|
|
if val_loss < best_loss:
|
|
|
best_loss = val_loss
|
|
|
best_checkpoint_path = chk_path
|
|
|
|
|
|
if best_checkpoint_path is None:
|
|
|
raise FileNotFoundError(f"No valid checkpoints found in {checkpoint_dir}. Checkpoint names must be like '...valloss_x.xxxx.pt'.")
|
|
|
|
|
|
logging.info(f"Loading best model from: {best_checkpoint_path} (Validation Loss: {best_loss:.4f})")
|
|
|
checkpoint = torch.load(best_checkpoint_path, map_location=self.device)
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
def translate(self, src_sentence: str):
|
|
|
"""Translates a single English sentence to Chinese using greedy decoding."""
|
|
|
if not src_sentence.strip():
|
|
|
return ""
|
|
|
|
|
|
|
|
|
src_tokens = [self.bos_id] + self.tokenizer.encode(src_sentence).ids + [self.eos_id]
|
|
|
src = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
|
|
|
|
|
|
|
|
tgt_tokens = [self.bos_id]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(self.config["MAX_SEQ_LEN"]):
|
|
|
tgt_input = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
|
|
|
|
|
|
|
|
logits = self.model(src, tgt_input, self.pad_id)
|
|
|
|
|
|
|
|
|
next_token_id = logits[:, -1, :].argmax(dim=-1).item()
|
|
|
tgt_tokens.append(next_token_id)
|
|
|
|
|
|
|
|
|
if next_token_id == self.eos_id:
|
|
|
break
|
|
|
|
|
|
|
|
|
translated_text = self.tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
|
|
return translated_text
|
|
|
|
|
|
def interactive_session():
|
|
|
"""Runs the main interactive translation loop."""
|
|
|
try:
|
|
|
translator = Translator(CONFIG)
|
|
|
translator.load_best_checkpoint()
|
|
|
except FileNotFoundError as e:
|
|
|
logging.error(f"Error initializing translator: {e}")
|
|
|
logging.error("Please make sure you have run the training script and have a valid tokenizer and checkpoint file.")
|
|
|
return
|
|
|
|
|
|
print("\n--- ZHEN - 1 Translator ---")
|
|
|
print("Type an English sentence and press Enter.")
|
|
|
print("Type 'quit' or 'exit' to close the program.")
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
source_text = input("\nEnglish > ")
|
|
|
if source_text.lower() in ['quit', 'exit', 'q']:
|
|
|
print("Exiting translator. Goodbye!")
|
|
|
break
|
|
|
|
|
|
if not source_text:
|
|
|
continue
|
|
|
|
|
|
translated_text = translator.translate(source_text)
|
|
|
print(f"Chinese < {translated_text}")
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
print("\nExiting translator. Goodbye!")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logging.error(f"An unexpected error occurred: {e}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
interactive_session() |