akki2825
Initial deployment of Morphological Transformer with ZeroGPU
1f39ae1
#!/usr/bin/env python3
"""
Decoding script for morphological reinflection using TagTransformer
"""
import argparse
import json
import logging
import os
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from transformer import TagTransformer, PAD_IDX, DEVICE
from morphological_dataset import build_vocabulary, tokenize_sequence
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def load_model(checkpoint_path: str, config: Dict, src_vocab: Dict[str, int],
tgt_vocab: Dict[str, int]) -> TagTransformer:
"""Load trained model from checkpoint"""
# Count feature tokens
feature_tokens = [token for token in src_vocab.keys()
if token.startswith('<') and token.endswith('>')]
nb_attr = len(feature_tokens)
# Create model
model = TagTransformer(
src_vocab_size=len(src_vocab),
trg_vocab_size=len(tgt_vocab),
embed_dim=config['embed_dim'],
nb_heads=config['nb_heads'],
src_hid_size=config['src_hid_size'],
src_nb_layers=config['src_nb_layers'],
trg_hid_size=config['trg_hid_size'],
trg_nb_layers=config['trg_nb_layers'],
dropout_p=0.0, # No dropout during inference
tie_trg_embed=config['tie_trg_embed'],
label_smooth=0.0, # No label smoothing during inference
nb_attr=nb_attr,
src_c2i=src_vocab,
trg_c2i=tgt_vocab,
attr_c2i={},
)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(DEVICE)
model.eval()
logger.info(f"Model loaded from {checkpoint_path}")
return model
def beam_search(model: TagTransformer, src_tokens: List[str], src_vocab: Dict[str, int],
tgt_vocab: Dict[str, int], beam_width: int = 5, max_length: int = 100) -> List[str]:
"""Perform beam search decoding"""
# Tokenize source
src_indices, _ = tokenize_sequence(src_tokens, src_vocab, max_length, add_bos_eos=True)
src_tensor = torch.tensor([src_indices], dtype=torch.long).to(DEVICE).t() # [seq_len, batch_size]
# Create source mask
src_mask = torch.zeros(src_tensor.size(0), src_tensor.size(1), dtype=torch.bool).to(DEVICE)
# Encode source
with torch.no_grad():
encoded = model.encode(src_tensor, src_mask)
# Initialize beam
beam = [([tgt_vocab['<BOS>']], 0.0)] # (sequence, score)
for step in range(max_length):
candidates = []
for sequence, score in beam:
if sequence[-1] == tgt_vocab['<EOS>']:
candidates.append((sequence, score))
continue
# Prepare target input
tgt_tensor = torch.tensor([sequence], dtype=torch.long).to(DEVICE).t()
tgt_mask = torch.zeros(tgt_tensor.size(0), tgt_tensor.size(1), dtype=torch.bool).to(DEVICE)
# Decode
with torch.no_grad():
output = model.decode(encoded, src_mask, tgt_tensor, tgt_mask)
# Get next token probabilities
next_token_probs = output[-1, 0] # [vocab_size]
# Get top-k candidates
top_k_probs, top_k_indices = torch.topk(next_token_probs, beam_width)
for prob, idx in zip(top_k_probs, top_k_indices):
new_sequence = sequence + [idx.item()]
new_score = score + prob.item()
candidates.append((new_sequence, new_score))
# Select top beam_width candidates
candidates.sort(key=lambda x: x[1], reverse=True)
beam = candidates[:beam_width]
# Check if all sequences end with EOS
if all(seq[-1] == tgt_vocab['<EOS>'] for seq, _ in beam):
break
# Return best sequence
best_sequence, _ = beam[0]
# Convert indices to tokens
idx_to_token = {idx: token for token, idx in tgt_vocab.items()}
result_tokens = [idx_to_token[idx] for idx in best_sequence[1:-1]] # Remove BOS and EOS
return result_tokens
def main():
parser = argparse.ArgumentParser(description='Decode using trained TagTransformer')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--config', type=str, default='./models/config.json', help='Path to config file')
parser.add_argument('--src_file', type=str, help='Source file for decoding')
parser.add_argument('--output_file', type=str, help='Output file for predictions')
parser.add_argument('--beam_width', type=int, default=5, help='Beam width for decoding')
args = parser.parse_args()
# Load configuration
with open(args.config, 'r') as f:
config = json.load(f)
# Data file paths for vocabulary building
train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src'
train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt'
dev_src = '10L_90NL/dev/run1/train.10L_90NL_1_1.src'
dev_tgt = '10L_90NL/dev/run1/train.10L_90NL_1_1.tgt'
test_src = '10L_90NL/test/run1/train.10L_90NL_1_1.src'
test_tgt = '10L_90NL/test/run1/train.10L_90NL_1_1.tgt'
# Build vocabularies
logger.info("Building vocabulary...")
src_vocab = build_vocabulary([train_src, dev_src, test_src])
tgt_vocab = build_vocabulary([train_tgt, dev_tgt, test_tgt])
# Load model
model = load_model(args.checkpoint, config, src_vocab, tgt_vocab)
# Decode
if args.src_file and args.output_file:
# Batch decoding from file
logger.info(f"Decoding from {args.src_file} to {args.output_file}")
with open(args.src_file, 'r', encoding='utf-8') as src_f, \
open(args.output_file, 'w', encoding='utf-8') as out_f:
for line_num, line in enumerate(src_f, 1):
src_tokens = line.strip().split()
if not src_tokens:
continue
try:
result = beam_search(model, src_tokens, src_vocab, tgt_vocab, args.beam_width)
out_f.write(' '.join(result) + '\n')
if line_num % 100 == 0:
logger.info(f"Processed {line_num} lines")
except Exception as e:
logger.error(f"Error processing line {line_num}: {e}")
out_f.write('<ERROR>\n')
logger.info("Decoding completed!")
else:
# Interactive decoding
logger.info("Interactive decoding mode. Type 'quit' to exit.")
while True:
try:
user_input = input("Enter source sequence: ").strip()
if user_input.lower() == 'quit':
break
if not user_input:
continue
src_tokens = user_input.split()
result = beam_search(model, src_tokens, src_vocab, tgt_vocab, args.beam_width)
print(f"Prediction: {' '.join(result)}")
print()
except KeyboardInterrupt:
break
except Exception as e:
logger.error(f"Error: {e}")
if __name__ == '__main__':
main()