--- license: other license_name: j.lacoma license_link: LICENSE base_model: - google-bert/bert-base-uncased --- # Transformer-Based fMRI Encoder Model This repository contains a Transformer-based model trained on neuroimaging datasets to classify conditions like Autism Spectrum Disorder (ASD) and ADHD, and to analyze brain activity during movie-watching. The model combines fMRI data with demographic features (age and gender) for binary classification tasks. Below is a detailed explanation of the datasets, model architecture, and training process. ## **Model Architecture** The model integrates multi-modal data and leverages a Transformer backbone for feature extraction. Below is a breakdown of its components: ### **1. Inputs** - **fMRI ROI Data:** High-dimensional features representing brain activity. - **Age Data:** Numerical input passed through a Multi-Layer Perceptron (MLP). - **Gender Data:** Binary input (male/female) embedded into a dense representation. ### **2. Transformer Backbone** - A pretrained Hugging Face Transformer (e.g., BERT) with: - Configurable number of attention heads, layers, and hidden size. - Dropout for regularization. - Dynamically adjusted hyperparameters using `AutoConfig`. ### **3. Pooling Mechanisms** - Aggregates the Transformer’s sequence outputs into a single vector using: - **Mean Pooling:** Averages hidden states. - **Max Pooling:** Selects the maximum value for each feature. - **Attention Pooling:** Learns attention weights to emphasize important sequence elements. ### **4. Output** - A fully connected layer maps the pooled output to a scalar value for binary classification. --- ## **Training Process** ### **Key Details:** - **Loss Function:** Binary Cross Entropy with Logits (`BCEWithLogitsLoss`), with class imbalance handled using positive weights. - **Optimizer:** Ranger (combines RAdam and Lookahead for stable convergence). - **Learning Rate Scheduler:** Cosine Annealing for gradual learning rate reduction. - **Gradient Clipping:** Prevents exploding gradients with a clipping threshold of 1.0. - **Early Stopping:** Stops training after 250 epochs without validation loss improvement. ### **Datasets Used:** 1. **ABIDE:** Autism vs. control classification. 2. **ADHD-200:** ADHD vs. control classification. 3. **Pixar Movie Dataset (Nilearn):** Brain activity analysis during movie-watching. ### **Output:** The model’s state dictionary is saved as `fmri_encoder_model.pth`. --- ## **How to Use This Model** ### 1. **Import Required Libraries** At the beginning, you import all necessary libraries for the model's implementation and training. ```python from transformers import AutoModel, AutoConfig import torch import torch.nn as nn from torch.optim.lr_scheduler import CosineAnnealingLR ``` --- ### 2. **Define the Transformer Model** The `TransformerModel` is implemented as a PyTorch `nn.Module`. It integrates a Hugging Face transformer (e.g., `bert-base-uncased`) and custom embedding layers for ROI, age, and gender features. - **Embedding Layers:** - `age_mlp` processes age as a continuous input. - `gender_embed` embeds the binary gender feature. - **ROI Encoder:** Encodes the ROI feature input using a small feedforward network. - **Transformer:** Initializes a Hugging Face transformer with a custom configuration. - **Pooling Mechanism:** Supports three pooling strategies: - Mean pooling - Max pooling - Attention pooling (using an additional attention mechanism). ```python class TransformerModel(nn.Module): def __init__(self, roi_input_dim, embed_dim, num_heads, num_layers, dropout_rate, pretrained_model_name="bert-base-uncased", pooling="mean"): super(TransformerModel, self).__init__() # Ensure embed_dim is divisible by num_heads if embed_dim % num_heads != 0: embed_dim = (embed_dim // num_heads) * num_heads print(f"Adjusted embed_dim to {embed_dim} to ensure divisibility by num_heads.") # Embedding layers for age and gender self.age_mlp = nn.Sequential( nn.Linear(1, embed_dim), nn.GELU(), nn.BatchNorm1d(embed_dim), nn.Dropout(dropout_rate), nn.Linear(embed_dim, embed_dim), ) self.gender_embed = nn.Embedding(2, embed_dim) # ROI encoder self.roi_encoder = nn.Sequential( nn.Linear(roi_input_dim, embed_dim), nn.GELU(), nn.BatchNorm1d(embed_dim), nn.Linear(embed_dim, embed_dim), nn.Dropout(dropout_rate), ) # Hugging Face Transformer Model config = AutoConfig.from_pretrained(pretrained_model_name) config.hidden_size = embed_dim config.num_attention_heads = num_heads config.num_hidden_layers = num_layers config.hidden_dropout_prob = dropout_rate config.attention_probs_dropout_prob = dropout_rate self.transformer = AutoModel.from_config(config) # Pooling mechanism assert pooling in ["mean", "max", "attention"] self.pooling = pooling if pooling == "attention": self.attention_pool = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.Mish(), nn.Linear(embed_dim // 2, 1), nn.Softmax(dim=1), ) # Output layer self.output_layer = nn.Linear(embed_dim, 1) def forward(self, roi_data, age, gender): age_embed = self.age_mlp(age.unsqueeze(-1)) gender_embed = self.gender_embed(gender.long()) roi_encoded = self.roi_encoder(roi_data) combined_input = torch.stack((roi_encoded, age_embed, gender_embed), dim=1) transformer_output = self.transformer(inputs_embeds=combined_input).last_hidden_state if self.pooling == "mean": pooled_output = transformer_output.mean(dim=1) elif self.pooling == "max": pooled_output = transformer_output.max(dim=1).values elif self.pooling == "attention": attention_weights = self.attention_pool(transformer_output) pooled_output = (attention_weights * transformer_output).sum(dim=1) return self.output_layer(pooled_output).squeeze(1) ``` --- ### 3. **Initialize the Model** The model is initialized using the best hyperparameters obtained (e.g., from an Optuna study or prior experimentation). - **Parameters:** - `roi_input_dim`: Number of input features for the ROI data. - `embed_dim`, `num_heads`, `num_layers`: Transformer model hyperparameters. - `dropout_rate`: Dropout rate to prevent overfitting. - `pretrained_model_name`: Name of the Hugging Face pretrained model. - `pooling`: Pooling strategy to summarize the transformer outputs. ```python model = TransformerModel( roi_input_dim=roi_features.shape[1], embed_dim=best_params["embed_dim"], num_heads=best_params["num_heads"], num_layers=best_params["num_layers"], # Optuna-tuned dropout_rate=best_params["dropout_rate"], pretrained_model_name="bert-base-uncased", # Pretrained transformer model pooling="attention", # Optuna-tuned pooling strategy ).to(device) ``` --- ### 4. **Support for Multi-GPU Training** If multiple GPUs are available, the model is wrapped in `nn.DataParallel` for parallel training. ```python if torch.cuda.device_count() > 1: model = nn.DataParallel(model) ``` --- ### 5. **Define Loss Function and Optimizer** - **Loss Function:** The `BCEWithLogitsLoss` is used for binary classification tasks, and class imbalance is handled by a computed `pos_weight`. - **Optimizer:** Uses the `Ranger` optimizer with a learning rate from the best hyperparameters. - **Scheduler:** A cosine annealing learning rate scheduler adjusts the learning rate over training. ```python # Compute class weights for imbalance pos_weight = torch.tensor([len(new_y) / new_y.sum() - 1], dtype=torch.float32).to(device) criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) # Optimizer and Scheduler optimizer = optimizers.Ranger(model.parameters(), lr=best_params["lr"], weight_decay=1e-4) scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0) ``` --- ### 6. **Load Pretrained Model Weights** The model can load previously trained weights to resume training or perform inference. ```python model.load_state_dict(torch.load("/kaggle/input/bert-encoder-fmri/pytorch/default/1/fmri_encoder_model.pth")) ``` --- ### 7. **Best Hyperparameters** These are the best hyperparameters used for model initialization and training. ```python best_params = { "embed_dim": 768, "num_heads": 32, "num_layers": 12, "dropout_rate": 0.119, "lr": 3.66e-5, } ``` ---