morphological-transformer / scripts /train_morphological_cuda.py
akki2825
Initial deployment of Morphological Transformer
fb0b30c
#!/usr/bin/env python3
"""
ULTRA-LOW-LEVEL CUDA training script for maximum speed
"""
import argparse
import json
import os
import time
import gc
import ctypes
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import torch.backends.cudnn as cudnn
from transformer import TagTransformer, PAD_IDX, DEVICE
from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn
# Aggressive CUDA optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
# Disable all logging for speed
import logging
logging.disable(logging.CRITICAL)
def create_cuda_optimized_model(config: Dict, src_vocab: Dict[str, int], tgt_vocab: Dict[str, int]) -> TagTransformer:
"""Create model with maximum CUDA optimizations"""
feature_tokens = [token for token in src_vocab.keys()
if token.startswith('<') and token.endswith('>')]
nb_attr = len(feature_tokens)
model = TagTransformer(
src_vocab_size=len(src_vocab),
trg_vocab_size=len(tgt_vocab),
embed_dim=config['embed_dim'],
nb_heads=config['nb_heads'],
src_hid_size=config['src_hid_size'],
src_nb_layers=config['src_nb_layers'],
trg_hid_size=config['trg_hid_size'],
trg_nb_layers=config['trg_nb_layers'],
dropout_p=config['dropout_p'],
tie_trg_embed=config['tie_trg_embed'],
label_smooth=config['label_smooth'],
nb_attr=nb_attr,
src_c2i=src_vocab,
trg_c2i=tgt_vocab,
attr_c2i={},
)
# Aggressive weight initialization
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
elif p.dim() == 1:
nn.init.uniform_(p, -0.1, 0.1)
# Compile model for maximum speed
if hasattr(torch, 'compile'):
try:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
print("✓ Model compiled with torch.compile (fullgraph)")
except Exception as e:
try:
model = torch.compile(model, mode="max-autotune")
print("✓ Model compiled with torch.compile")
except Exception as e2:
print(f"⚠ torch.compile failed: {e2}")
return model
def create_cuda_dataloader(dataset, config: Dict, src_vocab: Dict, tgt_vocab: Dict):
"""Create CUDA-optimized DataLoader"""
# Use maximum workers for CPU preprocessing
num_workers = min(32, os.cpu_count() or 1)
dataloader = DataLoader(
dataset,
batch_size=config['batch_size'],
shuffle=True,
collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, config['max_length']),
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
prefetch_factor=8, # Maximum prefetching
drop_last=True,
generator=torch.Generator(device='cpu'),
multiprocessing_context='spawn', # More stable than fork
)
return dataloader
def train_epoch_cuda(model: TagTransformer,
dataloader: DataLoader,
optimizer: optim.Optimizer,
device: torch.device,
epoch: int,
config: Dict,
scaler: GradScaler) -> float:
"""CUDA-optimized training with minimal overhead"""
model.train()
total_loss = 0.0
num_batches = 0
# Pre-allocate tensors and use CUDA streams
stream = torch.cuda.Stream()
# Use set_to_none for faster gradient clearing
optimizer.zero_grad(set_to_none=True)
start_time = time.time()
with torch.cuda.stream(stream):
for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader):
# Asynchronous transfer to GPU
src = src.to(device, non_blocking=True, memory_format=torch.channels_last)
src_mask = src_mask.to(device, non_blocking=True)
tgt = tgt.to(device, non_blocking=True, memory_format=torch.channels_last)
tgt_mask = tgt_mask.to(device, non_blocking=True)
# Mixed precision forward pass
with autocast(enabled=config.get('use_amp', True)):
output = model(src, src_mask, tgt, tgt_mask)
loss = model.loss(output[:-1], tgt[1:])
# Backward pass
scaler.scale(loss).backward()
# Optimizer step every batch (no accumulation for speed)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['gradient_clip'])
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
total_loss += loss.item()
num_batches += 1
# Minimal logging - only every 200 batches
if batch_idx % 200 == 0:
elapsed = time.time() - start_time
samples_per_sec = (batch_idx + 1) * config['batch_size'] / elapsed
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}, Speed: {samples_per_sec:.0f} samples/sec')
# Synchronize stream
stream.synchronize()
avg_loss = total_loss / num_batches
return avg_loss
def validate_cuda(model: TagTransformer,
dataloader: DataLoader,
device: torch.device,
config: Dict) -> float:
"""CUDA-optimized validation"""
model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for src, src_mask, tgt, tgt_mask in dataloader:
src = src.to(device, non_blocking=True, memory_format=torch.channels_last)
src_mask = src_mask.to(device, non_blocking=True)
tgt = tgt.to(device, non_blocking=True, memory_format=torch.channels_last)
tgt_mask = tgt_mask.to(device, non_blocking=True)
with autocast(enabled=config.get('use_amp', True)):
output = model(src, src_mask, tgt, tgt_mask)
loss = model.loss(output[:-1], tgt[1:])
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
return avg_loss
def save_checkpoint_cuda(model: TagTransformer,
optimizer: optim.Optimizer,
epoch: int,
loss: float,
save_path: str,
scaler: GradScaler = None):
"""Fast checkpoint saving"""
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
if scaler is not None:
checkpoint['scaler_state_dict'] = scaler.state_dict()
# Use fastest save method
torch.save(checkpoint, save_path, _use_new_zipfile_serialization=False, _use_new_zipfile_serialization_for_torch_save=False)
print(f'Checkpoint saved: {save_path}')
def load_checkpoint_cuda(model: TagTransformer,
optimizer: optim.Optimizer,
checkpoint_path: str,
scaler: GradScaler = None) -> int:
"""Fast checkpoint loading"""
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scaler is not None and 'scaler_state_dict' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f'Checkpoint loaded: {checkpoint_path}, Epoch: {epoch}, Loss: {loss:.4f}')
return epoch
def setup_cuda_environment():
"""Setup aggressive CUDA optimizations"""
if not torch.cuda.is_available():
print("CUDA not available!")
return False
# Set memory fraction and enable memory pool
torch.cuda.set_per_process_memory_fraction(0.98)
torch.cuda.empty_cache()
gc.collect()
# Enable all CUDA optimizations
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
# Set CUDA device properties for maximum performance
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
print(f"✓ CUDA Device: {props.name}")
print(f"✓ CUDA Memory: {props.total_memory / 1024**3:.1f} GB")
print(f"✓ CUDA Compute Capability: {props.major}.{props.minor}")
print(f"✓ CUDA Multiprocessors: {props.multi_processor_count}")
# Set environment variables for maximum performance
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1'
return True
def main():
parser = argparse.ArgumentParser(description='ULTRA-LOW-LEVEL CUDA training')
parser.add_argument('--resume', type=str, help='Path to checkpoint to resume from')
parser.add_argument('--output_dir', type=str, default='./models', help='Output directory')
parser.add_argument('--no_amp', action='store_true', help='Disable mixed precision training')
args = parser.parse_args()
# Ultra-aggressive configuration for maximum speed
config = {
'embed_dim': 256,
'nb_heads': 4,
'src_hid_size': 1024,
'src_nb_layers': 4,
'trg_hid_size': 1024,
'trg_nb_layers': 4,
'dropout_p': 0.1,
'tie_trg_embed': True,
'label_smooth': 0.1,
'batch_size': 1024, # Maximum batch size for GPU utilization
'learning_rate': 0.001,
'max_epochs': 1000,
'max_updates': 10000,
'warmup_steps': 4000,
'weight_decay': 0.01,
'gradient_clip': 1.0,
'save_every': 50, # Save very infrequently for speed
'eval_every': 20, # Evaluate very infrequently for speed
'max_length': 100,
'use_amp': not args.no_amp,
'gradient_accumulation_steps': 1, # No accumulation for maximum speed
}
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
# Save config
with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
json.dump(config, f, indent=2)
# Setup CUDA environment
if not setup_cuda_environment():
return
device = DEVICE
print(f'Using device: {device}')
# Data file paths
train_src = '../10L_90NL/train/run1/train.10L_90NL_1_1.src'
train_tgt = '../10L_90NL/train/run1/train.10L_90NL_1_1.tgt'
dev_src = '../10L_90NL/dev/run1/dev.10L_90NL_1_1.src'
dev_tgt = '../10L_90NL/dev/run1/dev.10L_90NL_1_1.tgt'
# Build vocabulary efficiently
print("Building vocabulary...")
src_vocab = build_vocabulary([train_src, dev_src])
tgt_vocab = build_vocabulary([train_tgt, dev_tgt])
print(f"Source vocabulary size: {len(src_vocab)}")
print(f"Target vocabulary size: {len(tgt_vocab)}")
# Create datasets
train_dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, config['max_length'])
dev_dataset = MorphologicalDataset(dev_src, dev_tgt, src_vocab, tgt_vocab, config['max_length'])
# Create CUDA-optimized dataloaders
train_loader = create_cuda_dataloader(train_dataset, config, src_vocab, tgt_vocab)
dev_loader = create_cuda_dataloader(dev_dataset, config, src_vocab, tgt_vocab)
# Create CUDA-optimized model
model = create_cuda_optimized_model(config, src_vocab, tgt_vocab)
model = model.to(device, memory_format=torch.channels_last)
# Count parameters
total_params = model.count_nb_params()
print(f'Total parameters: {total_params:,}')
# Create optimizer with maximum speed settings
optimizer = optim.AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay'],
betas=(0.9, 0.999),
eps=1e-8,
foreach=True, # Use foreach implementation
fused=True, # Use fused implementation if available
)
# Learning rate scheduler
def lr_lambda(step):
if step < config['warmup_steps']:
return float(step) / float(max(1, config['warmup_steps']))
progress = (step - config['warmup_steps']) / (config['max_updates'] - config['warmup_steps'])
return max(0.0, 0.5 * (1.0 + torch.cos(torch.pi * progress)))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Mixed precision training
scaler = GradScaler(enabled=config['use_amp'])
if config['use_amp']:
print("✓ Mixed precision training enabled")
# Resume from checkpoint if specified
start_epoch = 0
if args.resume:
start_epoch = load_checkpoint_cuda(model, optimizer, args.resume, scaler)
# Training loop
best_val_loss = float('inf')
global_step = 0
updates = 0
print(f"\nStarting CUDA-optimized training with {len(train_loader)} batches per epoch")
print(f"Batch size: {config['batch_size']}")
for epoch in range(start_epoch, config['max_epochs']):
epoch_start_time = time.time()
# Train
train_loss = train_epoch_cuda(
model, train_loader, optimizer, device, epoch, config, scaler
)
# Update learning rate
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
# Validate very infrequently for speed
if epoch % config['eval_every'] == 0:
val_loss = validate_cuda(model, dev_loader, device, config)
print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}')
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint_cuda(
model, optimizer, epoch, val_loss,
os.path.join(args.output_dir, 'checkpoints', 'best_model.pth'),
scaler
)
else:
print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, LR: {current_lr:.6f}')
# Save checkpoint very infrequently for speed
if epoch % config['save_every'] == 0:
save_checkpoint_cuda(
model, optimizer, epoch, train_loss,
os.path.join(args.output_dir, 'checkpoints', f'checkpoint_epoch_{epoch}.pth'),
scaler
)
epoch_time = time.time() - epoch_start_time
samples_per_sec = len(train_loader) * config['batch_size'] / epoch_time
print(f'Epoch {epoch} completed in {epoch_time:.2f}s ({samples_per_sec:.0f} samples/sec)')
# Count updates
updates += len(train_loader)
global_step += len(train_loader)
# Check if we've reached max updates
if updates >= config['max_updates']:
print(f'Reached maximum updates ({config["max_updates"]}), stopping training')
break
# Clear cache and synchronize
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Save final model
save_checkpoint_cuda(
model, optimizer, epoch, train_loss,
os.path.join(args.output_dir, 'checkpoints', 'final_model.pth'),
scaler
)
print('CUDA-optimized training completed!')
if __name__ == '__main__':
main()