In [1]:
from transformers import PreTrainedModel, AutoConfig, BertModel, BertTokenizerFast, BertConfig, AutoModel, AutoTokenizer
import pandas as pd
import torch
import os
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import joblib

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
df = pd.read_csv('/home/jovyan/simson_training_bolgov/kaggle_comp/train.csv')

targets = ['Tg', 'FFV', 'Tc', 'Density', 'Rg']

In [3]:
for i in range(1, 5):
    supplement_path = f'/home/jovyan/simson_training_bolgov/kaggle_comp/train_supplement/dataset{i}.csv'
    supplement_ds = pd.read_csv(supplement_path)

    if 'TC_mean' in supplement_ds.columns:
        supplement_ds = supplement_ds.rename(columns = {'TC_mean': 'Tc'})

    df = pd.concat([df, supplement_ds], axis=0)

df = df.sample(frac=1).reset_index(drop=True)
df

Unnamed: 0,id,SMILES,Tg,FFV,Tc,Density,Rg
0,4.215886e+08,*C(=O)c1ccc2c(c1)C(=O)N(c1ccc(Oc3ccc(Oc4ccc(N5...,,0.376767,,,
1,7.984549e+08,*c1ccc2c(c1)C(=O)N(c1ccc(Oc3ccc(N4C(=O)c5ccc(C...,,0.346993,,,
2,,*CC/C=C(/*)C,,,,,
3,,*CC(*)(C)C(=O)OCCN(CC)c1ccc(/N=N/c2ccc(OC)cc2)...,,,,,
4,,*Oc1cc(OC(=O)c2ccc(OCC)cc2)c(OC(=O)CCCC(*)=O)c...,,,,,
...,...,...,...,...,...,...,...
16958,2.389975e+08,*OC(=O)Oc1ccc(S(=O)(=O)c2ccc(OC(=O)OC3CC4CC(*)...,,0.339596,,,
16959,,*c1ccc(Oc2ccc(S(=O)(=O)c3ccc(Oc4ccc(N5C(=O)c6c...,,,,,
16960,,*OC(F)(F)COC(=O)c1cc(OCCCCC)cc(C(=O)OCC(*)(F)F)c1,,,,,
16961,1.973417e+09,*C=CC1CC(*)C2C(=O)N(c3ccc(F)cc3)C(=O)C12,,0.374710,,,


In [10]:
import pandas as pd
import numpy as np
from rdkit import Chem
import random
from typing import Optional, List, Union

def augment_smiles_dataset(df: pd.DataFrame,
                               smiles_column: str = 'SMILES',
                               augmentation_strategies: List[str] = ['enumeration', 'kekulize', 'stereo_enum'],
                               n_augmentations: int = 10,
                               preserve_original: bool = True,
                               random_seed: Optional[int] = None) -> pd.DataFrame:
    """
    Advanced SMILES augmentation with multiple strategies.
    
    Parameters:
    -----------
    augmentation_strategies : List[str]
        List of augmentation strategies: 'enumeration', 'kekulize', 'stereo_enum'
    """
    
    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
    
    def apply_augmentation_strategy(smiles: str, strategy: str) -> List[str]:
        """Apply specific augmentation strategy"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return [smiles]
            
            augmented = []
            
            if strategy == 'enumeration':
                # Standard SMILES enumeration
                for _ in range(n_augmentations):
                    enum_smiles = Chem.MolToSmiles(mol, 
                                                 canonical=False, 
                                                 doRandom=True,
                                                 isomericSmiles=True)
                    augmented.append(enum_smiles)
            
            elif strategy == 'kekulize':
                # Kekulization variants
                try:
                    Chem.Kekulize(mol)
                    kek_smiles = Chem.MolToSmiles(mol, kekuleSmiles=True)
                    augmented.append(kek_smiles)
                except:
                    pass
            
            elif strategy == 'stereo_enum':
                # Stereochemistry enumeration
                for _ in range(n_augmentations // 2):
                    # Remove stereochemistry
                    Chem.RemoveStereochemistry(mol)
                    no_stereo = Chem.MolToSmiles(mol)
                    augmented.append(no_stereo)
            
            return list(set(augmented))  # Remove duplicates
            
        except Exception as e:
            print(f"Error in {strategy} for {smiles}: {e}")
            return [smiles]
    
    augmented_rows = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        original_smiles = row[smiles_column]
        
        # Add original if requested
        if preserve_original:
            original_row = row.to_dict()
            original_row['augmentation_strategy'] = 'original'
            original_row['is_augmented'] = False
            augmented_rows.append(original_row)
        
        # Apply each augmentation strategy
        for strategy in augmentation_strategies:
            strategy_smiles = apply_augmentation_strategy(original_smiles, strategy)
            
            for aug_smiles in strategy_smiles:
                if aug_smiles != original_smiles:  # Avoid duplicating original
                    new_row = row.to_dict().copy()
                    new_row[smiles_column] = aug_smiles
                    new_row['augmentation_strategy'] = strategy
                    new_row['is_augmented'] = True
                    augmented_rows.append(new_row)
    
    augmented_df = pd.DataFrame(augmented_rows)
    augmented_df = augmented_df.reset_index(drop=True)
    
    print(f"Advanced augmentation completed:")
    print(f"Original size: {len(df)}, Augmented size: {len(augmented_df)}")
    print(f"Augmentation factor: {len(augmented_df) / len(df):.2f}x")
    
    return augmented_df.reset_index(drop=True)

def create_splits(df):
    length = len(df)
    train_length = int(0.85 * length)
    train = df.loc[:train_length]
    test = df.loc[train_length:]
    return train, test

train, test = create_splits(df)

train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

train = augment_smiles_dataset(train)
test = augment_smiles_dataset(test)

100%|████████████████████████████████████| 14419/14419 [00:43<00:00, 328.78it/s]


Advanced augmentation completed:
Original size: 14419, Augmented size: 168551
Augmentation factor: 11.69x


100%|██████████████████████████████████████| 2545/2545 [00:07<00:00, 333.57it/s]

Advanced augmentation completed:
Original size: 2545, Augmented size: 29716
Augmentation factor: 11.68x





In [11]:
scalers = []

for target in targets:
    target_scaler = StandardScaler()
    train[target] = target_scaler.fit_transform(train[target].to_numpy().reshape(-1, 1))
    test[target] = target_scaler.transform(test[target].to_numpy().reshape(-1, 1))
    
    scalers.append(target_scaler)

smiles_train = train['SMILES']
smiles_test = test['SMILES']

labels_train = train[targets].values
labels_test = test[targets].values

In [6]:
joblib.dump(scalers, 'target_scalers.pkl')

['target_scalers.pkl']

In [12]:
from sklearn.metrics import mean_absolute_error
from transformers import AutoTokenizer, BertModel
import torch
from torch import nn
from transformers.activations import ACT2FN

def global_ap(x):
    return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)

class SimSonEncoder(nn.Module):
    def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
        super(SimSonEncoder, self).__init__()
        self.config = config
        self.max_len = max_len

        self.bert = BertModel(config, add_pooling_layer=False)

        self.linear = nn.Linear(config.hidden_size, max_len)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask=None):
        if attention_mask is None:
            attention_mask = input_ids.ne(0)
            
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        hidden_states = outputs.last_hidden_state
        
        hidden_states = self.dropout(hidden_states)
        
        pooled = global_ap(hidden_states)
        
        out = self.linear(pooled)
        
        return out


class SimSonClassifier(nn.Module):
    def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1):
        super(SimSonClassifier, self).__init__()
        self.encoder = encoder
        self.clf = nn.Linear(encoder.max_len, num_labels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask=None, labels=None):
        x = self.encoder(input_ids, attention_mask)
        x = self.relu(self.dropout(x))
        x = self.clf(x)
        return x

tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# Only the hidden size is slightly larger, everything else is the same
config = BertConfig(
        vocab_size=tokenizer.vocab_size,
        hidden_size=768,
        num_hidden_layers=4,
        num_attention_heads=12,
        intermediate_size=2048,
        max_position_embeddings=512
    )

simson_params = torch.load('/home/jovyan/simson_training_bolgov/kaggle_comp/simson_polymer_1m_uncompiled.pth')

backbone = SimSonEncoder(config=config, max_len=512)
backbone.load_state_dict(simson_params)

model = SimSonClassifier(encoder=backbone, num_labels=len(targets))

  simson_params = torch.load('/home/jovyan/simson_training_bolgov/kaggle_comp/simson_polymer_1m_uncompiled.pth')


In [13]:
import numpy as np
import torch
from torch.utils.data import Dataset, Sampler, DataLoader


class SMILESDataset(Dataset):
    def __init__(self, smiles_list, labels, tokenizer, max_length=256):
        self.smiles_list = smiles_list
        self.labels = labels  # Shape: (num_samples, 5) - already scaled
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Create mask for valid (non-NaN) labels
        self.label_masks = ~np.isnan(self.labels)  # True where label is valid
        
        # Replace NaNs with 0 for safe tensor conversion (mask will handle exclusion)
        self.labels = np.nan_to_num(self.labels, nan=0.0)
    
    def __len__(self):
        return len(self.smiles_list)
    
    def __getitem__(self, idx):
        smiles = self.tokenizer.cls_token + self.smiles_list[idx]
        
        # Tokenize the SMILES string
        encoding = self.tokenizer(
            smiles,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.float32),
            'label_mask': torch.tensor(self.label_masks[idx], dtype=torch.float32)
        }
    
    def get_label_statistics(self):
        """Return statistics about label availability for 5 labels"""
        label_counts = self.label_masks.sum(axis=0)
        total_samples = len(self.smiles_list)
        
        stats = {
            'total_samples': total_samples,
            'label_0_count': label_counts[0],
            'label_1_count': label_counts[1],
            'label_2_count': label_counts[2],
            'label_3_count': label_counts[3],
            'label_4_count': label_counts[4],
            'label_0_ratio': label_counts[0] / total_samples,
            'label_1_ratio': label_counts[1] / total_samples,
            'label_2_ratio': label_counts[2] / total_samples,
            'label_3_ratio': label_counts[3] / total_samples,
            'label_4_ratio': label_counts[4] / total_samples,
            'all_labels_count': (self.label_masks.sum(axis=1) == 5).sum(),
            'partial_labels_count': ((self.label_masks.sum(axis=1) > 0) & (self.label_masks.sum(axis=1) < 5)).sum(),
            'no_labels_count': (self.label_masks.sum(axis=1) == 0).sum()
        }
        
        return stats


class UnderrepresentedLabelSampler(Sampler):
    """
    Custom sampler that gives higher sampling probability to samples containing under-represented labels.
    This ensures each batch contains a good mix of samples with different label availability patterns.
    """
    def __init__(self, dataset, num_labels=5, underrep_boost=2.0):
        """
        Args:
            dataset: SMILESDataset instance
            num_labels: Number of labels (5)
            underrep_boost: Multiplier to boost probability of under-represented labels
        """
        self.dataset = dataset
        self.num_samples = len(dataset)
        self.num_labels = num_labels
        self.underrep_boost = underrep_boost
        
        # Calculate label frequencies
        label_counts = dataset.label_masks.sum(axis=0)  # Count valid samples per label
        total_samples = self.num_samples
        
        # Label frequencies (proportion of samples with each label)
        label_freq = label_counts / total_samples
        
        # Inverse frequency weights (higher for under-represented labels)
        # Add small epsilon to avoid division by zero
        self.label_weights = 1.0 / (label_freq + 1e-6)
        
        # Apply boost to under-represented labels
        # Labels with frequency < median get boosted
        median_freq = np.median(label_freq)
        underrep_mask = label_freq < median_freq
        self.label_weights[underrep_mask] *= self.underrep_boost
        
        # Calculate sample weights based on which labels are present
        sample_weights = []
        for i in range(self.num_samples):
            mask = dataset.label_masks[i]  # Boolean mask for present labels
            if mask.sum() > 0:
                # Weight is average of present labels' weights
                weights = self.label_weights[mask]
                sample_weight = weights.mean()
            else:
                # If no labels present, give minimal weight
                sample_weight = 0.1
            sample_weights.append(sample_weight)
        
        self.sample_weights = torch.tensor(sample_weights, dtype=torch.double)
        
        # Print sampling statistics
        print(f"Label frequencies: {label_freq}")
        print(f"Label weights: {self.label_weights}")
        print(f"Under-represented labels (< median freq {median_freq:.3f}): {np.where(underrep_mask)[0]}")
        print(f"Sample weight range: [{self.sample_weights.min():.3f}, {self.sample_weights.max():.3f}]")
    
    def __iter__(self):
        # Sample with replacement according to calculated weights
        indices = torch.multinomial(self.sample_weights, self.num_samples, replacement=True)
        return iter(indices.tolist())
    
    def __len__(self):
        return self.num_samples


def calculate_unweighted_loss(predictions, labels, label_mask):
    """
    Calculate simple unweighted MSE loss with masking (no label weights)
    
    Args:
        predictions: Model outputs (batch_size, 5)
        labels: Ground truth labels (batch_size, 5)
        label_mask: Mask for valid labels (batch_size, 5)
    """
    loss_fn = nn.MSELoss(reduction='none')
    
    # Calculate per-sample, per-label losses
    losses = loss_fn(predictions, labels)  # Shape: (batch_size, 5)
    
    # Apply masking to exclude NaN labels
    valid_mask = label_mask.bool()
    masked_losses = losses * valid_mask.float()
    
    # Calculate final loss (only over valid predictions)
    total_loss = masked_losses.sum()
    total_valid = valid_mask.sum()
    
    return total_loss / total_valid if total_valid > 0 else torch.tensor(0.0, device=predictions.device, requires_grad=True)


def calculate_true_loss(predictions, labels, label_mask, scalers=None):
    """
    Calculate unscaled MAE loss for monitoring using separate scalers for each label
    
    Args:
        predictions (torch.Tensor): Model outputs of shape (batch_size, 5).
        labels (torch.Tensor): Ground truth labels of shape (batch_size, 5).
        label_mask (torch.Tensor): Boolean mask for valid labels of shape (batch_size, 5).
        scalers: List of scaler objects, one for each label
    
    Returns:
        float: Average MAE across all valid samples
    """
    # Detach tensors from the computation graph and move to CPU
    predictions_np = predictions.cpu().detach().numpy()
    labels_np = labels.cpu().numpy()
    label_mask_np = label_mask.cpu().numpy().astype(bool)
    
    total_mae = 0
    total_samples = 0
    
    for label_idx in range(5):
        # Get valid samples for this label
        valid_mask = label_mask_np[:, label_idx]
        
        if valid_mask.any():
            valid_preds = predictions_np[valid_mask, label_idx].reshape(-1, 1)
            valid_labels = labels_np[valid_mask, label_idx].reshape(-1, 1)
            
            if scalers is not None:
                # Unscale using the corresponding scaler for this label
                unscaled_preds = scalers[label_idx].inverse_transform(valid_preds).flatten()
                unscaled_labels = scalers[label_idx].inverse_transform(valid_labels).flatten()
            else:
                unscaled_preds = valid_preds.flatten()
                unscaled_labels = valid_labels.flatten()
            
            # Calculate MAE for this label
            mae = np.mean(np.abs(unscaled_preds - unscaled_labels))
            total_mae += mae * len(unscaled_preds)
            total_samples += len(unscaled_preds)
    
    return total_mae / total_samples if total_samples > 0 else 0.0


def calculate_individual_label_losses(predictions, labels, label_mask, scalers=None):
    """
    Calculate unscaled MAE loss for each individual label
    
    Args:
        predictions (torch.Tensor): Model outputs of shape (batch_size, 5).
        labels (torch.Tensor): Ground truth labels of shape (batch_size, 5).
        label_mask (torch.Tensor): Boolean mask for valid labels of shape (batch_size, 5).
        scalers: List of scaler objects, one for each label
    
    Returns:
        dict: Dictionary with MAE for each label
    """
    # Detach tensors from the computation graph and move to CPU
    predictions_np = predictions.cpu().detach().numpy()
    labels_np = labels.cpu().numpy()
    label_mask_np = label_mask.cpu().numpy().astype(bool)
    
    individual_losses = {}
    
    for label_idx in range(5):
        # Get valid samples for this label
        valid_mask = label_mask_np[:, label_idx]
        
        if valid_mask.any():
            valid_preds = predictions_np[valid_mask, label_idx].reshape(-1, 1)
            valid_labels = labels_np[valid_mask, label_idx].reshape(-1, 1)
            
            if scalers is not None:
                # Unscale using the corresponding scaler for this label
                unscaled_preds = scalers[label_idx].inverse_transform(valid_preds).flatten()
                unscaled_labels = scalers[label_idx].inverse_transform(valid_labels).flatten()
            else:
                unscaled_preds = valid_preds.flatten()
                unscaled_labels = valid_labels.flatten()
            
            # Calculate MAE for this label
            mae = np.mean(np.abs(unscaled_preds - unscaled_labels))
            individual_losses[f'label_{label_idx}'] = mae
        else:
            individual_losses[f'label_{label_idx}'] = None  # No valid samples for this label
    
    return individual_losses


def analyze_batch_composition(dataloader, num_batches=5):
    """
    Analyze the composition of batches to see label distribution
    """
    print("Analyzing batch composition:")
    
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= num_batches:
            break
            
        label_mask = batch['label_mask'].numpy()
        
        # Count samples with each label in this batch
        label_counts = label_mask.sum(axis=0)
        batch_size = label_mask.shape[0]
        
        print(f"Batch {batch_idx + 1}: Size={batch_size}")
        for i in range(5):
            print(f"  Label {i}: {label_counts[i]}/{batch_size} ({label_counts[i]/batch_size:.2%})")
        print()


def train_model(model, train_dataloader, val_dataloader, 
                scalers=None, num_epochs=10, learning_rate=2e-5, device='cuda', 
                patience=3, validation_steps=500):
    """
    Train model with unweighted loss and custom sampler for five labels
    
    Args:
        model: CustomModel instance (should output 5 labels)
        train_dataloader: Training data loader with custom sampler
        val_dataloader: Validation data loader  
        scalers: List of scalers for unscaled loss monitoring
        num_epochs: Number of training epochs
        learning_rate: Learning rate
        device: Training device
        patience: Early stopping patience (in validation steps)
        validation_steps: Perform validation every N training steps
    """
    model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    total_steps = len(train_dataloader) * num_epochs
    scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=total_steps)
    
    train_losses = []
    val_losses = []
    
    # Early stopping initialization
    best_val_loss = float('inf')
    steps_no_improve = 0
    best_model_state = None
    
    # Training tracking
    global_step = 0
    running_train_loss = 0
    running_true_train_loss = 0
    train_steps_count = 0
    
    print(f"Training with custom sampler (no label weights)")
    print(f"Validation will be performed every {validation_steps} steps")
    
    model.train()
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        train_progress = tqdm(train_dataloader, desc="Training", leave=False)
        
        for batch_idx, batch in enumerate(train_progress):
            with torch.autocast(dtype=torch.float16, device_type="cuda"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                label_mask = batch['label_mask'].to(device)
                
                optimizer.zero_grad()
                
                # Model forward pass
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                )
                
                # Calculate unweighted loss (sampler handles the balancing)
                loss = calculate_unweighted_loss(outputs, labels, label_mask)
                
                # Calculate true loss for monitoring
                true_loss = calculate_true_loss(outputs, labels, label_mask, scalers)
            
            # Accumulate losses for averaging
            running_train_loss += loss.item()
            running_true_train_loss += true_loss
            train_steps_count += 1
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            scheduler.step()
            
            global_step += 1
            
            train_progress.set_postfix({
                'step': global_step,
                'loss': f'{loss.item():.4f}',
                'true_loss': f'{true_loss:.4f}',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })
            
            # Perform validation every validation_steps
            if global_step % validation_steps == 0:
                # Calculate average training losses since last validation
                avg_train_loss = running_train_loss / train_steps_count
                avg_true_train_loss = running_true_train_loss / train_steps_count
                
                train_losses.append(avg_train_loss)
                
                # Reset running averages
                running_train_loss = 0
                running_true_train_loss = 0
                train_steps_count = 0
                
                # Validation
                model.eval()
                total_val_loss = 0
                total_true_val_loss = 0
                val_batches = 0
                
                # Track individual label losses across all validation batches
                accumulated_individual_losses = {f'label_{i}': [] for i in range(5)}

                with torch.no_grad():
                    for val_batch in val_dataloader:
                        with torch.autocast(dtype=torch.float16, device_type="cuda"):
                            input_ids = val_batch['input_ids'].to(device)
                            attention_mask = val_batch['attention_mask'].to(device)
                            labels = val_batch['labels'].to(device)
                            label_mask = val_batch['label_mask'].to(device)
                            
                            outputs = model(
                                input_ids=input_ids,
                                attention_mask=attention_mask,
                            )
                            
                            val_loss = calculate_unweighted_loss(outputs, labels, label_mask)
                            val_true_loss = calculate_true_loss(outputs, labels, label_mask, scalers)
                            
                            # Calculate individual label losses for this batch
                            individual_losses = calculate_individual_label_losses(outputs, labels, label_mask, scalers)
                            
                            # Accumulate individual losses
                            for label_key, loss_value in individual_losses.items():
                                if loss_value is not None:
                                    accumulated_individual_losses[label_key].append(loss_value)

                        total_val_loss += val_loss.item()
                        total_true_val_loss += val_true_loss
                        val_batches += 1
                    
                avg_val_loss = total_val_loss / val_batches
                avg_val_true_loss = total_true_val_loss / val_batches
                val_losses.append(avg_val_loss)
                
                # Calculate average individual label losses
                avg_individual_losses = {}
                for label_key, losses in accumulated_individual_losses.items():
                    if losses:
                        avg_individual_losses[label_key] = np.mean(losses)
                    else:
                        avg_individual_losses[label_key] = None
                
                # Print validation results with individual label losses
                print(f"\nStep {global_step} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | True train loss: {avg_true_train_loss:.4f} | True val loss: {avg_val_true_loss:.4f}")
                print("Individual label losses (unscaled):")
                for i in range(5):
                    label_key = f'label_{i}'
                    if avg_individual_losses[label_key] is not None:
                        print(f"  Label {i}: {avg_individual_losses[label_key]:.4f}")
                    else:
                        print(f"  Label {i}: No valid samples")
                
                # Early stopping check and best model saving
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    steps_no_improve = 0
                    best_model_state = model.state_dict().copy()
                    print(f"New best validation loss: {best_val_loss:.4f}")
                else:
                    steps_no_improve += 1
                    if steps_no_improve >= patience:
                        print(f"Early stopping triggered after {global_step} steps ({steps_no_improve} validation steps without improvement).")
                        # Load best model and return
                        if best_model_state is not None:
                            model.load_state_dict(best_model_state)
                            print(f"Loaded best model with validation loss: {best_val_loss:.4f}")
                        return train_losses, val_losses, best_val_loss
                
                model.train()
    
    # Handle any remaining training loss that hasn't been validated
    if train_steps_count > 0:
        avg_train_loss = running_train_loss / train_steps_count
        train_losses.append(avg_train_loss)
    
    # Load the best model state before returning
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation loss: {best_val_loss:.4f}")
    
    return train_losses, val_losses, best_val_loss


def run_training(smiles_train, smiles_test, labels_train, labels_test, 
                model, tokenizer, scalers, num_epochs=5, learning_rate=1e-5, 
                batch_size=256, validation_steps=500, underrep_boost=2.0):
    """
    Complete training pipeline for five labels with custom sampler
    
    Args:
        smiles_train, smiles_test: Lists of SMILES strings
        labels_train, labels_test: numpy arrays of shape (num_samples, 5) - ALREADY SCALED
        model: CustomModel instance (configured for 5 outputs)
        tokenizer: Tokenizer instance
        scalers: List of 5 scalers, one for each label (for inverse transform only)
        num_epochs: Number of training epochs
        learning_rate: Learning rate
        batch_size: Batch size for training
        validation_steps: Perform validation every N training steps
        underrep_boost: Boost factor for under-represented labels in sampler
    """
    
    print("Setting up datasets for five-label training with custom sampler")
    
    # Create datasets - no scaling performed here
    train_dataset = SMILESDataset(smiles_train, labels_train, tokenizer)
    val_dataset = SMILESDataset(smiles_test, labels_test, tokenizer)
    
    # Print dataset statistics
    train_stats = train_dataset.get_label_statistics()
    val_stats = val_dataset.get_label_statistics()
    
    print("Training dataset statistics:")
    for key, value in train_stats.items():
        print(f"  {key}: {value}")
    
    print("Validation dataset statistics:")
    for key, value in val_stats.items():
        print(f"  {key}: {value}")
    
    # Create custom sampler for balanced training
    train_sampler = UnderrepresentedLabelSampler(
        train_dataset, 
        num_labels=5, 
        underrep_boost=underrep_boost
    )
    
    # Create data loaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=None,  # Use custom sampler instead of shuffle=True
        num_workers=4,
        pin_memory=True
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Analyze batch composition to verify sampler effectiveness
    print("\n" + "="*50)
    #analyze_batch_composition(train_dataloader, num_batches=3)
    print("="*50)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Training steps per epoch: {len(train_dataloader)}")
    print(f"Total training steps: {len(train_dataloader) * num_epochs}")
    
    # Train the model
    train_losses, val_losses, best_val_loss = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        scalers=scalers,
        num_epochs=num_epochs,
        learning_rate=learning_rate,
        device=device,
        patience=10,
        validation_steps=validation_steps,
    )
    
    print('Training completed.')
    print(f'Number of validation checkpoints: {len(val_losses)}')
    print(f'Final training losses: {train_losses[-5:] if len(train_losses) >= 5 else train_losses}')
    print(f'Best validation loss: {best_val_loss:.4f}')
    
    # Save model
    torch.save(model.state_dict(), '/home/jovyan/simson_training_bolgov/kaggle_comp/checkpoints/clf_kaggle.bin')
    print("Model saved successfully!")
    
    return train_losses, val_losses, best_val_loss


In [15]:
import numpy as np
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from tqdm import tqdm

BATCH_SIZE = 128

train_losses, val_losses, best_loss = run_training(
     smiles_train, smiles_test, labels_train, labels_test, 
     model, tokenizer, scalers, num_epochs=20, learning_rate=1e-4, batch_size=BATCH_SIZE, validation_steps=len(smiles_train) // BATCH_SIZE,
)

Setting up datasets for five-label training with custom sampler
Training dataset statistics:
  total_samples: 168551
  label_0_count: 5446
  label_1_count: 78850
  label_2_count: 14846
  label_3_count: 5779
  label_4_count: 5782
  label_0_ratio: 0.032310695279173664
  label_1_ratio: 0.46781092962960763
  label_2_ratio: 0.08808016564719284
  label_3_ratio: 0.03428635843157264
  label_4_ratio: 0.03430415719871137
  all_labels_count: 0
  partial_labels_count: 96406
  no_labels_count: 72145
Validation dataset statistics:
  total_samples: 29716
  label_0_count: 947
  label_1_count: 13878
  label_2_count: 2764
  label_3_count: 957
  label_4_count: 955
  label_0_ratio: 0.03186835374882218
  label_1_ratio: 0.4670211333961502
  label_2_ratio: 0.0930138645847355
  label_3_ratio: 0.03220487279580024
  label_4_ratio: 0.03213756898640463
  all_labels_count: 0
  partial_labels_count: 17016
  no_labels_count: 12700
Label frequencies: [0.0323107  0.46781093 0.08808017 0.03428636 0.03430416]
Label weig

                                                                                


Step 1316 | Train Loss: 0.6250 | Val Loss: 0.4127 | True train loss: 3.9762 | True val loss: 3.8368
Individual label losses (unscaled):
  Label 0: 76.7992
  Label 1: 0.0127
  Label 2: 0.0372
  Label 3: 0.0987
  Label 4: 3.3515
New best validation loss: 0.4127

Epoch 2/20


                                                                                


Step 2632 | Train Loss: 0.5464 | Val Loss: 0.4244 | True train loss: 3.5447 | True val loss: 3.4895
Individual label losses (unscaled):
  Label 0: 68.7228
  Label 1: 0.0130
  Label 2: 0.0379
  Label 3: 0.0952
  Label 4: 3.8732

Epoch 3/20


Training: 100%|█| 1317/1317 [01:22<00:00,  1.88it/s, step=3951, loss=0.6545, tru


Step 3948 | Train Loss: 0.5242 | Val Loss: 0.4007 | True train loss: 3.4056 | True val loss: 3.2830
Individual label losses (unscaled):
  Label 0: 63.8785
  Label 1: 0.0130
  Label 2: 0.0362
  Label 3: 0.1013
  Label 4: 3.4475
New best validation loss: 0.4007


                                                                                


Epoch 4/20


Training: 100%|▉| 1315/1317 [01:22<00:01,  1.87it/s, step=5267, loss=0.3083, tru


Step 5264 | Train Loss: 0.5011 | Val Loss: 0.3770 | True train loss: 3.1835 | True val loss: 3.3785
Individual label losses (unscaled):
  Label 0: 66.0959
  Label 1: 0.0124
  Label 2: 0.0382
  Label 3: 0.0951
  Label 4: 3.3052
New best validation loss: 0.3770


                                                                                


Epoch 5/20


Training: 100%|▉| 1315/1317 [01:22<00:01,  1.87it/s, step=6583, loss=0.2640, tru


Step 6580 | Train Loss: 0.4860 | Val Loss: 0.3498 | True train loss: 3.2743 | True val loss: 3.4532
Individual label losses (unscaled):
  Label 0: 67.9448
  Label 1: 0.0116
  Label 2: 0.0392
  Label 3: 0.0810
  Label 4: 3.3704
New best validation loss: 0.3498


                                                                                


Epoch 6/20


Training: 100%|▉| 1313/1317 [01:22<00:02,  1.87it/s, step=7899, loss=0.1156, tru


Step 7896 | Train Loss: 0.4671 | Val Loss: 0.3422 | True train loss: 3.1278 | True val loss: 3.3296
Individual label losses (unscaled):
  Label 0: 63.2215
  Label 1: 0.0117
  Label 2: 0.0362
  Label 3: 0.0827
  Label 4: 3.2292
New best validation loss: 0.3422


                                                                                


Epoch 7/20


Training: 100%|▉| 1313/1317 [01:22<00:02,  1.86it/s, step=9215, loss=0.2901, tru


Step 9212 | Train Loss: 0.4557 | Val Loss: 0.3389 | True train loss: 3.0609 | True val loss: 3.2751
Individual label losses (unscaled):
  Label 0: 63.4267
  Label 1: 0.0114
  Label 2: 0.0381
  Label 3: 0.0815
  Label 4: 2.8806
New best validation loss: 0.3389


                                                                                


Epoch 8/20


Training: 100%|▉| 1311/1317 [01:22<00:03,  1.87it/s, step=10531, loss=0.4604, tr


Step 10528 | Train Loss: 0.4474 | Val Loss: 0.3379 | True train loss: 3.0718 | True val loss: 3.2051
Individual label losses (unscaled):
  Label 0: 61.2247
  Label 1: 0.0113
  Label 2: 0.0372
  Label 3: 0.0828
  Label 4: 2.9602
New best validation loss: 0.3379


                                                                                


Epoch 9/20


Training: 100%|▉| 1311/1317 [01:21<00:03,  1.87it/s, step=11847, loss=0.2547, tr


Step 11844 | Train Loss: 0.4285 | Val Loss: 0.3416 | True train loss: 3.0075 | True val loss: 3.1697
Individual label losses (unscaled):
  Label 0: 61.3822
  Label 1: 0.0112
  Label 2: 0.0421
  Label 3: 0.0847
  Label 4: 3.3251


                                                                                


Epoch 10/20


Training:  99%|▉| 1309/1317 [01:21<00:04,  1.87it/s, step=13163, loss=0.2791, tr


Step 13160 | Train Loss: 0.4116 | Val Loss: 0.3174 | True train loss: 2.9027 | True val loss: 3.1666
Individual label losses (unscaled):
  Label 0: 59.6537
  Label 1: 0.0110
  Label 2: 0.0365
  Label 3: 0.0877
  Label 4: 3.1535
New best validation loss: 0.3174


                                                                                


Epoch 11/20


Training:  99%|▉| 1309/1317 [01:21<00:04,  1.87it/s, step=14479, loss=0.3915, tr


Step 14476 | Train Loss: 0.3983 | Val Loss: 0.3039 | True train loss: 2.8602 | True val loss: 3.1240
Individual label losses (unscaled):
  Label 0: 60.6528
  Label 1: 0.0107
  Label 2: 0.0371
  Label 3: 0.0827
  Label 4: 3.2043
New best validation loss: 0.3039


                                                                                


Epoch 12/20


Training:  99%|▉| 1307/1317 [01:21<00:05,  1.87it/s, step=15795, loss=0.1155, tr


Step 15792 | Train Loss: 0.3863 | Val Loss: 0.3050 | True train loss: 2.7796 | True val loss: 3.0697
Individual label losses (unscaled):
  Label 0: 59.8002
  Label 1: 0.0108
  Label 2: 0.0371
  Label 3: 0.0815
  Label 4: 3.1037


                                                                                


Epoch 13/20


Training:  99%|▉| 1307/1317 [01:21<00:05,  1.87it/s, step=17111, loss=0.2704, tr


Step 17108 | Train Loss: 0.3779 | Val Loss: 0.2881 | True train loss: 2.7442 | True val loss: 3.1636
Individual label losses (unscaled):
  Label 0: 61.2941
  Label 1: 0.0102
  Label 2: 0.0361
  Label 3: 0.0836
  Label 4: 3.1077
New best validation loss: 0.2881


                                                                                


Epoch 14/20


Training:  99%|▉| 1305/1317 [01:21<00:06,  1.87it/s, step=18427, loss=0.4965, tr


Step 18424 | Train Loss: 0.3645 | Val Loss: 0.2822 | True train loss: 2.6844 | True val loss: 3.1494
Individual label losses (unscaled):
  Label 0: 61.1663
  Label 1: 0.0100
  Label 2: 0.0365
  Label 3: 0.0743
  Label 4: 3.2309
New best validation loss: 0.2822


                                                                                


Epoch 15/20


                                                                                

KeyboardInterrupt: 