File size: 8,214 Bytes
c093feb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import torch
import torch.nn as nn
from pathlib import Path
import math
import logging
import re

# --- Setup ---
# Configure logging to be minimal for inference
logging.basicConfig(level=logging.INFO, format='%(message)s')

# --- Configuration (Must match the training script) ---
CONFIG = {
    "SRC_LANG": "en",
    "TGT_LANG": "zh",
    "TOKENIZER_FILE": "opus_en_zh_tokenizer.json",
    "MAX_SEQ_LEN": 128,
    "DIM": 256,
    "ENCODER_LAYERS": 4,
    "DECODER_LAYERS": 4,
    "N_HEADS": 8,
    "FF_DIM": 512,
    "DROPOUT": 0.1,
    "CHECKPOINT_DIR": "checkpoints_translation",
}


class PositionalEncoding(nn.Module):
    def __init__(self, dim, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
        pe = torch.zeros(max_len, 1, dim)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TranslationTransformer(nn.Module):
    def __init__(self, vocab_size, dim, n_heads, encoder_layers, decoder_layers, ff_dim, dropout, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.pos_encoder = PositionalEncoding(dim, dropout, max_len)
        self.transformer = nn.Transformer(
            d_model=dim, nhead=n_heads, num_encoder_layers=encoder_layers,
            num_decoder_layers=decoder_layers, dim_feedforward=ff_dim,
            dropout=dropout, batch_first=True
        )
        self.generator = nn.Linear(dim, vocab_size)

    def _generate_mask(self, src, tgt, pad_id):
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[1], device=tgt.device)
        src_padding_mask = (src == pad_id)
        tgt_padding_mask = (tgt == pad_id)
        return tgt_mask, src_padding_mask, tgt_padding_mask

    def forward(self, src, tgt, pad_id):
        src_emb = self.pos_encoder((self.embedding(src) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
        tgt_emb = self.pos_encoder((self.embedding(tgt) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
        tgt_mask, src_padding_mask, tgt_padding_mask = self._generate_mask(src, tgt, pad_id)
        output = self.transformer(
            src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask
        )
        return self.generator(output)

# We need to import the Tokenizer class to load the tokenizer file
from tokenizers import Tokenizer

class Translator:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info(f"Using device: {self.device}")

        # Load the trained tokenizer
        tokenizer_path = Path(self.config["TOKENIZER_FILE"])
        if not tokenizer_path.exists():
            raise FileNotFoundError(f"Tokenizer file not found at {tokenizer_path}. Please run the training script first.")
        self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
        # Get special token IDs
        self.bos_id = self.tokenizer.token_to_id("<s>")
        self.eos_id = self.tokenizer.token_to_id("</s>")
        self.pad_id = self.tokenizer.token_to_id("<pad>")

        # Initialize the model structure
        self.model = TranslationTransformer(
            vocab_size=self.tokenizer.get_vocab_size(),
            dim=self.config["DIM"], n_heads=self.config["N_HEADS"],
            encoder_layers=self.config["ENCODER_LAYERS"], decoder_layers=self.config["DECODER_LAYERS"],
            ff_dim=self.config["FF_DIM"], dropout=self.config["DROPOUT"], max_len=self.config["MAX_SEQ_LEN"]
        )
        self.model.to(self.device)

    def load_best_checkpoint(self):
        """Finds and loads the checkpoint with the lowest validation loss."""
        checkpoint_dir = Path(self.config["CHECKPOINT_DIR"])
        if not checkpoint_dir.exists():
            raise FileNotFoundError(f"Checkpoint directory not found at {checkpoint_dir}.")

        best_loss = float('inf')
        best_checkpoint_path = None
        
        for chk_path in checkpoint_dir.glob("*.pt"):
            # Use regex to find the validation loss in the filename
            match = re.search(r'valloss_([\d.]+)\.pt', chk_path.name)
            if match:
                val_loss = float(match.group(1))
                if val_loss < best_loss:
                    best_loss = val_loss
                    best_checkpoint_path = chk_path
        
        if best_checkpoint_path is None:
            raise FileNotFoundError(f"No valid checkpoints found in {checkpoint_dir}. Checkpoint names must be like '...valloss_x.xxxx.pt'.")

        logging.info(f"Loading best model from: {best_checkpoint_path} (Validation Loss: {best_loss:.4f})")
        checkpoint = torch.load(best_checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        
        # Set the model to evaluation mode. This is crucial!
        # It disables layers like Dropout for consistent inference.
        self.model.eval()

    def translate(self, src_sentence: str):
        """Translates a single English sentence to Chinese using greedy decoding."""
        if not src_sentence.strip():
            return ""

        # Prepare the input
        src_tokens = [self.bos_id] + self.tokenizer.encode(src_sentence).ids + [self.eos_id]
        src = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
        
        # Start decoding
        tgt_tokens = [self.bos_id]
        
        with torch.no_grad(): # Disable gradient calculation for efficiency
            for _ in range(self.config["MAX_SEQ_LEN"]):
                tgt_input = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
                
                # Get model predictions
                logits = self.model(src, tgt_input, self.pad_id)
                
                # Get the most likely next token (greedy decoding)
                next_token_id = logits[:, -1, :].argmax(dim=-1).item()
                tgt_tokens.append(next_token_id)
                
                # Stop if the end-of-sentence token is generated
                if next_token_id == self.eos_id:
                    break
        
        # Decode the generated token IDs back to a string
        translated_text = self.tokenizer.decode(tgt_tokens, skip_special_tokens=True)
        return translated_text

def interactive_session():
    """Runs the main interactive translation loop."""
    try:
        translator = Translator(CONFIG)
        translator.load_best_checkpoint()
    except FileNotFoundError as e:
        logging.error(f"Error initializing translator: {e}")
        logging.error("Please make sure you have run the training script and have a valid tokenizer and checkpoint file.")
        return

    print("\n--- ZHEN - 1 Translator ---")
    print("Type an English sentence and press Enter.")
    print("Type 'quit' or 'exit' to close the program.")
    
    while True:
        try:
            source_text = input("\nEnglish > ")
            if source_text.lower() in ['quit', 'exit', 'q']:
                print("Exiting translator. Goodbye!")
                break
            
            if not source_text:
                continue

            translated_text = translator.translate(source_text)
            print(f"Chinese < {translated_text}")

        except KeyboardInterrupt:
            print("\nExiting translator. Goodbye!")
            break
        except Exception as e:
            logging.error(f"An unexpected error occurred: {e}")


if __name__ == "__main__":
    interactive_session()