akki2825
Initial deployment of Morphological Transformer
fb0b30c
#!/usr/bin/env python3
"""
Simplified training script for TagTransformer (without TensorBoard)
"""
import argparse
import json
import logging
import os
import random
import time
from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformer import TagTransformer, PAD_IDX, DEVICE
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class DummyDataset(Dataset):
"""Dummy dataset for demonstration - replace with your actual dataset"""
def __init__(self, num_samples=1000, max_seq_len=50, vocab_size=1000, nb_attr=100):
self.num_samples = num_samples
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.nb_attr = nb_attr
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Generate random source and target sequences
src_len = random.randint(10, self.max_seq_len)
trg_len = random.randint(10, self.max_seq_len)
# Source sequence with some attribute tokens at the end
src = torch.randint(0, self.vocab_size - self.nb_attr, (src_len,))
# Add some attribute tokens
if self.nb_attr > 0:
num_attr = random.randint(1, min(5, self.nb_attr))
attr_tokens = torch.randint(self.vocab_size - self.nb_attr, self.vocab_size, (num_attr,))
src = torch.cat([src, attr_tokens])
# Target sequence
trg = torch.randint(0, self.vocab_size, (trg_len,))
# Create masks
src_mask = torch.ones(src.size(0), dtype=torch.bool)
trg_mask = torch.ones(trg.size(0), dtype=torch.bool)
return src, src_mask, trg, trg_mask
def collate_fn(batch):
"""Collate function for DataLoader"""
src_batch, src_masks, trg_batch, trg_masks = zip(*batch)
# Pad sequences to max length in batch
max_src_len = max(len(src) for src in src_batch)
max_trg_len = max(len(trg) for trg in trg_batch)
# Pad source sequences
padded_src = []
padded_src_masks = []
for src, mask in zip(src_batch, src_masks):
padding_len = max_src_len - len(src)
if padding_len > 0:
src = torch.cat([src, torch.full((padding_len,), PAD_IDX)])
mask = torch.cat([mask, torch.zeros(padding_len, dtype=torch.bool)])
padded_src.append(src)
padded_src_masks.append(mask)
# Pad target sequences
padded_trg = []
padded_trg_masks = []
for trg, mask in zip(trg_batch, trg_masks):
padding_len = max_trg_len - len(trg)
if padding_len > 0:
trg = torch.cat([trg, torch.full((padding_len,), PAD_IDX)])
mask = torch.cat([mask, torch.zeros(padding_len, dtype=torch.bool)])
padded_trg.append(trg)
padded_trg_masks.append(mask)
# Stack and transpose for transformer input format [seq_len, batch_size]
src_batch = torch.stack(padded_src).t()
src_masks = torch.stack(padded_src_masks).t()
trg_batch = torch.stack(padded_trg).t()
trg_masks = torch.stack(padded_trg_masks).t()
return src_batch, src_masks, trg_batch, trg_masks
def create_model(config: Dict) -> TagTransformer:
"""Create and initialize the TagTransformer model"""
model = TagTransformer(
src_vocab_size=config['src_vocab_size'],
trg_vocab_size=config['trg_vocab_size'],
embed_dim=config['embed_dim'],
nb_heads=config['nb_heads'],
src_hid_size=config['src_hid_size'],
src_nb_layers=config['src_nb_layers'],
trg_hid_size=config['trg_hid_size'],
trg_nb_layers=config['trg_nb_layers'],
dropout_p=config['dropout_p'],
tie_trg_embed=config['tie_trg_embed'],
label_smooth=config['label_smooth'],
nb_attr=config['nb_attr'],
src_c2i={}, # Placeholder - replace with actual mappings
trg_c2i={}, # Placeholder - replace with actual mappings
attr_c2i={}, # Placeholder - replace with actual mappings
)
# Initialize weights
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
def train_epoch(model: TagTransformer,
dataloader: DataLoader,
optimizer: optim.Optimizer,
device: torch.device,
epoch: int) -> float:
"""Train for one epoch"""
model.train()
total_loss = 0.0
num_batches = 0
for batch_idx, (src, src_mask, trg, trg_mask) in enumerate(dataloader):
src, src_mask, trg, trg_mask = (
src.to(device), src_mask.to(device),
trg.to(device), trg_mask.to(device)
)
optimizer.zero_grad()
# Forward pass
output = model(src, src_mask, trg, trg_mask)
# Compute loss (shift sequences for next-token prediction)
loss = model.loss(output[:-1], trg[1:])
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 100 == 0:
logger.info(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
avg_loss = total_loss / num_batches
return avg_loss
def validate(model: TagTransformer,
dataloader: DataLoader,
device: torch.device) -> float:
"""Validate the model"""
model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for src, src_mask, trg, trg_mask in dataloader:
src, src_mask, trg, trg_mask = (
src.to(device), src_mask.to(device),
trg.to(device), trg_mask.to(device)
)
# Forward pass
output = model(src, src_mask, trg, trg_mask)
# Compute loss
loss = model.loss(output[:-1], trg[1:])
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
return avg_loss
def save_checkpoint(model: TagTransformer,
optimizer: optim.Optimizer,
epoch: int,
loss: float,
save_path: str):
"""Save model checkpoint"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, save_path)
logger.info(f'Checkpoint saved to {save_path}')
def main():
parser = argparse.ArgumentParser(description='Train TagTransformer (Simplified)')
parser.add_argument('--output_dir', type=str, default='./outputs', help='Output directory')
parser.add_argument('--num_epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size')
args = parser.parse_args()
# Configuration
config = {
'src_vocab_size': 1000,
'trg_vocab_size': 1000,
'embed_dim': 256,
'nb_heads': 8,
'src_hid_size': 1024,
'src_nb_layers': 3,
'trg_hid_size': 1024,
'trg_nb_layers': 3,
'dropout_p': 0.1,
'tie_trg_embed': True,
'label_smooth': 0.1,
'nb_attr': 50,
'batch_size': args.batch_size,
'learning_rate': 0.0001,
'num_epochs': args.num_epochs,
'warmup_steps': 100,
'weight_decay': 0.01,
}
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Save config
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=2)
# Set device
device = DEVICE
logger.info(f'Using device: {device}')
# Create datasets
train_dataset = DummyDataset(
num_samples=1000,
max_seq_len=50,
vocab_size=config['src_vocab_size'],
nb_attr=config['nb_attr']
)
val_dataset = DummyDataset(
num_samples=100,
max_seq_len=50,
vocab_size=config['src_vocab_size'],
nb_attr=config['nb_attr']
)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
collate_fn=collate_fn,
num_workers=2
)
val_loader = DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
collate_fn=collate_fn,
num_workers=2
)
# Create model
model = create_model(config)
model = model.to(device)
# Count parameters
total_params = model.count_nb_params()
logger.info(f'Total parameters: {total_params:,}')
# Create optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
# Learning rate scheduler
def lr_lambda(step):
if step < config['warmup_steps']:
return float(step) / float(max(1, config['warmup_steps']))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * (step - config['warmup_steps']) /
(len(train_loader) * config['num_epochs'] - config['warmup_steps']))))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Training loop
best_val_loss = float('inf')
for epoch in range(config['num_epochs']):
start_time = time.time()
# Train
train_loss = train_epoch(model, train_loader, optimizer, device, epoch)
# Update learning rate
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
# Validate
val_loss = validate(model, val_loader, device)
logger.info(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(
model, optimizer, epoch, val_loss,
os.path.join(args.output_dir, 'best_model.pth')
)
# Save checkpoint periodically
if epoch % 5 == 0:
save_checkpoint(
model, optimizer, epoch, train_loss,
os.path.join(args.output_dir, f'checkpoint_epoch_{epoch}.pth')
)
epoch_time = time.time() - start_time
logger.info(f'Epoch {epoch} completed in {epoch_time:.2f}s')
# Save final model
save_checkpoint(
model, optimizer, config['num_epochs'] - 1, train_loss,
os.path.join(args.output_dir, 'final_model.pth')
)
logger.info('Training completed!')
if __name__ == '__main__':
main()