| | |
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import Transformer |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.nn.utils.rnn import pad_sequence |
| | import pytorch_lightning as pl |
| | from pytorch_lightning.loggers import WandbLogger |
| | from pytorch_lightning.callbacks import ( |
| | ModelCheckpoint, |
| | EarlyStopping, |
| | ) |
| | import math |
| | import os |
| | import pandas as pd |
| | from sklearn.model_selection import train_test_split |
| | import time |
| | import wandb |
| |
|
| |
|
| | from tokenizers import ( |
| | Tokenizer, |
| | models, |
| | pre_tokenizers, |
| | decoders, |
| | trainers, |
| | ) |
| |
|
| | import logging |
| | import gc |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | SRC_VOCAB_SIZE_ESTIMATE = 10000 |
| | TGT_VOCAB_SIZE_ESTIMATE = 14938 |
| | EMB_SIZE = 2048 |
| | NHEAD = 8 |
| | FFN_HID_DIM = ( |
| | 4096 |
| | ) |
| | NUM_ENCODER_LAYERS = 12 |
| | NUM_DECODER_LAYERS = 12 |
| | DROPOUT = 0.1 |
| | MAX_LEN = 384 |
| |
|
| | |
| | ACCELERATOR = "gpu" |
| | DEVICES = 6 |
| | STRATEGY = "ddp" |
| | PRECISION = "16-mixed" |
| | BATCH_SIZE_PER_GPU = 48 |
| | ACCUMULATE_GRAD_BATCHES = ( |
| | 1 |
| | ) |
| | NUM_EPOCHS = 50 |
| | LEARNING_RATE = 5e-5 |
| | WEIGHT_DECAY = 1e-2 |
| | GRAD_CLIP_NORM = 1.0 |
| | VALIDATION_SPLIT = 0.05 |
| | RANDOM_SEED = 42 |
| | PATIENCE = 5 |
| | NUM_WORKERS = 8 |
| |
|
| | |
| | PAD_IDX = 0 |
| | SOS_IDX = 1 |
| | EOS_IDX = 2 |
| | UNK_IDX = 3 |
| |
|
| | |
| | |
| | SMILES_TOKENIZER_FILE = "smiles_bytelevel_bpe_tokenizer_scaled.json" |
| | IUPAC_TOKENIZER_FILE = "iupac_unigram_tokenizer_scaled.json" |
| | INPUT_CSV_FILE = "data_clean.csv" |
| |
|
| | |
| | TRAIN_SMILES_FILE = "train.smi" |
| | TRAIN_IUPAC_FILE = "train.iupac" |
| | VAL_SMILES_FILE = "val.smi" |
| | VAL_IUPAC_FILE = "val.iupac" |
| | CHECKPOINT_DIR = "checkpoints" |
| | BEST_MODEL_FILENAME = ( |
| | "smiles-to-iupac-transformer-best" |
| | ) |
| |
|
| | |
| | WANDB_PROJECT = "SMILES-to-IUPAC-Large-BPE" |
| | WANDB_ENTITY = ( |
| | "adrianmirza" |
| | ) |
| | WANDB_RUN_NAME = f"transformer_BPE_E{EMB_SIZE}_H{NHEAD}_L{NUM_ENCODER_LAYERS}_BS{BATCH_SIZE_PER_GPU * DEVICES}_LR{LEARNING_RATE}" |
| |
|
| | |
| | hparams = { |
| | "src_tokenizer_type": "ByteLevelBPE", |
| | "tgt_tokenizer_type": "Unigram", |
| | "src_vocab_size_estimate": SRC_VOCAB_SIZE_ESTIMATE, |
| | "tgt_vocab_size_estimate": TGT_VOCAB_SIZE_ESTIMATE, |
| | "emb_size": EMB_SIZE, |
| | "nhead": NHEAD, |
| | "ffn_hid_dim": FFN_HID_DIM, |
| | "num_encoder_layers": NUM_ENCODER_LAYERS, |
| | "num_decoder_layers": NUM_DECODER_LAYERS, |
| | "dropout": DROPOUT, |
| | "max_len": MAX_LEN, |
| | "batch_size_per_gpu": BATCH_SIZE_PER_GPU, |
| | "effective_batch_size": BATCH_SIZE_PER_GPU * DEVICES * ACCUMULATE_GRAD_BATCHES, |
| | "num_epochs": NUM_EPOCHS, |
| | "learning_rate": LEARNING_RATE, |
| | "weight_decay": WEIGHT_DECAY, |
| | "grad_clip_norm": GRAD_CLIP_NORM, |
| | "validation_split": VALIDATION_SPLIT, |
| | "random_seed": RANDOM_SEED, |
| | "patience": PATIENCE, |
| | "precision": PRECISION, |
| | "gpus": DEVICES, |
| | "strategy": STRATEGY, |
| | "num_workers": NUM_WORKERS, |
| | } |
| |
|
| | |
| |
|
| |
|
| | |
| | def get_smiles_tokenizer( |
| | train_files=None, |
| | vocab_size=30000, |
| | min_frequency=2, |
| | tokenizer_path=SMILES_TOKENIZER_FILE, |
| | ): |
| | """Creates or loads a Byte-Level BPE tokenizer for SMILES.""" |
| | if os.path.exists(tokenizer_path): |
| | logging.info(f"Loading existing SMILES tokenizer from {tokenizer_path}") |
| | try: |
| | tokenizer = Tokenizer.from_file(tokenizer_path) |
| | |
| | if ( |
| | tokenizer.token_to_id("<pad>") != PAD_IDX |
| | or tokenizer.token_to_id("<sos>") != SOS_IDX |
| | or tokenizer.token_to_id("<eos>") != EOS_IDX |
| | or tokenizer.token_to_id("<unk>") != UNK_IDX |
| | ): |
| | logging.warning( |
| | "Special token ID mismatch after loading SMILES tokenizer. Re-check config." |
| | ) |
| | |
| | if not isinstance(tokenizer.model, models.BPE): |
| | logging.warning( |
| | f"Loaded tokenizer from {tokenizer_path} is not a BPE model. Retraining." |
| | ) |
| | raise TypeError("Incorrect tokenizer model type loaded.") |
| | return tokenizer |
| | except Exception as e: |
| | logging.error(f"Failed to load SMILES tokenizer: {e}. Retraining...") |
| |
|
| | logging.info("Creating and training SMILES Byte-Level BPE tokenizer...") |
| | |
| | tokenizer = Tokenizer(models.BPE(unk_token="<unk>")) |
| |
|
| | |
| | |
| | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
| | |
| | tokenizer.decoder = decoders.ByteLevel() |
| |
|
| | special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"] |
| | |
| | trainer = trainers.BpeTrainer( |
| | vocab_size=vocab_size, |
| | min_frequency=min_frequency, |
| | special_tokens=special_tokens, |
| | |
| | |
| | |
| | ) |
| |
|
| | if train_files and all(os.path.exists(f) for f in train_files): |
| | logging.info(f"Training SMILES BPE tokenizer on: {train_files}") |
| | tokenizer.train(files=train_files, trainer=trainer) |
| | logging.info( |
| | f"SMILES BPE tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}" |
| | ) |
| | |
| | if ( |
| | tokenizer.token_to_id("<pad>") != PAD_IDX |
| | or tokenizer.token_to_id("<sos>") != SOS_IDX |
| | or tokenizer.token_to_id("<eos>") != EOS_IDX |
| | or tokenizer.token_to_id("<unk>") != UNK_IDX |
| | ): |
| | logging.warning( |
| | "Special token ID mismatch after training SMILES BPE tokenizer. Check trainer setup." |
| | ) |
| | try: |
| | tokenizer.save(tokenizer_path) |
| | logging.info(f"SMILES BPE tokenizer saved to {tokenizer_path}") |
| | except Exception as e: |
| | logging.error(f"Failed to save SMILES BPE tokenizer: {e}") |
| | else: |
| | logging.error( |
| | "Training files not provided or not found for SMILES tokenizer. Cannot train." |
| | ) |
| | |
| | tokenizer.add_special_tokens(special_tokens) |
| |
|
| | return tokenizer |
| |
|
| |
|
| | |
| | def get_iupac_tokenizer( |
| | train_files=None, |
| | vocab_size=30000, |
| | min_frequency=2, |
| | tokenizer_path=IUPAC_TOKENIZER_FILE, |
| | ): |
| | """Creates or loads a Unigram tokenizer for IUPAC names.""" |
| | if os.path.exists(tokenizer_path): |
| | logging.info(f"Loading existing IUPAC tokenizer from {tokenizer_path}") |
| | try: |
| | tokenizer = Tokenizer.from_file(tokenizer_path) |
| | if ( |
| | tokenizer.token_to_id("<pad>") != PAD_IDX |
| | or tokenizer.token_to_id("<sos>") != SOS_IDX |
| | or tokenizer.token_to_id("<eos>") != EOS_IDX |
| | or tokenizer.token_to_id("<unk>") != UNK_IDX |
| | ): |
| | logging.warning( |
| | "Special token ID mismatch after loading IUPAC tokenizer. Re-check config." |
| | ) |
| | return tokenizer |
| | except Exception as e: |
| | logging.error(f"Failed to load IUPAC tokenizer: {e}. Retraining...") |
| |
|
| | logging.info("Creating and training IUPAC Unigram tokenizer...") |
| | tokenizer = Tokenizer(models.Unigram()) |
| | |
| | pre_tokenizer_list = [ |
| | pre_tokenizers.WhitespaceSplit(), |
| | pre_tokenizers.Punctuation(), |
| | pre_tokenizers.Digits(individual_digits=True), |
| | ] |
| | |
| | |
| | tokenizer.pre_tokenizer = pre_tokenizers.Sequence(pre_tokenizer_list) |
| | tokenizer.decoder = ( |
| | decoders.Metaspace() |
| | ) |
| | special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"] |
| | trainer = trainers.UnigramTrainer( |
| | vocab_size=vocab_size, |
| | special_tokens=special_tokens, |
| | unk_token="<unk>", |
| | |
| | |
| | |
| | ) |
| |
|
| | if train_files and all(os.path.exists(f) for f in train_files): |
| | logging.info(f"Training IUPAC tokenizer on: {train_files}") |
| | tokenizer.train(files=train_files, trainer=trainer) |
| | logging.info( |
| | f"IUPAC tokenizer trained. Final Vocab size: {tokenizer.get_vocab_size()}" |
| | ) |
| | |
| | if ( |
| | tokenizer.token_to_id("<pad>") != PAD_IDX |
| | or tokenizer.token_to_id("<sos>") != SOS_IDX |
| | or tokenizer.token_to_id("<eos>") != EOS_IDX |
| | or tokenizer.token_to_id("<unk>") != UNK_IDX |
| | ): |
| | logging.warning( |
| | "Special token ID mismatch after training IUPAC tokenizer. Check trainer setup." |
| | ) |
| | try: |
| | tokenizer.save(tokenizer_path) |
| | logging.info(f"IUPAC tokenizer saved to {tokenizer_path}") |
| | except Exception as e: |
| | logging.error(f"Failed to save IUPAC tokenizer: {e}") |
| | else: |
| | logging.error( |
| | "Training files not provided or not found for IUPAC tokenizer. Cannot train." |
| | ) |
| | tokenizer.add_special_tokens(special_tokens) |
| |
|
| | return tokenizer |
| |
|
| |
|
| | |
| | class PositionalEncoding(nn.Module): |
| | """Injects positional information into the input embeddings.""" |
| |
|
| | def __init__(self, emb_size: int, dropout: float, maxlen: int = 5000): |
| | super().__init__() |
| | den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) |
| | pos = torch.arange(0, maxlen).reshape(maxlen, 1) |
| | pos_embedding = torch.zeros((maxlen, emb_size)) |
| | pos_embedding[:, 0::2] = torch.sin(pos * den) |
| | pos_embedding[:, 1::2] = torch.cos(pos * den) |
| | pos_embedding = pos_embedding.unsqueeze( |
| | 0 |
| | ) |
| | self.dropout = nn.Dropout(dropout) |
| | self.register_buffer( |
| | "pos_embedding", pos_embedding |
| | ) |
| |
|
| | def forward(self, token_embedding: torch.Tensor): |
| | |
| | seq_len = token_embedding.size(1) |
| | |
| | |
| | if seq_len > self.pos_embedding.size(1): |
| | logging.warning( |
| | f"Input sequence length ({seq_len}) exceeds PositionalEncoding maxlen ({self.pos_embedding.size(1)}). Truncating positional encoding." |
| | ) |
| | pos_to_add = self.pos_embedding[:, : self.pos_embedding.size(1), :] |
| | |
| | |
| | output = token_embedding[:, : self.pos_embedding.size(1), :] + pos_to_add |
| | else: |
| | pos_to_add = self.pos_embedding[:, :seq_len, :] |
| | output = token_embedding + pos_to_add |
| |
|
| | return self.dropout(output) |
| |
|
| |
|
| | class TokenEmbedding(nn.Module): |
| | """Converts token indices to embeddings.""" |
| |
|
| | def __init__(self, vocab_size: int, emb_size): |
| | super().__init__() |
| | self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=PAD_IDX) |
| | self.emb_size = emb_size |
| |
|
| | def forward(self, tokens: torch.Tensor): |
| | return self.embedding(tokens.long()) * math.sqrt(self.emb_size) |
| |
|
| |
|
| | class Seq2SeqTransformer(nn.Module): |
| | """The main Encoder-Decoder Transformer model.""" |
| |
|
| | def __init__( |
| | self, |
| | num_encoder_layers: int, |
| | num_decoder_layers: int, |
| | emb_size: int, |
| | nhead: int, |
| | src_vocab_size: int, |
| | tgt_vocab_size: int, |
| | dim_feedforward: int, |
| | dropout: float = 0.1, |
| | max_len: int = MAX_LEN, |
| | ): |
| | super().__init__() |
| |
|
| | if emb_size % nhead != 0: |
| | raise ValueError( |
| | f"Embedding size ({emb_size}) must be divisible by the number of heads ({nhead})" |
| | ) |
| |
|
| | self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) |
| | self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) |
| |
|
| | |
| | pe_maxlen = max( |
| | max_len, 5000 |
| | ) |
| | self.positional_encoding = PositionalEncoding( |
| | emb_size, dropout=dropout, maxlen=pe_maxlen |
| | ) |
| |
|
| | self.transformer = Transformer( |
| | d_model=emb_size, |
| | nhead=nhead, |
| | num_encoder_layers=num_encoder_layers, |
| | num_decoder_layers=num_decoder_layers, |
| | dim_feedforward=dim_feedforward, |
| | dropout=dropout, |
| | batch_first=True, |
| | ) |
| |
|
| | self.generator = nn.Linear(emb_size, tgt_vocab_size) |
| | self._init_weights() |
| |
|
| | def _init_weights(self): |
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.xavier_uniform_(p) |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | trg: torch.Tensor, |
| | tgt_mask: torch.Tensor, |
| | src_padding_mask: torch.Tensor, |
| | tgt_padding_mask: torch.Tensor, |
| | memory_key_padding_mask: torch.Tensor, |
| | ): |
| | |
| | |
| | src_padding_mask = src_padding_mask.to(src.device) |
| | tgt_padding_mask = tgt_padding_mask.to(trg.device) |
| | memory_key_padding_mask = memory_key_padding_mask.to(src.device) |
| | |
| | tgt_mask = tgt_mask.to(trg.device) |
| |
|
| | src_emb = self.positional_encoding( |
| | self.src_tok_emb(src) |
| | ) |
| | tgt_emb = self.positional_encoding( |
| | self.tgt_tok_emb(trg) |
| | ) |
| |
|
| | outs = self.transformer( |
| | src=src_emb, |
| | tgt=tgt_emb, |
| | src_mask=None, |
| | tgt_mask=tgt_mask, |
| | memory_mask=None, |
| | src_key_padding_mask=src_padding_mask, |
| | tgt_key_padding_mask=tgt_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | ) |
| | |
| | return self.generator(outs) |
| |
|
| | def encode(self, src: torch.Tensor, src_padding_mask: torch.Tensor): |
| | src_padding_mask = src_padding_mask.to( |
| | src.device |
| | ) |
| | src_emb = self.positional_encoding( |
| | self.src_tok_emb(src) |
| | ) |
| | memory = self.transformer.encoder( |
| | src_emb, mask=None, src_key_padding_mask=src_padding_mask |
| | ) |
| | return memory |
| |
|
| | def decode( |
| | self, |
| | tgt: torch.Tensor, |
| | memory: torch.Tensor, |
| | tgt_mask: torch.Tensor, |
| | tgt_padding_mask: torch.Tensor, |
| | memory_key_padding_mask: torch.Tensor, |
| | ): |
| | |
| | tgt_mask = tgt_mask.to(tgt.device) |
| | tgt_padding_mask = tgt_padding_mask.to(tgt.device) |
| | memory_key_padding_mask = memory_key_padding_mask.to(memory.device) |
| |
|
| | tgt_emb = self.positional_encoding( |
| | self.tgt_tok_emb(tgt) |
| | ) |
| | output = self.transformer.decoder( |
| | tgt=tgt_emb, |
| | memory=memory, |
| | tgt_mask=tgt_mask, |
| | memory_mask=None, |
| | tgt_key_padding_mask=tgt_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | ) |
| | return output |
| |
|
| |
|
| | |
| | def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor: |
| | """Generates an upper-triangular matrix for causal masking.""" |
| | mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1) |
| | mask = ( |
| | mask.float() |
| | .masked_fill(mask == 0, float("-inf")) |
| | .masked_fill(mask == 1, float(0.0)) |
| | ) |
| | return mask |
| |
|
| |
|
| | def create_masks( |
| | src: torch.Tensor, tgt: torch.Tensor, pad_idx: int, device: torch.device |
| | ): |
| | """ |
| | Creates all necessary masks for the Transformer model. |
| | Assumes src and tgt are inputs to the forward pass (tgt includes SOS, excludes EOS). |
| | Returns boolean masks where True indicates the position should be masked (ignored). |
| | """ |
| | src_seq_len = src.shape[1] |
| | tgt_seq_len = tgt.shape[1] |
| |
|
| | |
| | tgt_mask = generate_square_subsequent_mask( |
| | tgt_seq_len, device |
| | ) |
| |
|
| | |
| | src_padding_mask = src == pad_idx |
| | tgt_padding_mask = tgt == pad_idx |
| | memory_key_padding_mask = ( |
| | src_padding_mask |
| | ) |
| |
|
| | return tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask |
| |
|
| |
|
| | |
| | class SmilesIupacDataset(Dataset): |
| | """Dataset class for SMILES-IUPAC pairs, reading from pre-split files.""" |
| |
|
| | def __init__(self, smiles_file: str, iupac_file: str): |
| | logging.info(f"Loading data from {smiles_file} and {iupac_file}") |
| | try: |
| | with open(smiles_file, "r", encoding="utf-8") as f_smi: |
| | self.smiles = [line.strip() for line in f_smi if line.strip()] |
| | with open(iupac_file, "r", encoding="utf-8") as f_iupac: |
| | self.iupac = [line.strip() for line in f_iupac if line.strip()] |
| |
|
| | if len(self.smiles) != len(self.iupac): |
| | logging.warning( |
| | f"Mismatch in number of lines: {smiles_file} ({len(self.smiles)}) vs {iupac_file} ({len(self.iupac)}). Trimming." |
| | ) |
| | min_len = min(len(self.smiles), len(self.iupac)) |
| | self.smiles = self.smiles[:min_len] |
| | self.iupac = self.iupac[:min_len] |
| |
|
| | logging.info( |
| | f"Loaded {len(self.smiles)} pairs from {smiles_file}/{iupac_file}." |
| | ) |
| | if len(self.smiles) == 0: |
| | logging.warning(f"Loaded 0 data pairs. Check files.") |
| |
|
| | except FileNotFoundError: |
| | logging.error( |
| | f"Error: One or both files not found: {smiles_file}, {iupac_file}" |
| | ) |
| | raise |
| | except Exception as e: |
| | logging.error(f"Error loading data: {e}") |
| | raise |
| |
|
| | def __len__(self): |
| | return len(self.smiles) |
| |
|
| | def __getitem__(self, idx): |
| | return self.smiles[idx], self.iupac[idx] |
| |
|
| |
|
| | def collate_fn( |
| | batch, smiles_tokenizer, iupac_tokenizer, pad_idx, sos_idx, eos_idx, max_len |
| | ): |
| | """Collates data samples into batches.""" |
| | src_batch, tgt_batch = [], [] |
| | skipped_count = 0 |
| | for src_sample, tgt_sample in batch: |
| | try: |
| | |
| | src_encoded = smiles_tokenizer.encode(src_sample) |
| | |
| | src_ids = src_encoded.ids[:max_len] |
| | if not src_ids: |
| | skipped_count += 1 |
| | continue |
| | src_tensor = torch.tensor(src_ids, dtype=torch.long) |
| |
|
| | |
| | tgt_encoded = iupac_tokenizer.encode(tgt_sample) |
| | |
| | tgt_ids = tgt_encoded.ids[: max_len - 2] |
| | if ( |
| | not tgt_ids |
| | ): |
| | skipped_count += 1 |
| | continue |
| | |
| | tgt_tensor = torch.tensor([sos_idx] + tgt_ids + [eos_idx], dtype=torch.long) |
| |
|
| | src_batch.append(src_tensor) |
| | tgt_batch.append(tgt_tensor) |
| | except Exception as e: |
| | |
| | |
| | |
| | skipped_count += 1 |
| | continue |
| |
|
| | |
| | |
| |
|
| | if not src_batch or not tgt_batch: |
| | |
| | return torch.tensor([]), torch.tensor([]) |
| |
|
| | try: |
| | |
| | src_batch_padded = pad_sequence( |
| | src_batch, batch_first=True, padding_value=pad_idx |
| | ) |
| | tgt_batch_padded = pad_sequence( |
| | tgt_batch, batch_first=True, padding_value=pad_idx |
| | ) |
| | except Exception as e: |
| | logging.error( |
| | f"Error during padding: {e}. Src lengths: {[len(s) for s in src_batch]}, Tgt lengths: {[len(t) for t in tgt_batch]}" |
| | ) |
| | |
| | return torch.tensor([]), torch.tensor([]) |
| |
|
| | return src_batch_padded, tgt_batch_padded |
| |
|
| |
|
| | |
| | class SmilesIupacLitModule(pl.LightningModule): |
| | def __init__( |
| | self, src_vocab_size: int, tgt_vocab_size: int, hparams_dict: dict |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.save_hyperparameters(hparams_dict) |
| |
|
| | self.model = Seq2SeqTransformer( |
| | num_encoder_layers=self.hparams.num_encoder_layers, |
| | num_decoder_layers=self.hparams.num_decoder_layers, |
| | emb_size=self.hparams.emb_size, |
| | nhead=self.hparams.nhead, |
| | src_vocab_size=src_vocab_size, |
| | tgt_vocab_size=tgt_vocab_size, |
| | dim_feedforward=self.hparams.ffn_hid_dim, |
| | dropout=self.hparams.dropout, |
| | max_len=self.hparams.max_len, |
| | ) |
| |
|
| | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
| |
|
| | |
| | total_params = sum(p.numel() for p in self.model.parameters()) |
| | trainable_params = sum( |
| | p.numel() for p in self.model.parameters() if p.requires_grad |
| | ) |
| | logging.info(f"Model Initialized:") |
| | logging.info(f" Total Parameters: {total_params / 1_000_000:.2f} M") |
| | logging.info(f" Trainable Parameters: {trainable_params / 1_000_000:.2f} M") |
| | |
| | |
| | |
| | self.hparams.total_params_M = round(total_params / 1_000_000, 2) |
| | self.hparams.trainable_params_M = round(trainable_params / 1_000_000, 2) |
| |
|
| | def forward(self, src, tgt): |
| | |
| | |
| | |
| | tgt_input = tgt[:, :-1] |
| | tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = ( |
| | create_masks( |
| | src, |
| | tgt_input, |
| | PAD_IDX, |
| | self.device, |
| | ) |
| | ) |
| | logits = self.model( |
| | src, |
| | tgt_input, |
| | tgt_mask, |
| | src_padding_mask, |
| | tgt_padding_mask, |
| | memory_key_padding_mask, |
| | ) |
| | return logits |
| |
|
| | def training_step(self, batch, batch_idx): |
| | src, tgt = batch |
| | if src.numel() == 0 or tgt.numel() == 0: |
| | |
| | return None |
| |
|
| | tgt_input = tgt[:, :-1] |
| | tgt_out = tgt[:, 1:] |
| |
|
| | |
| | tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = ( |
| | create_masks(src, tgt_input, PAD_IDX, self.device) |
| | ) |
| |
|
| | try: |
| | logits = self.model( |
| | src=src, |
| | trg=tgt_input, |
| | tgt_mask=tgt_mask, |
| | src_padding_mask=src_padding_mask, |
| | tgt_padding_mask=tgt_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | ) |
| | |
| |
|
| | |
| | |
| | |
| | loss = self.criterion( |
| | logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) |
| | ) |
| |
|
| | |
| | if not torch.isfinite(loss): |
| | logging.warning( |
| | f"Non-finite loss encountered in training step {batch_idx}: {loss.item()}. Skipping update." |
| | ) |
| | |
| | |
| | return None |
| |
|
| | |
| | |
| | self.log( |
| | "train_loss", |
| | loss, |
| | on_step=True, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | sync_dist=True, |
| | batch_size=src.size(0), |
| | ) |
| |
|
| | return loss |
| |
|
| | except RuntimeError as e: |
| | if "CUDA out of memory" in str(e): |
| | logging.warning( |
| | f"CUDA OOM error during training step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch." |
| | ) |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | return None |
| | else: |
| | logging.error(f"Runtime error during training step {batch_idx}: {e}") |
| | |
| | logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}") |
| | return None |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | src, tgt = batch |
| | if src.numel() == 0 or tgt.numel() == 0: |
| | |
| | return None |
| |
|
| | tgt_input = tgt[:, :-1] |
| | tgt_out = tgt[:, 1:] |
| |
|
| | tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask = ( |
| | create_masks(src, tgt_input, PAD_IDX, self.device) |
| | ) |
| |
|
| | try: |
| | logits = self.model( |
| | src, |
| | tgt_input, |
| | tgt_mask, |
| | src_padding_mask, |
| | tgt_padding_mask, |
| | memory_key_padding_mask, |
| | ) |
| | loss = self.criterion( |
| | logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) |
| | ) |
| |
|
| | if torch.isfinite(loss): |
| | |
| | |
| | self.log( |
| | "val_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | sync_dist=True, |
| | batch_size=src.size(0), |
| | ) |
| | else: |
| | logging.warning( |
| | f"Non-finite loss encountered during validation step {batch_idx}: {loss.item()}." |
| | ) |
| | |
| | |
| | |
| |
|
| | except RuntimeError as e: |
| | |
| | logging.error(f"Runtime error during validation step {batch_idx}: {e}") |
| | if "CUDA out of memory" in str(e): |
| | logging.warning( |
| | f"CUDA OOM error during validation step {batch_idx} with shape src: {src.shape}, tgt: {tgt.shape}. Skipping batch." |
| | ) |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | else: |
| | logging.error(f"Shapes - src: {src.shape}, tgt: {tgt.shape}") |
| | |
| | |
| | return None |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW( |
| | self.parameters(), |
| | lr=self.hparams.learning_rate, |
| | weight_decay=self.hparams.weight_decay, |
| | ) |
| |
|
| | |
| | |
| | |
| | try: |
| | from transformers import get_linear_schedule_with_warmup |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | try: |
| | |
| | num_training_steps = self.trainer.estimated_stepping_batches |
| | logging.info( |
| | f"Estimated stepping batches for LR schedule: {num_training_steps}" |
| | ) |
| | if num_training_steps is None or num_training_steps <= 0: |
| | logging.warning( |
| | "Could not estimate stepping batches, using fallback for LR schedule." |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | num_training_steps = 1_000_000 |
| | except AttributeError: |
| | logging.warning( |
| | "self.trainer not available yet in configure_optimizers. Using fallback step count for LR schedule." |
| | ) |
| | num_training_steps = 1_000_000 |
| |
|
| | |
| | num_warmup_steps = int(0.05 * num_training_steps) |
| | logging.info( |
| | f"LR Scheduler: Total steps ~{num_training_steps}, Warmup steps: {num_warmup_steps}" |
| | ) |
| |
|
| | scheduler = get_linear_schedule_with_warmup( |
| | optimizer, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps, |
| | ) |
| |
|
| | lr_scheduler_config = { |
| | "scheduler": scheduler, |
| | "interval": "step", |
| | "frequency": 1, |
| | "name": "linear_warmup_decay_lr", |
| | } |
| | logging.info("Using Linear Warmup/Decay LR Scheduler.") |
| | return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} |
| |
|
| | except ImportError: |
| | logging.warning( |
| | "'transformers' library not found. Cannot create linear warmup scheduler. Using constant LR." |
| | ) |
| | return optimizer |
| | except Exception as e: |
| | logging.error( |
| | f"Error setting up LR scheduler: {e}. Using constant LR.", exc_info=True |
| | ) |
| | return optimizer |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | def greedy_decode( |
| | model: pl.LightningModule, |
| | src: torch.Tensor, |
| | src_padding_mask: torch.Tensor, |
| | max_len: int, |
| | sos_idx: int, |
| | eos_idx: int, |
| | device: torch.device, |
| | ) -> torch.Tensor: |
| | """Performs greedy decoding using the LightningModule's model.""" |
| | |
| | transformer_model = model.model |
| |
|
| | try: |
| | with torch.no_grad(): |
| | |
| | memory = transformer_model.encode( |
| | src, src_padding_mask |
| | ) |
| | memory = memory.to(device) |
| | |
| | memory_key_padding_mask = src_padding_mask.to(memory.device) |
| |
|
| | ys = ( |
| | torch.ones(1, 1).fill_(sos_idx).type(torch.long).to(device) |
| | ) |
| |
|
| | for i in range(max_len - 1): |
| | tgt_seq_len = ys.shape[1] |
| | |
| | tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device).to( |
| | device |
| | ) |
| | |
| | tgt_padding_mask = torch.zeros(ys.shape, dtype=torch.bool).to( |
| | device |
| | ) |
| |
|
| | |
| | out = transformer_model.decode( |
| | ys, memory, tgt_mask, tgt_padding_mask, memory_key_padding_mask |
| | ) |
| | |
| |
|
| | |
| | last_token_logits = transformer_model.generator( |
| | out[:, -1, :] |
| | ) |
| | prob = last_token_logits |
| | _, next_word = torch.max(prob, dim=1) |
| | next_word = next_word.item() |
| |
|
| | |
| | ys = torch.cat( |
| | [ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1 |
| | ) |
| |
|
| | |
| | if next_word == eos_idx: |
| | break |
| | |
| | return ys[:, 1:] |
| |
|
| | except RuntimeError as e: |
| | logging.error(f"Runtime error during greedy decode: {e}") |
| | if "CUDA out of memory" in str(e): |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | |
| | return torch.tensor([[]], dtype=torch.long, device=device) |
| |
|
| |
|
| | def translate( |
| | model: pl.LightningModule, |
| | src_sentence: str, |
| | smiles_tokenizer, |
| | iupac_tokenizer, |
| | device: torch.device, |
| | max_len: int, |
| | sos_idx: int, |
| | eos_idx: int, |
| | pad_idx: int, |
| | ) -> str: |
| | """Translates a single SMILES string using the LightningModule.""" |
| | model.eval() |
| |
|
| | try: |
| | src_encoded = smiles_tokenizer.encode(src_sentence) |
| | if not src_encoded or len(src_encoded.ids) == 0: |
| | logging.warning(f"Encoding failed for SMILES: {src_sentence}") |
| | return "[Encoding Error]" |
| | |
| | src_ids = src_encoded.ids[:max_len] |
| | if not src_ids: |
| | logging.warning( |
| | f"Source sequence empty after truncation for SMILES: {src_sentence}" |
| | ) |
| | return "[Encoding Error - Empty Src]" |
| |
|
| | except Exception as e: |
| | logging.error(f"Error tokenizing SMILES '{src_sentence}': {e}") |
| | return "[Encoding Error]" |
| |
|
| | |
| | src = ( |
| | torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(device) |
| | ) |
| | |
| | |
| | |
| | |
| | src_padding_mask = src == pad_idx |
| |
|
| | |
| | tgt_tokens_tensor = greedy_decode( |
| | model=model, |
| | src=src, |
| | src_padding_mask=src_padding_mask, |
| | max_len=max_len, |
| | sos_idx=sos_idx, |
| | eos_idx=eos_idx, |
| | device=device, |
| | ) |
| |
|
| | |
| | if tgt_tokens_tensor.numel() > 0: |
| | tgt_tokens = tgt_tokens_tensor.flatten().cpu().numpy().tolist() |
| | try: |
| | |
| | translation = iupac_tokenizer.decode(tgt_tokens, skip_special_tokens=True) |
| | return translation |
| | except Exception as e: |
| | logging.error(f"Error decoding target tokens {tgt_tokens}: {e}") |
| | return "[Decoding Error]" |
| | else: |
| | |
| | |
| | return "[Decoding Error - Empty Output]" |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | pl.seed_everything(RANDOM_SEED, workers=True) |
| |
|
| | |
| | os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| |
|
| | |
| | |
| | logging.info(f"Loading and splitting data from {INPUT_CSV_FILE}...") |
| | |
| | try: |
| | |
| | df = pd.read_csv(INPUT_CSV_FILE, dtype={"SMILES": str, "Systematic": str}) |
| | logging.info(f"Initial rows loaded: {len(df)}") |
| | if "SMILES" not in df.columns: |
| | raise ValueError("CSV must contain 'SMILES' column.") |
| | if "Systematic" not in df.columns: |
| | raise ValueError("CSV must contain 'Systematic' (IUPAC name) column.") |
| | df.rename(columns={"Systematic": "IUPAC"}, inplace=True) |
| |
|
| | initial_rows = len(df) |
| | df.dropna(subset=["SMILES", "IUPAC"], inplace=True) |
| | rows_after_na = len(df) |
| | if initial_rows > rows_after_na: |
| | logging.info( |
| | f"Dropped {initial_rows - rows_after_na} rows with missing values." |
| | ) |
| | |
| | df = df[df["SMILES"].str.strip().astype(bool)] |
| | df = df[df["IUPAC"].str.strip().astype(bool)] |
| | df["SMILES"] = df["SMILES"].str.strip() |
| | df["IUPAC"] = df["IUPAC"].str.strip() |
| | rows_after_empty = len(df) |
| | if rows_after_na > rows_after_empty: |
| | logging.info( |
| | f"Dropped {rows_after_na - rows_after_empty} rows with empty strings after stripping." |
| | ) |
| |
|
| | smiles_data = df["SMILES"].tolist() |
| | iupac_data = df["IUPAC"].tolist() |
| | logging.info(f"Loaded {len(smiles_data)} valid pairs from CSV.") |
| | del df |
| | gc.collect() |
| |
|
| | if len(smiles_data) < 10: |
| | raise ValueError( |
| | f"Not enough valid data ({len(smiles_data)}) for split. Need at least 10." |
| | ) |
| |
|
| | train_smi, val_smi, train_iupac, val_iupac = train_test_split( |
| | smiles_data, |
| | iupac_data, |
| | test_size=VALIDATION_SPLIT, |
| | random_state=RANDOM_SEED, |
| | ) |
| | logging.info(f"Split: {len(train_smi)} train, {len(val_smi)} validation.") |
| | del smiles_data, iupac_data |
| | gc.collect() |
| |
|
| | logging.info("Writing split data to files...") |
| | with open(TRAIN_SMILES_FILE, "w", encoding="utf-8") as f: |
| | f.write("\n".join(train_smi)) |
| | with open(TRAIN_IUPAC_FILE, "w", encoding="utf-8") as f: |
| | f.write("\n".join(train_iupac)) |
| | with open(VAL_SMILES_FILE, "w", encoding="utf-8") as f: |
| | f.write("\n".join(val_smi)) |
| | with open(VAL_IUPAC_FILE, "w", encoding="utf-8") as f: |
| | f.write("\n".join(val_iupac)) |
| | logging.info( |
| | f"Split files written: {TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}, {VAL_SMILES_FILE}, {VAL_IUPAC_FILE}" |
| | ) |
| | del train_smi, val_smi, train_iupac, val_iupac |
| | gc.collect() |
| |
|
| | except FileNotFoundError: |
| | logging.error(f"Fatal error: Input CSV file not found at {INPUT_CSV_FILE}") |
| | exit(1) |
| | except ValueError as ve: |
| | logging.error(f"Fatal error during data preparation: {ve}") |
| | exit(1) |
| | except Exception as e: |
| | logging.error(f"Fatal error during data preparation: {e}", exc_info=True) |
| | exit(1) |
| | |
| |
|
| | |
| | logging.info("Initializing Tokenizers...") |
| | |
| | if not os.path.exists(TRAIN_SMILES_FILE) or not os.path.exists(TRAIN_IUPAC_FILE): |
| | logging.error( |
| | f"Training files ({TRAIN_SMILES_FILE}, {TRAIN_IUPAC_FILE}) not found. Cannot train tokenizers." |
| | ) |
| | exit(1) |
| |
|
| | smiles_tokenizer = get_smiles_tokenizer( |
| | train_files=[TRAIN_SMILES_FILE], |
| | vocab_size=SRC_VOCAB_SIZE_ESTIMATE, |
| | tokenizer_path=SMILES_TOKENIZER_FILE, |
| | ) |
| | iupac_tokenizer = get_iupac_tokenizer( |
| | train_files=[TRAIN_IUPAC_FILE], |
| | vocab_size=TGT_VOCAB_SIZE_ESTIMATE, |
| | tokenizer_path=IUPAC_TOKENIZER_FILE, |
| | ) |
| |
|
| | ACTUAL_SRC_VOCAB_SIZE = smiles_tokenizer.get_vocab_size() |
| | ACTUAL_TGT_VOCAB_SIZE = iupac_tokenizer.get_vocab_size() |
| | logging.info(f"Actual SMILES Vocab Size: {ACTUAL_SRC_VOCAB_SIZE}") |
| | logging.info(f"Actual IUPAC Vocab Size: {ACTUAL_TGT_VOCAB_SIZE}") |
| | |
| | hparams["actual_src_vocab_size"] = ACTUAL_SRC_VOCAB_SIZE |
| | hparams["actual_tgt_vocab_size"] = ACTUAL_TGT_VOCAB_SIZE |
| |
|
| | |
| | |
| | if WANDB_ENTITY is None: |
| | logging.warning( |
| | "WANDB_ENTITY not set. WandB will log to your default entity. Set WANDB_ENTITY='your_username_or_team' to specify." |
| | ) |
| |
|
| | wandb_logger = WandbLogger( |
| | project=WANDB_PROJECT, |
| | entity=WANDB_ENTITY, |
| | name=WANDB_RUN_NAME, |
| | config=hparams, |
| | |
| | |
| | ) |
| |
|
| | |
| | logging.info("Creating Datasets and DataLoaders...") |
| | try: |
| | train_dataset = SmilesIupacDataset(TRAIN_SMILES_FILE, TRAIN_IUPAC_FILE) |
| | val_dataset = SmilesIupacDataset(VAL_SMILES_FILE, VAL_IUPAC_FILE) |
| | if len(train_dataset) == 0 or len(val_dataset) == 0: |
| | logging.error( |
| | "Training or validation dataset is empty. Check data splitting and file content." |
| | ) |
| | exit(1) |
| | except Exception as e: |
| | logging.error(f"Error creating Datasets: {e}", exc_info=True) |
| | exit(1) |
| |
|
| | |
| | def collate_fn_partial(batch): |
| | return collate_fn( |
| | batch, |
| | smiles_tokenizer, |
| | iupac_tokenizer, |
| | PAD_IDX, |
| | SOS_IDX, |
| | EOS_IDX, |
| | hparams["max_len"], |
| | ) |
| |
|
| | |
| | persistent_workers = NUM_WORKERS > 0 and STRATEGY == "ddp" |
| |
|
| | train_dataloader = DataLoader( |
| | train_dataset, |
| | batch_size=BATCH_SIZE_PER_GPU, |
| | shuffle=True, |
| | collate_fn=collate_fn_partial, |
| | num_workers=NUM_WORKERS, |
| | pin_memory=True, |
| | persistent_workers=persistent_workers, |
| | drop_last=True, |
| | ) |
| | val_dataloader = DataLoader( |
| | val_dataset, |
| | batch_size=BATCH_SIZE_PER_GPU, |
| | shuffle=False, |
| | collate_fn=collate_fn_partial, |
| | num_workers=NUM_WORKERS, |
| | pin_memory=True, |
| | persistent_workers=persistent_workers, |
| | drop_last=False, |
| | ) |
| |
|
| | |
| | logging.info("Initializing Lightning Module...") |
| | |
| | model = SmilesIupacLitModule( |
| | src_vocab_size=ACTUAL_SRC_VOCAB_SIZE, |
| | tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE, |
| | hparams_dict=hparams, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | checkpoint_callback = ModelCheckpoint( |
| | dirpath=CHECKPOINT_DIR, |
| | filename=BEST_MODEL_FILENAME + "-{epoch:02d}-{val_loss:.4f}", |
| | save_top_k=1, |
| | verbose=True, |
| | monitor="val_loss", |
| | mode="min", |
| | save_last=True, |
| | ) |
| | early_stopping_callback = EarlyStopping( |
| | monitor="val_loss", |
| | patience=PATIENCE, |
| | verbose=True, |
| | mode="min", |
| | ) |
| |
|
| | |
| | logging.info( |
| | f"Initializing PyTorch Lightning Trainer (GPUs={DEVICES}, Strategy='{STRATEGY}', Precision='{PRECISION}')..." |
| | ) |
| | trainer = pl.Trainer( |
| | accelerator=ACCELERATOR, |
| | devices=DEVICES, |
| | strategy=STRATEGY, |
| | precision=PRECISION, |
| | max_epochs=NUM_EPOCHS, |
| | logger=wandb_logger, |
| | callbacks=[checkpoint_callback, early_stopping_callback], |
| | gradient_clip_val=GRAD_CLIP_NORM, |
| | accumulate_grad_batches=ACCUMULATE_GRAD_BATCHES, |
| | log_every_n_steps=50, |
| | |
| | |
| | |
| | |
| | ) |
| |
|
| | |
| | logging.info( |
| | f"Starting training with Effective Batch Size: {hparams['effective_batch_size']}..." |
| | ) |
| | start_time = time.time() |
| | try: |
| | trainer.fit(model, train_dataloader, val_dataloader) |
| | training_duration = time.time() - start_time |
| | logging.info( |
| | f"Training finished in {training_duration / 3600:.2f} hours ({training_duration:.2f} seconds)." |
| | ) |
| |
|
| | |
| | best_path = checkpoint_callback.best_model_path |
| | best_score = checkpoint_callback.best_model_score |
| | if best_score is not None: |
| | logging.info( |
| | f"Best model checkpoint saved at: {best_path} with val_loss: {best_score.item():.4f}" |
| | ) |
| | |
| | wandb_logger.experiment.summary["best_val_loss"] = best_score.item() |
| | wandb_logger.experiment.summary["best_model_path"] = best_path |
| | else: |
| | logging.warning( |
| | "Could not retrieve best model score from checkpoint callback." |
| | ) |
| |
|
| | except Exception as e: |
| | logging.error(f"Fatal error during training: {e}", exc_info=True) |
| | |
| | if wandb.run is not None: |
| | wandb.finish(exit_code=1) |
| | exit(1) |
| |
|
| | |
| | best_model_path_to_load = checkpoint_callback.best_model_path |
| | logging.info( |
| | f"\nLoading best model from {best_model_path_to_load} for translation examples..." |
| | ) |
| | final_model = None |
| | if best_model_path_to_load and os.path.exists(best_model_path_to_load): |
| | try: |
| | |
| | |
| | final_model = SmilesIupacLitModule.load_from_checkpoint( |
| | best_model_path_to_load, |
| | |
| | |
| | src_vocab_size=ACTUAL_SRC_VOCAB_SIZE, |
| | tgt_vocab_size=ACTUAL_TGT_VOCAB_SIZE, |
| | hparams_dict=hparams, |
| | ) |
| | |
| | inference_device = torch.device( |
| | f"{ACCELERATOR}:0" |
| | if ACCELERATOR == "gpu" and torch.cuda.is_available() |
| | else "cpu" |
| | ) |
| | final_model = final_model.to(inference_device) |
| | final_model.eval() |
| | final_model.freeze() |
| | logging.info( |
| | f"Best model loaded successfully to {inference_device} for final translation." |
| | ) |
| | except Exception as e: |
| | logging.error( |
| | f"Error loading saved model from {best_model_path_to_load}: {e}", |
| | exc_info=True, |
| | ) |
| | final_model = None |
| | else: |
| | logging.error( |
| | f"Error: Best model checkpoint path not found or invalid: '{best_model_path_to_load}'. Cannot perform final translation." |
| | ) |
| |
|
| | |
| | if final_model: |
| | logging.info("\n--- Example Translations (using validation data) ---") |
| | num_examples = 20 |
| | try: |
| | |
| | val_smi_examples = [] |
| | val_iupac_examples = [] |
| | if os.path.exists(VAL_SMILES_FILE) and os.path.exists(VAL_IUPAC_FILE): |
| | with ( |
| | open(VAL_SMILES_FILE, "r", encoding="utf-8") as f_smi, |
| | open(VAL_IUPAC_FILE, "r", encoding="utf-8") as f_iupac, |
| | ): |
| | for i, (smi_line, iupac_line) in enumerate(zip(f_smi, f_iupac)): |
| | if i >= num_examples: |
| | break |
| | val_smi_examples.append(smi_line.strip()) |
| | val_iupac_examples.append(iupac_line.strip()) |
| | else: |
| | logging.warning( |
| | f"Validation files ({VAL_SMILES_FILE}, {VAL_IUPAC_FILE}) not found. Cannot show examples." |
| | ) |
| |
|
| | if len(val_smi_examples) > 0: |
| | print("\n" + "=" * 40) |
| | print( |
| | f"Example Translations (First {len(val_smi_examples)} Validation Samples)" |
| | ) |
| | print("=" * 40) |
| | |
| | inference_device = next(final_model.parameters()).device |
| | translation_examples = [] |
| | for i in range(len(val_smi_examples)): |
| | smi = val_smi_examples[i] |
| | true_iupac = val_iupac_examples[i] |
| | predicted_iupac = translate( |
| | model=final_model, |
| | src_sentence=smi, |
| | smiles_tokenizer=smiles_tokenizer, |
| | iupac_tokenizer=iupac_tokenizer, |
| | device=inference_device, |
| | max_len=hparams["max_len"], |
| | sos_idx=SOS_IDX, |
| | eos_idx=EOS_IDX, |
| | pad_idx=PAD_IDX, |
| | ) |
| | print(f"\nExample {i + 1}:") |
| | print(f" SMILES: {smi}") |
| | print(f" True IUPAC: {true_iupac}") |
| | print(f" Predicted IUPAC: {predicted_iupac}") |
| | print("-" * 30) |
| | |
| | translation_examples.append([smi, true_iupac, predicted_iupac]) |
| |
|
| | print("=" * 40 + "\n") |
| |
|
| | |
| | try: |
| | columns = ["SMILES", "True IUPAC", "Predicted IUPAC"] |
| | wandb_table = wandb.Table( |
| | data=translation_examples, columns=columns |
| | ) |
| | wandb_logger.experiment.log( |
| | {"validation_translations": wandb_table} |
| | ) |
| | logging.info("Logged translation examples to WandB Table.") |
| | except Exception as wb_err: |
| | logging.error( |
| | f"Failed to log translation examples to WandB: {wb_err}" |
| | ) |
| |
|
| | else: |
| | logging.warning("Could not load validation samples for examples.") |
| | except Exception as e: |
| | logging.error(f"Error during example translation phase: {e}", exc_info=True) |
| | else: |
| | logging.warning( |
| | "Skipping final translation examples as the best model could not be loaded." |
| | ) |
| |
|
| | |
| | if wandb.run is not None: |
| | wandb.finish() |
| | logging.info("WandB run finished.") |
| | else: |
| | logging.info("No active WandB run to finish.") |
| |
|
| | logging.info("Script finished.") |
| |
|