File size: 7,598 Bytes
fb0b30c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#!/usr/bin/env python3
"""
Decoding script for morphological reinflection using TagTransformer
"""

import argparse
import json
import logging
import os
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F

from transformer import TagTransformer, PAD_IDX, DEVICE
from morphological_dataset import build_vocabulary, tokenize_sequence

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def load_model(checkpoint_path: str, config: Dict, src_vocab: Dict[str, int], 
               tgt_vocab: Dict[str, int]) -> TagTransformer:
    """Load trained model from checkpoint"""
    
    # Count feature tokens
    feature_tokens = [token for token in src_vocab.keys() 
                     if token.startswith('<') and token.endswith('>')]
    nb_attr = len(feature_tokens)
    
    # Create model
    model = TagTransformer(
        src_vocab_size=len(src_vocab),
        trg_vocab_size=len(tgt_vocab),
        embed_dim=config['embed_dim'],
        nb_heads=config['nb_heads'],
        src_hid_size=config['src_hid_size'],
        src_nb_layers=config['src_nb_layers'],
        trg_hid_size=config['trg_hid_size'],
        trg_nb_layers=config['trg_nb_layers'],
        dropout_p=0.0,  # No dropout during inference
        tie_trg_embed=config['tie_trg_embed'],
        label_smooth=0.0,  # No label smoothing during inference
        nb_attr=nb_attr,
        src_c2i=src_vocab,
        trg_c2i=tgt_vocab,
        attr_c2i={},
    )
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(DEVICE)
    model.eval()
    
    logger.info(f"Model loaded from {checkpoint_path}")
    return model

def beam_search(model: TagTransformer, src_tokens: List[str], src_vocab: Dict[str, int],
                tgt_vocab: Dict[str, int], beam_width: int = 5, max_length: int = 100) -> List[str]:
    """Perform beam search decoding"""
    
    # Tokenize source
    src_indices, _ = tokenize_sequence(src_tokens, src_vocab, max_length, add_bos_eos=True)
    src_tensor = torch.tensor([src_indices], dtype=torch.long).to(DEVICE).t()  # [seq_len, batch_size]
    
    # Create source mask
    src_mask = torch.zeros(src_tensor.size(0), src_tensor.size(1), dtype=torch.bool).to(DEVICE)
    
    # Encode source
    with torch.no_grad():
        encoded = model.encode(src_tensor, src_mask)
    
    # Initialize beam
    beam = [([tgt_vocab['<BOS>']], 0.0)]  # (sequence, score)
    
    for step in range(max_length):
        candidates = []
        
        for sequence, score in beam:
            if sequence[-1] == tgt_vocab['<EOS>']:
                candidates.append((sequence, score))
                continue
            
            # Prepare target input
            tgt_tensor = torch.tensor([sequence], dtype=torch.long).to(DEVICE).t()
            tgt_mask = torch.zeros(tgt_tensor.size(0), tgt_tensor.size(1), dtype=torch.bool).to(DEVICE)
            
            # Decode
            with torch.no_grad():
                output = model.decode(encoded, src_mask, tgt_tensor, tgt_mask)
            
            # Get next token probabilities
            next_token_probs = output[-1, 0]  # [vocab_size]
            
            # Get top-k candidates
            top_k_probs, top_k_indices = torch.topk(next_token_probs, beam_width)
            
            for prob, idx in zip(top_k_probs, top_k_indices):
                new_sequence = sequence + [idx.item()]
                new_score = score + prob.item()
                candidates.append((new_sequence, new_score))
        
        # Select top beam_width candidates
        candidates.sort(key=lambda x: x[1], reverse=True)
        beam = candidates[:beam_width]
        
        # Check if all sequences end with EOS
        if all(seq[-1] == tgt_vocab['<EOS>'] for seq, _ in beam):
            break
    
    # Return best sequence
    best_sequence, _ = beam[0]
    
    # Convert indices to tokens
    idx_to_token = {idx: token for token, idx in tgt_vocab.items()}
    result_tokens = [idx_to_token[idx] for idx in best_sequence[1:-1]]  # Remove BOS and EOS
    
    return result_tokens

def main():
    parser = argparse.ArgumentParser(description='Decode using trained TagTransformer')
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--config', type=str, default='./models/config.json', help='Path to config file')
    parser.add_argument('--src_file', type=str, help='Source file for decoding')
    parser.add_argument('--output_file', type=str, help='Output file for predictions')
    parser.add_argument('--beam_width', type=int, default=5, help='Beam width for decoding')
    args = parser.parse_args()
    
    # Load configuration
    with open(args.config, 'r') as f:
        config = json.load(f)
    
    # Data file paths for vocabulary building
    train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src'
    train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt'
    dev_src = '10L_90NL/dev/run1/train.10L_90NL_1_1.src'
    dev_tgt = '10L_90NL/dev/run1/train.10L_90NL_1_1.tgt'
    test_src = '10L_90NL/test/run1/train.10L_90NL_1_1.src'
    test_tgt = '10L_90NL/test/run1/train.10L_90NL_1_1.tgt'
    
    # Build vocabularies
    logger.info("Building vocabulary...")
    src_vocab = build_vocabulary([train_src, dev_src, test_src])
    tgt_vocab = build_vocabulary([train_tgt, dev_tgt, test_tgt])
    
    # Load model
    model = load_model(args.checkpoint, config, src_vocab, tgt_vocab)
    
    # Decode
    if args.src_file and args.output_file:
        # Batch decoding from file
        logger.info(f"Decoding from {args.src_file} to {args.output_file}")
        
        with open(args.src_file, 'r', encoding='utf-8') as src_f, \
             open(args.output_file, 'w', encoding='utf-8') as out_f:
            
            for line_num, line in enumerate(src_f, 1):
                src_tokens = line.strip().split()
                if not src_tokens:
                    continue
                
                try:
                    result = beam_search(model, src_tokens, src_vocab, tgt_vocab, args.beam_width)
                    out_f.write(' '.join(result) + '\n')
                    
                    if line_num % 100 == 0:
                        logger.info(f"Processed {line_num} lines")
                        
                except Exception as e:
                    logger.error(f"Error processing line {line_num}: {e}")
                    out_f.write('<ERROR>\n')
        
        logger.info("Decoding completed!")
        
    else:
        # Interactive decoding
        logger.info("Interactive decoding mode. Type 'quit' to exit.")
        
        while True:
            try:
                user_input = input("Enter source sequence: ").strip()
                if user_input.lower() == 'quit':
                    break
                
                if not user_input:
                    continue
                
                src_tokens = user_input.split()
                result = beam_search(model, src_tokens, src_vocab, tgt_vocab, args.beam_width)
                
                print(f"Prediction: {' '.join(result)}")
                print()
                
            except KeyboardInterrupt:
                break
            except Exception as e:
                logger.error(f"Error: {e}")

if __name__ == '__main__':
    main()