DarcyCheng commited on
Commit
c4a5de3
Β·
verified Β·
1 Parent(s): a10de78

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +219 -0
README.md ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RNN-based Neural Machine Translation (NMT)
2
+
3
+ A PyTorch implementation of RNN-based Neural Machine Translation system for Chinese-to-English translation, featuring LSTM encoder-decoder architecture with attention mechanisms.
4
+
5
+ ## Introduction
6
+
7
+ This repository implements a RNN-based Neural Machine Translation system with the following key components:
8
+
9
+ **Model**: Implement a model using LSTM, with both the encoder and decoder consisting of unidirectional layers.
10
+
11
+ **Attention mechanism**: Implement the attention mechanism and investigate the impact of different alignment functionsβ€”such as dot-product, multiplicative, and additiveβ€”on model performance.
12
+
13
+ **Training policy**: Compare the effectiveness of Teacher Forcing and Free Running strategies.
14
+
15
+ **Decoding policy**: Compare the effectiveness of greedy and beam-search decoding strategies.
16
+
17
+ ### Key Features
18
+
19
+ - **Encoder**: Unidirectional LSTM encoder for source language (Chinese)
20
+ - **Decoder**: Unidirectional LSTM decoder with attention mechanism for target language (English)
21
+ - **Attention Types**:
22
+ - Dot-product attention
23
+ - Multiplicative attention
24
+ - Additive attention (Bahdanau-style)
25
+ - **Tokenization**:
26
+ - Chinese: Jieba word segmentation
27
+ - English: SentencePiece subword tokenization
28
+ - **Training Strategies**:
29
+ - Teacher Forcing (configurable ratio)
30
+ - Free Running
31
+ - **Decoding Strategies**:
32
+ - Greedy decoding
33
+ - Beam search decoding (configurable beam size)
34
+
35
+ ## Data Preparation
36
+
37
+ The compressed package contains four JSONL files, corresponding respectively to the small training set, large training set, validation set, and test set, with sizes of 100k, 10k, 500, and 200 samples. Each line in a JSONL file contains one parallel sentence pair. The final model performance will be evaluated based on results on the test set.
38
+
39
+ ### Data Format
40
+
41
+ Each line in the JSONL files follows this format:
42
+ ```json
43
+ {"chinese": "δΈ­ζ–‡ε₯子", "english": "English sentence"}
44
+ ```
45
+
46
+ ### Data Directory Structure
47
+
48
+ ```
49
+ translation_dataset_zh_en/
50
+ β”œβ”€β”€ train_small.jsonl # 100k samples
51
+ β”œβ”€β”€ train_large.jsonl # 10k samples
52
+ β”œβ”€β”€ dev.jsonl # 500 samples
53
+ └── test.jsonl # 200 samples
54
+ ```
55
+
56
+ ### Preprocessing
57
+
58
+ The data preprocessing pipeline includes:
59
+ 1. Chinese text segmentation using Jieba
60
+ 2. English text tokenization using SentencePiece
61
+ 3. Vocabulary construction with frequency cutoff
62
+ 4. Sentence padding and batching
63
+
64
+ ## Environment
65
+
66
+ ### Requirements
67
+
68
+ - **Python**: Python 3.9.25
69
+ - **PyTorch**: torch 2.0.1+cu118 (or compatible version)
70
+ - **CUDA**: CUDA 11.8 (optional, for GPU acceleration)
71
+
72
+ ### Installation
73
+
74
+ 1. Clone the repository:
75
+ ```bash
76
+ git clone <repository-url>
77
+ cd RNN_NMT
78
+ ```
79
+
80
+ 2. Install dependencies:
81
+ ```bash
82
+ pip install -r requirement.txt
83
+ ```
84
+
85
+ 3. Download NLTK data (required for BLEU score calculation):
86
+ ```python
87
+ import nltk
88
+ nltk.download('punkt')
89
+ ```
90
+
91
+ ### Dependencies
92
+
93
+ Key dependencies include:
94
+ - `torch>=1.12.0` - Deep learning framework
95
+ - `numpy>=1.21.0` - Numerical computing
96
+ - `hydra-core>=1.3.0` - Configuration management
97
+ - `omegaconf>=2.2.0` - Configuration objects
98
+ - `sentencepiece>=0.1.96` - English subword tokenization
99
+ - `jieba>=0.42.1` - Chinese word segmentation
100
+ - `nltk>=3.7` - BLEU score evaluation
101
+ - `tqdm>=4.62.0` - Progress bars
102
+
103
+ ## Training and Evaluation
104
+
105
+ ### Training
106
+
107
+ Train the model using the default configuration:
108
+
109
+ ```bash
110
+ python train.py
111
+ ```
112
+
113
+ The training script uses Hydra for configuration management. You can override configuration parameters via command line:
114
+
115
+ ```bash
116
+ python train.py attention_type=additive teacher_forcing_ratio=0.7 decoding_strategy=beam-search beam_size=5
117
+ ```
118
+
119
+ ### Configuration
120
+
121
+ Main training parameters can be configured in `configs/train.yaml`:
122
+
123
+ - `attention_type`: "dot-product", "multiplicative", or "additive"
124
+ - `teacher_forcing_ratio`: Ratio for teacher forcing (0.0-1.0)
125
+ - `decoding_strategy`: "greedy" or "beam-search"
126
+ - `beam_size`: Beam size for beam search (default: 5)
127
+ - `learning_rate`: Initial learning rate (default: 5e-5)
128
+ - `batch_size`: Batch size (default: 128)
129
+ - `max_epochs`: Maximum training epochs (default: 50)
130
+
131
+ ### Evaluation
132
+
133
+ Evaluate a trained model on the test set:
134
+
135
+ ```bash
136
+ python eval.py
137
+ ```
138
+
139
+ Or with custom parameters:
140
+
141
+ ```bash
142
+ python eval.py --model_path <path_to_model> --data_path <path_to_data> --decoding_strategy beam-search --beam_size 5
143
+ ```
144
+
145
+ Alternatively, you can use `inference.py` directly (same functionality):
146
+
147
+ ```bash
148
+ python inference.py --model_path <path_to_model> --data_path <path_to_data> --decoding_strategy beam-search --beam_size 5
149
+ ```
150
+
151
+ The evaluation script will output:
152
+ - Perplexity (PPL) on test set
153
+ - BLEU-1, BLEU-2, BLEU-3, BLEU-4 scores
154
+ - Detailed translation examples
155
+
156
+ ### Model Checkpoints
157
+
158
+ During training, the model saves:
159
+ - **Best model**: `save_dir/model_rnn_best.pt` (best validation perplexity)
160
+ - **Last model**: `save_dir/model_rnn_last.pt` (most recent checkpoint)
161
+ - **Optimizer state**: Saved alongside model files (`.optim` extension)
162
+
163
+ ### Resuming Training
164
+
165
+ To resume training from a checkpoint:
166
+
167
+ ```yaml
168
+ # In configs/train.yaml
169
+ resume_from_model: "save_dir/model_rnn_last.pt"
170
+ ```
171
+
172
+ ## Project Structure
173
+
174
+ ```
175
+ RNN_NMT/
176
+ β”œβ”€β”€ configs/
177
+ β”‚ └── train.yaml # Training configuration
178
+ β”œβ”€β”€ dataset/
179
+ β”‚ └── vocab.py # Vocabulary management
180
+ β”œβ”€β”€ models/
181
+ β”‚ β”œβ”€β”€ rnn_nmt.py # Main NMT model
182
+ β”‚ β”œβ”€β”€ model_embeddings.py # Embedding layers
183
+ β”‚ └── char_decoder.py # Character-level decoder
184
+ β”œβ”€β”€ utils/
185
+ β”‚ β”œβ”€β”€ utils.py # Utility functions (BLEU, batching, etc.)
186
+ β”‚ └── preprocess_data.py # Data preprocessing
187
+ β”œβ”€β”€ train.py # Training script
188
+ β”œβ”€β”€ inference.py # Evaluation script
189
+ β”œβ”€β”€ eval.py # Evaluation script (alias for inference.py)
190
+ β”œβ”€β”€ requirement.txt # Python dependencies
191
+ └── README.md # This file
192
+ ```
193
+
194
+ ## Experimental Results
195
+
196
+ The model performance is evaluated using:
197
+ - **Perplexity (PPL)**: Lower is better
198
+ - **BLEU Score**: Higher is better (BLEU-4 as primary metric)
199
+
200
+ Training metrics are automatically saved to `training_metrics.json` for visualization and analysis.
201
+
202
+ ## Acknowledgement
203
+
204
+ ζ„Ÿθ°’δ»₯δΈ‹ε‡ δΈͺδ»“εΊ“οΌš
205
+
206
+ 1. **Jieba** (Chinese word segmentation tool): [https://github.com/fxsjy/jieba](https://github.com/fxsjy/jieba)
207
+
208
+ 2. **SentencePiece** (English and multilingual subword tokenization tool): [https://github.com/google/sentencepiece](https://github.com/google/sentencepiece)
209
+
210
+ 3. **RNN Machine Translation**: [https://github.com/pi-tau/machine-translation](https://github.com/pi-tau/machine-translation)
211
+
212
+ ## License
213
+
214
+ [Add your license information here]
215
+
216
+ ## Contact
217
+
218
+ [Add your contact information here]
219
+