yashwan2003's picture
Update main.py
1025f82 verified
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
def get_device():
return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def scaled_dot_product(q, k, v, mask=None):
d_k = q.size()[-1]
scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
if mask is not None:
scaled = scaled + mask # Fixed: Direct addition, mask already has correct shape
attention = F.softmax(scaled, dim=-1)
values = torch.matmul(attention, v)
return values, attention
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_sequence_length):
super().__init__()
self.max_sequence_length = max_sequence_length
self.d_model = d_model
def forward(self):
even_i = torch.arange(0, self.d_model, 2).float()
denominator = torch.pow(10000, even_i/self.d_model)
position = (torch.arange(self.max_sequence_length)
.reshape(self.max_sequence_length, 1))
even_PE = torch.sin(position / denominator)
odd_PE = torch.cos(position / denominator)
stacked = torch.stack([even_PE, odd_PE], dim=2)
PE = torch.flatten(stacked, start_dim=1, end_dim=2)
return PE
class SentenceEmbedding(nn.Module):
"""
REMOVED - This class is no longer needed as tokenization is handled externally
The model now expects pre-tokenized integer sequences as input
"""
def __init__(self, max_sequence_length, d_model, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.max_sequence_length = max_sequence_length
self.embedding = nn.Embedding(vocab_size, d_model)
self.position_encoder = PositionalEncoding(d_model, max_sequence_length)
self.dropout = nn.Dropout(p=0.1)
def forward(self, x):
"""
Args:
x: Pre-tokenized integer sequences [batch_size, seq_len]
Returns:
Embedded sequences with positional encoding [batch_size, seq_len, d_model]
"""
# x is already tokenized as integers
x = self.embedding(x)
pos = self.position_encoder().to(x.device)
x = self.dropout(x + pos)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_layer = nn.Linear(d_model , 3 * d_model)
self.linear_layer = nn.Linear(d_model, d_model)
def forward(self, x, mask):
batch_size, sequence_length, d_model = x.size()
qkv = self.qkv_layer(x)
qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim)
qkv = qkv.permute(0, 2, 1, 3)
q, k, v = qkv.chunk(3, dim=-1)
values, attention = scaled_dot_product(q, k, v, mask)
values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, self.num_heads * self.head_dim)
out = self.linear_layer(values)
return out
class LayerNormalization(nn.Module):
def __init__(self, parameters_shape, eps=1e-5):
super().__init__()
self.parameters_shape=parameters_shape
self.eps=eps
self.gamma = nn.Parameter(torch.ones(parameters_shape))
self.beta = nn.Parameter(torch.zeros(parameters_shape))
def forward(self, inputs):
dims = [-(i + 1) for i in range(len(self.parameters_shape))]
mean = inputs.mean(dim=dims, keepdim=True)
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
std = (var + self.eps).sqrt()
y = (inputs - mean) / std
out = self.gamma * y + self.beta
return out
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, hidden, drop_prob=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, hidden)
self.linear2 = nn.Linear(hidden, d_model)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=drop_prob)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class EncoderLayer(nn.Module):
def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
super(EncoderLayer, self).__init__()
self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.norm1 = LayerNormalization(parameters_shape=[d_model])
self.dropout1 = nn.Dropout(p=drop_prob)
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
self.norm2 = LayerNormalization(parameters_shape=[d_model])
self.dropout2 = nn.Dropout(p=drop_prob)
def forward(self, x, self_attention_mask):
residual_x = x.clone()
x = self.attention(x, mask=self_attention_mask)
x = self.dropout1(x)
x = self.norm1(x + residual_x)
residual_x = x.clone()
x = self.ffn(x)
x = self.dropout2(x)
x = self.norm2(x + residual_x)
return x
class SequentialEncoder(nn.Sequential):
def forward(self, *inputs):
x, self_attention_mask = inputs
for module in self._modules.values():
x = module(x, self_attention_mask)
return x
class Encoder(nn.Module):
def __init__(self,
d_model,
ffn_hidden,
num_heads,
drop_prob,
num_layers,
max_sequence_length,
vocab_size):
super().__init__()
# Simplified: only needs vocab_size now
self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, vocab_size)
self.layers = SequentialEncoder(*[EncoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
for _ in range(num_layers)])
def forward(self, x, self_attention_mask):
"""
Args:
x: Pre-tokenized integer sequences [batch_size, seq_len]
self_attention_mask: Attention mask [batch_size, 1, 1, seq_len]
"""
x = self.sentence_embedding(x)
x = self.layers(x, self_attention_mask)
return x
class MultiHeadCrossAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.kv_layer = nn.Linear(d_model , 2 * d_model)
self.q_layer = nn.Linear(d_model , d_model)
self.linear_layer = nn.Linear(d_model, d_model)
def forward(self, x, y, mask):
batch_size, sequence_length, d_model = x.size()
kv = self.kv_layer(x)
q = self.q_layer(y)
kv = kv.reshape(batch_size, sequence_length, self.num_heads, 2 * self.head_dim)
q = q.reshape(batch_size, sequence_length, self.num_heads, self.head_dim)
kv = kv.permute(0, 2, 1, 3)
q = q.permute(0, 2, 1, 3)
k, v = kv.chunk(2, dim=-1)
values, attention = scaled_dot_product(q, k, v, mask)
values = values.permute(0, 2, 1, 3).reshape(batch_size, sequence_length, d_model)
out = self.linear_layer(values)
return out
class DecoderLayer(nn.Module):
def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
super(DecoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.layer_norm1 = LayerNormalization(parameters_shape=[d_model])
self.dropout1 = nn.Dropout(p=drop_prob)
self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
self.layer_norm2 = LayerNormalization(parameters_shape=[d_model])
self.dropout2 = nn.Dropout(p=drop_prob)
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
self.layer_norm3 = LayerNormalization(parameters_shape=[d_model])
self.dropout3 = nn.Dropout(p=drop_prob)
def forward(self, x, y, self_attention_mask, cross_attention_mask):
_y = y.clone()
y = self.self_attention(y, mask=self_attention_mask)
y = self.dropout1(y)
y = self.layer_norm1(y + _y)
_y = y.clone()
y = self.encoder_decoder_attention(x, y, mask=cross_attention_mask)
y = self.dropout2(y)
y = self.layer_norm2(y + _y)
_y = y.clone()
y = self.ffn(y)
y = self.dropout3(y)
y = self.layer_norm3(y + _y)
return y
class SequentialDecoder(nn.Sequential):
def forward(self, *inputs):
x, y, self_attention_mask, cross_attention_mask = inputs
for module in self._modules.values():
y = module(x, y, self_attention_mask, cross_attention_mask)
return y
class Decoder(nn.Module):
def __init__(self,
d_model,
ffn_hidden,
num_heads,
drop_prob,
num_layers,
max_sequence_length,
vocab_size):
super().__init__()
# Simplified: only needs vocab_size now
self.sentence_embedding = SentenceEmbedding(max_sequence_length, d_model, vocab_size)
self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob) for _ in range(num_layers)])
def forward(self, x, y, self_attention_mask, cross_attention_mask):
"""
Args:
x: Encoder output [batch_size, seq_len, d_model]
y: Pre-tokenized target sequences [batch_size, seq_len]
self_attention_mask: Decoder self-attention mask [batch_size, 1, seq_len, seq_len]
cross_attention_mask: Cross-attention mask [batch_size, 1, 1, seq_len]
"""
y = self.sentence_embedding(y)
y = self.layers(x, y, self_attention_mask, cross_attention_mask)
return y
class Transformer(nn.Module):
def __init__(self,
d_model,
ffn_hidden,
num_heads,
drop_prob,
num_layers,
max_sequence_length,
src_vocab_size,
tgt_vocab_size):
"""
Simplified Transformer initialization
Args:
d_model: Model dimension
ffn_hidden: FFN hidden dimension
num_heads: Number of attention heads
drop_prob: Dropout probability
num_layers: Number of encoder/decoder layers
max_sequence_length: Maximum sequence length
src_vocab_size: Source vocabulary size
tgt_vocab_size: Target vocabulary size
"""
super().__init__()
self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers,
max_sequence_length, src_vocab_size)
self.decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers,
max_sequence_length, tgt_vocab_size)
self.linear = nn.Linear(d_model, tgt_vocab_size)
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def forward(self,
x,
y,
encoder_self_attention_mask=None,
decoder_self_attention_mask=None,
decoder_cross_attention_mask=None):
"""
Forward pass - simplified interface
Args:
x: Source token IDs [batch_size, src_len]
y: Target token IDs [batch_size, tgt_len]
encoder_self_attention_mask: Encoder mask [batch_size, 1, 1, src_len]
decoder_self_attention_mask: Decoder self-attention mask [batch_size, 1, tgt_len, tgt_len]
decoder_cross_attention_mask: Cross-attention mask [batch_size, 1, 1, src_len]
Returns:
Output logits [batch_size, tgt_len, tgt_vocab_size]
"""
x = self.encoder(x, encoder_self_attention_mask)
out = self.decoder(x, y, decoder_self_attention_mask, decoder_cross_attention_mask)
out = self.linear(out)
return out
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace, ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
import numpy as np
from typing import List, Tuple
import json
# Import your transformer model and the tokenizer utilities
# from your_transformer_file import Transformer
# from tokenizer_data_prep import (
# prepare_data, create_masks, TranslationTokenizer,
# START_TOKEN, END_TOKEN, PADDING_TOKEN
# )
# Special tokens
START_TOKEN = '<START>'
END_TOKEN = '<END>'
PADDING_TOKEN = '<PAD>'
UNKNOWN_TOKEN = '<UNK>'
class TranslationTokenizer:
"""
Tokenizer for machine translation with special tokens support
"""
def __init__(self, vocab_size: int = 10000):
self.vocab_size = vocab_size
self.special_tokens = [PADDING_TOKEN, UNKNOWN_TOKEN, START_TOKEN, END_TOKEN]
self.tokenizer = None
def train(self, texts: List[str], save_path: str = None):
"""
Train tokenizer on corpus
Args:
texts: List of sentences to train on
save_path: Path to save trained tokenizer
"""
# Initialize BPE tokenizer
self.tokenizer = Tokenizer(BPE(unk_token=UNKNOWN_TOKEN))
# # self.tokenizer.pre_tokenizer = Whitespace()
# self.tokenizer.pre_tokenizer = ByteLevel()
# # Configure trainer
# trainer = BpeTrainer(
# vocab_size=self.vocab_size,
# special_tokens=self.special_tokens,
# show_progress=True
# )
# This handles the "whitespace" issue and maps bytes to visible characters
self.tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
# 3. Set Decoder (CRITICAL for getting Kannada back)
self.tokenizer.decoder = ByteLevelDecoder()
# 4. Configure trainer
# We set min_frequency=2 to ensure rare characters aren't immediate UNKs
trainer = BpeTrainer(
vocab_size=self.vocab_size,
special_tokens=self.special_tokens,
min_frequency=2,
show_progress=True
)
# Train on texts
self.tokenizer.train_from_iterator(texts, trainer)
# Save if path provided
if save_path:
self.tokenizer.save(save_path)
return self
def load(self, path: str):
"""Load pre-trained tokenizer"""
self.tokenizer = Tokenizer.from_file(path)
return self
def encode(self, text: str) -> List[int]:
"""Encode text to token IDs"""
if self.tokenizer is None:
raise ValueError("Tokenizer not trained or loaded")
return self.tokenizer.encode(text).ids
def decode(self, ids: List[int]) -> str:
"""Decode token IDs to text"""
if self.tokenizer is None:
raise ValueError("Tokenizer not trained or loaded")
return self.tokenizer.decode(ids)
def get_vocab(self) -> dict:
"""Get vocabulary mapping"""
if self.tokenizer is None:
raise ValueError("Tokenizer not trained or loaded")
return self.tokenizer.get_vocab()
def get_vocab_size(self) -> int:
"""Get vocabulary size"""
if self.tokenizer is None:
raise ValueError("Tokenizer not trained or loaded")
return self.tokenizer.get_vocab_size()
class TranslationDataset(Dataset):
"""
Dataset for machine translation - OPTIMIZED with pre-tokenization
"""
def __init__(
self,
source_texts: List[str],
target_texts: List[str],
source_tokenizer: TranslationTokenizer,
target_tokenizer: TranslationTokenizer,
max_length: int = 200
):
"""
Args:
source_texts: List of source language sentences
target_texts: List of target language sentences
source_tokenizer: Tokenizer for source language
target_tokenizer: Tokenizer for target language
max_length: Maximum sequence length
"""
assert len(source_texts) == len(target_texts), "Source and target must have same length"
self.max_length = max_length
# Get special token IDs
src_vocab = source_tokenizer.get_vocab()
tgt_vocab = target_tokenizer.get_vocab()
self.src_pad_idx = src_vocab[PADDING_TOKEN]
self.tgt_pad_idx = tgt_vocab[PADDING_TOKEN]
self.tgt_start_idx = tgt_vocab[START_TOKEN]
self.tgt_end_idx = tgt_vocab[END_TOKEN]
# PERFORMANCE FIX: Pre-tokenize all data once during initialization
# This is 100x faster than tokenizing on-the-fly during training
print(f"Pre-tokenizing {len(source_texts)} samples...")
self.source_ids = []
self.target_inputs = []
self.target_labels = []
for src_text, tgt_text in tqdm(zip(source_texts, target_texts), total=len(source_texts), desc="Tokenizing"):
# Encode and truncate
src_ids = source_tokenizer.encode(src_text)[:max_length]
tgt_ids = target_tokenizer.encode(tgt_text)[:max_length - 1]
# Create target input (with START) and labels (with END)
tgt_input = [self.tgt_start_idx] + tgt_ids
tgt_label = tgt_ids + [self.tgt_end_idx]
# Pad sequences
src_padded = src_ids + [self.src_pad_idx] * (max_length - len(src_ids))
tgt_input_padded = tgt_input + [self.tgt_pad_idx] * (max_length - len(tgt_input))
tgt_label_padded = tgt_label + [self.tgt_pad_idx] * (max_length - len(tgt_label))
# Store as tensors
self.source_ids.append(torch.tensor(src_padded, dtype=torch.long))
self.target_inputs.append(torch.tensor(tgt_input_padded, dtype=torch.long))
self.target_labels.append(torch.tensor(tgt_label_padded, dtype=torch.long))
print(f"✓ Pre-tokenization complete!")
def __len__(self):
return len(self.source_ids)
def __getitem__(self, idx):
"""
Returns pre-tokenized and padded sequences (no processing needed)
"""
return {
'source': self.source_ids[idx],
'target': self.target_inputs[idx],
'labels': self.target_labels[idx]
}
# Pre-compute causal mask once (cached for performance)
_causal_mask_cache = {}
def create_masks(src, tgt, src_pad_idx, tgt_pad_idx, device=device):
"""
Create attention masks - OPTIMIZED with cached causal masks
Args:
src: Source sequences [batch_size, src_len]
tgt: Target sequences [batch_size, tgt_len]
src_pad_idx: Padding index for source
tgt_pad_idx: Padding index for target
device: Device to create masks on
Returns:
src_mask: Encoder self-attention mask
tgt_mask: Decoder self-attention mask (causal)
cross_mask: Decoder cross-attention mask
"""
batch_size = src.size(0)
src_len = src.size(1)
tgt_len = tgt.size(1)
# Source padding mask [batch_size, 1, 1, src_len]
src_padding_mask = (src == src_pad_idx).unsqueeze(1).unsqueeze(2)
src_mask = src_padding_mask.float() * -1e9
# Target padding mask [batch_size, 1, 1, tgt_len]
tgt_padding_mask = (tgt == tgt_pad_idx).unsqueeze(1).unsqueeze(2)
# PERFORMANCE FIX: Cache causal mask to avoid recreating it every batch
# This single change can give 2-3x speedup as torch.triu is expensive
cache_key = (tgt_len, device)
if cache_key not in _causal_mask_cache:
tgt_causal_mask = torch.triu(torch.ones(tgt_len, tgt_len, device=device), diagonal=1).bool()
tgt_causal_mask = tgt_causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, tgt_len, tgt_len]
_causal_mask_cache[cache_key] = tgt_causal_mask
else:
tgt_causal_mask = _causal_mask_cache[cache_key]
# Combine padding and causal masks
tgt_mask = (tgt_padding_mask | tgt_causal_mask).float() * -1e9
# Cross attention mask (only padding) [batch_size, 1, 1, src_len]
cross_mask = src_padding_mask.float() * -1e9
return src_mask, tgt_mask, cross_mask
# Example usage and data preparation pipeline
def prepare_data(
source_texts: List[str],
target_texts: List[str],
source_vocab_size: int = 10000,
target_vocab_size: int = 10000,
max_length: int = 200,
batch_size: int = 64,
train_split: float = 0.9
):
"""
Complete data preparation pipeline
Args:
source_texts: List of source language sentences
target_texts: List of target language sentences
source_vocab_size: Vocabulary size for source language
target_vocab_size: Vocabulary size for target language
max_length: Maximum sequence length
batch_size: Batch size for DataLoader
train_split: Proportion of data for training
Returns:
train_loader: Training DataLoader
val_loader: Validation DataLoader
source_tokenizer: Trained source tokenizer
target_tokenizer: Trained target tokenizer
vocab_info: Dictionary with vocabulary information
"""
print("Training tokenizers...")
# Train tokenizers
source_tokenizer = TranslationTokenizer(vocab_size=source_vocab_size)
source_tokenizer.train(source_texts, save_path='source_tokenizer.json')
target_tokenizer = TranslationTokenizer(vocab_size=target_vocab_size)
target_tokenizer.train(target_texts, save_path='target_tokenizer.json')
print(f"Source vocabulary size: {source_tokenizer.get_vocab_size()}")
print(f"Target vocabulary size: {target_tokenizer.get_vocab_size()}")
# Split data
n_samples = len(source_texts)
n_train = int(n_samples * train_split)
indices = np.random.permutation(n_samples)
train_indices = indices[:n_train]
val_indices = indices[n_train:]
train_src = [source_texts[i] for i in train_indices]
train_tgt = [target_texts[i] for i in train_indices]
val_src = [source_texts[i] for i in val_indices]
val_tgt = [target_texts[i] for i in val_indices]
print(f"Training samples: {len(train_src)}")
print(f"Validation samples: {len(val_src)}")
# Create datasets
train_dataset = TranslationDataset(
train_src, train_tgt, source_tokenizer, target_tokenizer, max_length
)
val_dataset = TranslationDataset(
val_src, val_tgt, source_tokenizer, target_tokenizer, max_length
)
# Create dataloaders with parallel data loading for faster training
# PERFORMANCE FIX: num_workers > 0 enables parallel data loading
# This prevents GPU from waiting idle while CPU prepares batches
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4, # Use 4 worker processes for parallel loading
pin_memory=torch.cuda.is_available(), # Faster GPU transfer
persistent_workers=True # Keep workers alive between epochs
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2, # Use 2 workers for validation
pin_memory=torch.cuda.is_available(),
persistent_workers=True
)
# Vocabulary information
src_vocab = source_tokenizer.get_vocab()
tgt_vocab = target_tokenizer.get_vocab()
vocab_info = {
'source_vocab_size': source_tokenizer.get_vocab_size(),
'target_vocab_size': target_tokenizer.get_vocab_size(),
'src_pad_idx': src_vocab[PADDING_TOKEN],
'tgt_pad_idx': tgt_vocab[PADDING_TOKEN],
'src_start_idx': src_vocab[START_TOKEN],
'tgt_start_idx': tgt_vocab[START_TOKEN],
'src_end_idx': src_vocab[END_TOKEN],
'tgt_end_idx': tgt_vocab[END_TOKEN],
'source_to_index': src_vocab,
'target_to_index': tgt_vocab
}
return train_loader, val_loader, source_tokenizer, target_tokenizer, vocab_info
def greedy_decode(
model,
src_sentence,
source_tokenizer,
target_tokenizer,
vocab_info,
max_length=75,
device=None
):
"""
Greedy decoding for inference - FIXED device issues
Args:
model: Trained transformer model
src_sentence: Source sentence (string)
source_tokenizer: Source language tokenizer
target_tokenizer: Target language tokenizer
vocab_info: Vocabulary information dictionary
max_length: Maximum decoding length
device: Device to use
Returns:
Translated sentence (string)
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
model = model.to(device)
# Encode source sentence
src_ids = source_tokenizer.encode(src_sentence.lower()).ids
src_ids = src_ids[:max_length]
src_padded = src_ids + [vocab_info['src_pad_idx']] * (max_length - len(src_ids))
src = torch.tensor([src_padded], dtype=torch.long).to(device)
# Create source mask on correct device - FIXED
src_mask = (src == vocab_info['src_pad_idx']).unsqueeze(1).unsqueeze(2)
src_mask = src_mask.float() * -1e9
# Start with START token
tgt_ids = [vocab_info['tgt_start_idx']]
with torch.no_grad():
for _ in range(max_length):
# Prepare target input
tgt_padded = tgt_ids + [vocab_info['tgt_pad_idx']] * (max_length - len(tgt_ids))
tgt = torch.tensor([tgt_padded], dtype=torch.long).to(device)
# Create target mask on correct device - FIXED
tgt_len = len(tgt_ids)
tgt_padding_mask = (tgt == vocab_info['tgt_pad_idx']).unsqueeze(1).unsqueeze(2)
tgt_causal_mask = torch.triu(torch.ones(max_length, max_length, device=device), diagonal=1).bool()
tgt_causal_mask = tgt_causal_mask.unsqueeze(0).unsqueeze(0)
tgt_mask = (tgt_padding_mask | tgt_causal_mask).float() * -1e9
# Cross attention mask
cross_mask = src_mask
# Forward pass
outputs = model(
src,
tgt,
encoder_self_attention_mask=src_mask,
decoder_self_attention_mask=tgt_mask,
decoder_cross_attention_mask=cross_mask
)
# Get next token
next_token_logits = outputs[0, tgt_len - 1, :]
next_token = next_token_logits.argmax().item()
# Add to sequence
tgt_ids.append(next_token)
# Stop if END token
if next_token == vocab_info['tgt_end_idx']:
break
# Decode (remove START and END tokens)
tgt_ids = [t for t in tgt_ids if t not in [
vocab_info['tgt_start_idx'],
vocab_info['tgt_end_idx'],
vocab_info['tgt_pad_idx']
]]
translated = target_tokenizer.decode(tgt_ids)
return translated
class TransformerTrainer:
"""
Trainer class for Transformer translation model - OPTIMIZED with AMP
"""
def __init__(
self,
model,
train_loader,
val_loader,
vocab_info,
device=None,
learning_rate=0.0001,
label_smoothing=0.1,
use_amp=True, # Automatic Mixed Precision for 2-3x speedup
gradient_accumulation_steps=1 # Accumulate gradients for larger effective batch
):
# Set device
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
print(f"Using device: {self.device}")
# Move model to device
self.model = model.to(self.device)
self.train_loader = train_loader
self.val_loader = val_loader
self.vocab_info = vocab_info
self.gradient_accumulation_steps = gradient_accumulation_steps
# PERFORMANCE: Automatic Mixed Precision for 2-3x speedup
self.use_amp = use_amp and torch.cuda.is_available()
self.scaler = torch.cuda.amp.GradScaler() if self.use_amp else None
print(f"Mixed Precision Training: {'Enabled' if self.use_amp else 'Disabled'}")
print(f"Gradient Accumulation Steps: {gradient_accumulation_steps}")
# Loss function with label smoothing (ignores padding)
self.criterion = nn.CrossEntropyLoss(
ignore_index=vocab_info['tgt_pad_idx'],
label_smoothing=label_smoothing
)
# Optimizer
self.optimizer = Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)
# Learning rate scheduler
self.scheduler = ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=2
)
self.train_losses = []
self.val_losses = []
def train_epoch(self):
"""Train for one epoch with AMP and gradient accumulation"""
self.model.train()
total_loss = 0
pbar = tqdm(self.train_loader, desc='Training')
for batch_idx, batch in enumerate(pbar):
try:
# Move to device with non_blocking for faster transfer
src = batch['source'].to(self.device, non_blocking=True)
tgt = batch['target'].to(self.device, non_blocking=True)
labels = batch['labels'].to(self.device, non_blocking=True)
# Create masks on the correct device
src_mask, tgt_mask, cross_mask = create_masks(
src, tgt,
self.vocab_info['src_pad_idx'],
self.vocab_info['tgt_pad_idx'],
self.device
)
# PERFORMANCE: Use Automatic Mixed Precision
with torch.cuda.amp.autocast(enabled=self.use_amp):
outputs = self.model(
src,
tgt,
encoder_self_attention_mask=src_mask,
decoder_self_attention_mask=tgt_mask,
decoder_cross_attention_mask=cross_mask
)
# Calculate loss
loss = self.criterion(
outputs.reshape(-1, outputs.size(-1)),
labels.reshape(-1)
)
# Scale loss for gradient accumulation
loss = loss / self.gradient_accumulation_steps
# Backward pass with gradient scaling
if self.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Update weights after accumulating gradients
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
# Gradient clipping
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# Optimizer step
if self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += loss.item() * self.gradient_accumulation_steps
pbar.set_postfix({'loss': f'{loss.item() * self.gradient_accumulation_steps:.4f}'})
except RuntimeError as e:
print(f"\nError in batch {batch_idx}: {str(e)}")
print(f"Source shape: {src.shape}, device: {src.device}")
print(f"Target shape: {tgt.shape}, device: {tgt.device}")
raise e
avg_loss = total_loss / len(self.train_loader)
return avg_loss
def validate(self):
"""Validate the model with AMP"""
self.model.eval()
total_loss = 0
with torch.no_grad():
pbar = tqdm(self.val_loader, desc='Validation')
for batch in pbar:
# Move to device with non_blocking
src = batch['source'].to(self.device, non_blocking=True)
tgt = batch['target'].to(self.device, non_blocking=True)
labels = batch['labels'].to(self.device, non_blocking=True)
# Create masks on the correct device
src_mask, tgt_mask, cross_mask = create_masks(
src, tgt,
self.vocab_info['src_pad_idx'],
self.vocab_info['tgt_pad_idx'],
self.device
)
# Forward pass with AMP
with torch.cuda.amp.autocast(enabled=self.use_amp):
outputs = self.model(
src,
tgt,
encoder_self_attention_mask=src_mask,
decoder_self_attention_mask=tgt_mask,
decoder_cross_attention_mask=cross_mask
)
# Calculate loss
loss = self.criterion(
outputs.reshape(-1, outputs.size(-1)),
labels.reshape(-1)
)
total_loss += loss.item()
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / len(self.val_loader)
return avg_loss
def train(self, num_epochs, save_path='best_model.pt'):
"""
Train the model for specified epochs
Args:
num_epochs: Number of epochs to train
save_path: Path to save best model
"""
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\nEpoch {epoch + 1}/{num_epochs}")
print("-" * 50)
# Train
train_loss = self.train_epoch()
self.train_losses.append(train_loss)
print(f"Training Loss: {train_loss:.4f}")
# Validate
val_loss = self.validate()
self.val_losses.append(val_loss)
print(f"Validation Loss: {val_loss:.4f}")
# Learning rate scheduling
self.scheduler.step(val_loss)
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'vocab_info': self.vocab_info
}, save_path)
print(f"✓ Model saved with validation loss: {val_loss:.4f}")
print("\nTraining completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
return self.train_losses, self.val_losses
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
with open(f'dataset/train.en', 'r') as file:
english_texts = file.readlines()
with open(f'dataset/train.kn', 'r') as file:
kannada_texts = file.readlines()
english_texts= [sentence.rstrip('\n').lower() for sentence in english_texts]
kannada_texts = [sentence.rstrip('\n') for sentence in kannada_texts]
# Prepare data with OPTIMIZED settings for faster training
# CRITICAL: Reduced vocab sizes from 64k/40k to 16k/12k for 3-5x speedup
# The final linear layer size is vocab_size × d_model, so smaller vocab = much faster
train_loader, val_loader, src_tok, tgt_tok, vocab_info = prepare_data(
english_texts,
kannada_texts,
source_vocab_size=50000, # Reduced from 64000 - still captures most words
target_vocab_size=32000, # Reduced from 40000 - 3x faster computation
max_length=75, # Reduced from 100 - fewer tokens to process
batch_size=500 # Reduced from 300 - better GPU utilization
)
# Initialize model with optimized size
# PERFORMANCE: Smaller model = faster training, often better generalization
model = Transformer(
d_model=384, # Reduced from 512 for faster computation
ffn_hidden=1536, # Reduced from 2048 (4x d_model ratio maintained)
num_heads=6, # Reduced from 8 (d_model must be divisible by num_heads)
drop_prob=0.1,
num_layers=4, # Reduced from 6 - still effective for translation
max_sequence_length=75, # Match the max_length from data prep
src_vocab_size=vocab_info['source_vocab_size'],
tgt_vocab_size=vocab_info['target_vocab_size']
)
# Initialize trainer with performance optimizations
trainer = TransformerTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
vocab_info=vocab_info,
device=device,
learning_rate=0.0001,
use_amp=True, # Enable mixed precision for 2-3x speedup
gradient_accumulation_steps=1 # Increase if you get OOM errors
)
# Train
train_losses, val_losses = trainer.train(num_epochs=50, save_path='best_model.pt')
# Inference example
test_sentence = "Hello, how are you?"
translation = greedy_decode(
model,
test_sentence,
src_tok,
tgt_tok,
vocab_info,
device=device # Explicitly pass device
)
print(f"Source: {test_sentence}")
print(f"Translation: {translation}")
print("Training pipeline ready with fixed device handling!")