|
|
--- |
|
|
license: other |
|
|
license_name: j.lacoma |
|
|
license_link: LICENSE |
|
|
base_model: |
|
|
- google-bert/bert-base-uncased |
|
|
--- |
|
|
|
|
|
# Transformer-Based fMRI Encoder Model |
|
|
|
|
|
This repository contains a Transformer-based model trained on neuroimaging datasets to classify conditions like Autism Spectrum Disorder (ASD) and ADHD, and to analyze brain activity during movie-watching. The model combines fMRI data with demographic features (age and gender) for binary classification tasks. Below is a detailed explanation of the datasets, model architecture, and training process. |
|
|
|
|
|
|
|
|
## **Model Architecture** |
|
|
|
|
|
The model integrates multi-modal data and leverages a Transformer backbone for feature extraction. Below is a breakdown of its components: |
|
|
|
|
|
### **1. Inputs** |
|
|
- **fMRI ROI Data:** High-dimensional features representing brain activity. |
|
|
- **Age Data:** Numerical input passed through a Multi-Layer Perceptron (MLP). |
|
|
- **Gender Data:** Binary input (male/female) embedded into a dense representation. |
|
|
|
|
|
### **2. Transformer Backbone** |
|
|
- A pretrained Hugging Face Transformer (e.g., BERT) with: |
|
|
- Configurable number of attention heads, layers, and hidden size. |
|
|
- Dropout for regularization. |
|
|
- Dynamically adjusted hyperparameters using `AutoConfig`. |
|
|
|
|
|
### **3. Pooling Mechanisms** |
|
|
- Aggregates the Transformer’s sequence outputs into a single vector using: |
|
|
- **Mean Pooling:** Averages hidden states. |
|
|
- **Max Pooling:** Selects the maximum value for each feature. |
|
|
- **Attention Pooling:** Learns attention weights to emphasize important sequence elements. |
|
|
|
|
|
### **4. Output** |
|
|
- A fully connected layer maps the pooled output to a scalar value for binary classification. |
|
|
|
|
|
--- |
|
|
|
|
|
## **Training Process** |
|
|
|
|
|
### **Key Details:** |
|
|
- **Loss Function:** Binary Cross Entropy with Logits (`BCEWithLogitsLoss`), with class imbalance handled using positive weights. |
|
|
- **Optimizer:** Ranger (combines RAdam and Lookahead for stable convergence). |
|
|
- **Learning Rate Scheduler:** Cosine Annealing for gradual learning rate reduction. |
|
|
- **Gradient Clipping:** Prevents exploding gradients with a clipping threshold of 1.0. |
|
|
- **Early Stopping:** Stops training after 250 epochs without validation loss improvement. |
|
|
|
|
|
### **Datasets Used:** |
|
|
1. **ABIDE:** Autism vs. control classification. |
|
|
2. **ADHD-200:** ADHD vs. control classification. |
|
|
3. **Pixar Movie Dataset (Nilearn):** Brain activity analysis during movie-watching. |
|
|
|
|
|
### **Output:** |
|
|
The model’s state dictionary is saved as `fmri_encoder_model.pth`. |
|
|
|
|
|
--- |
|
|
|
|
|
## **How to Use This Model** |
|
|
|
|
|
### 1. **Import Required Libraries** |
|
|
At the beginning, you import all necessary libraries for the model's implementation and training. |
|
|
|
|
|
```python |
|
|
from transformers import AutoModel, AutoConfig |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
### 2. **Define the Transformer Model** |
|
|
The `TransformerModel` is implemented as a PyTorch `nn.Module`. It integrates a Hugging Face transformer (e.g., `bert-base-uncased`) and custom embedding layers for ROI, age, and gender features. |
|
|
|
|
|
- **Embedding Layers:** |
|
|
- `age_mlp` processes age as a continuous input. |
|
|
- `gender_embed` embeds the binary gender feature. |
|
|
|
|
|
- **ROI Encoder:** |
|
|
Encodes the ROI feature input using a small feedforward network. |
|
|
|
|
|
- **Transformer:** |
|
|
Initializes a Hugging Face transformer with a custom configuration. |
|
|
|
|
|
- **Pooling Mechanism:** |
|
|
Supports three pooling strategies: |
|
|
- Mean pooling |
|
|
- Max pooling |
|
|
- Attention pooling (using an additional attention mechanism). |
|
|
|
|
|
```python |
|
|
class TransformerModel(nn.Module): |
|
|
def __init__(self, roi_input_dim, embed_dim, num_heads, num_layers, dropout_rate, pretrained_model_name="bert-base-uncased", pooling="mean"): |
|
|
super(TransformerModel, self).__init__() |
|
|
|
|
|
# Ensure embed_dim is divisible by num_heads |
|
|
if embed_dim % num_heads != 0: |
|
|
embed_dim = (embed_dim // num_heads) * num_heads |
|
|
print(f"Adjusted embed_dim to {embed_dim} to ensure divisibility by num_heads.") |
|
|
|
|
|
# Embedding layers for age and gender |
|
|
self.age_mlp = nn.Sequential( |
|
|
nn.Linear(1, embed_dim), |
|
|
nn.GELU(), |
|
|
nn.BatchNorm1d(embed_dim), |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(embed_dim, embed_dim), |
|
|
) |
|
|
self.gender_embed = nn.Embedding(2, embed_dim) |
|
|
|
|
|
# ROI encoder |
|
|
self.roi_encoder = nn.Sequential( |
|
|
nn.Linear(roi_input_dim, embed_dim), |
|
|
nn.GELU(), |
|
|
nn.BatchNorm1d(embed_dim), |
|
|
nn.Linear(embed_dim, embed_dim), |
|
|
nn.Dropout(dropout_rate), |
|
|
) |
|
|
|
|
|
# Hugging Face Transformer Model |
|
|
config = AutoConfig.from_pretrained(pretrained_model_name) |
|
|
config.hidden_size = embed_dim |
|
|
config.num_attention_heads = num_heads |
|
|
config.num_hidden_layers = num_layers |
|
|
config.hidden_dropout_prob = dropout_rate |
|
|
config.attention_probs_dropout_prob = dropout_rate |
|
|
self.transformer = AutoModel.from_config(config) |
|
|
|
|
|
# Pooling mechanism |
|
|
assert pooling in ["mean", "max", "attention"] |
|
|
self.pooling = pooling |
|
|
|
|
|
if pooling == "attention": |
|
|
self.attention_pool = nn.Sequential( |
|
|
nn.Linear(embed_dim, embed_dim // 2), |
|
|
nn.Mish(), |
|
|
nn.Linear(embed_dim // 2, 1), |
|
|
nn.Softmax(dim=1), |
|
|
) |
|
|
|
|
|
# Output layer |
|
|
self.output_layer = nn.Linear(embed_dim, 1) |
|
|
|
|
|
def forward(self, roi_data, age, gender): |
|
|
age_embed = self.age_mlp(age.unsqueeze(-1)) |
|
|
gender_embed = self.gender_embed(gender.long()) |
|
|
roi_encoded = self.roi_encoder(roi_data) |
|
|
|
|
|
combined_input = torch.stack((roi_encoded, age_embed, gender_embed), dim=1) |
|
|
transformer_output = self.transformer(inputs_embeds=combined_input).last_hidden_state |
|
|
|
|
|
if self.pooling == "mean": |
|
|
pooled_output = transformer_output.mean(dim=1) |
|
|
elif self.pooling == "max": |
|
|
pooled_output = transformer_output.max(dim=1).values |
|
|
elif self.pooling == "attention": |
|
|
attention_weights = self.attention_pool(transformer_output) |
|
|
pooled_output = (attention_weights * transformer_output).sum(dim=1) |
|
|
|
|
|
return self.output_layer(pooled_output).squeeze(1) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
### 3. **Initialize the Model** |
|
|
The model is initialized using the best hyperparameters obtained (e.g., from an Optuna study or prior experimentation). |
|
|
|
|
|
- **Parameters:** |
|
|
- `roi_input_dim`: Number of input features for the ROI data. |
|
|
- `embed_dim`, `num_heads`, `num_layers`: Transformer model hyperparameters. |
|
|
- `dropout_rate`: Dropout rate to prevent overfitting. |
|
|
- `pretrained_model_name`: Name of the Hugging Face pretrained model. |
|
|
- `pooling`: Pooling strategy to summarize the transformer outputs. |
|
|
|
|
|
```python |
|
|
model = TransformerModel( |
|
|
roi_input_dim=roi_features.shape[1], |
|
|
embed_dim=best_params["embed_dim"], |
|
|
num_heads=best_params["num_heads"], |
|
|
num_layers=best_params["num_layers"], # Optuna-tuned |
|
|
dropout_rate=best_params["dropout_rate"], |
|
|
pretrained_model_name="bert-base-uncased", # Pretrained transformer model |
|
|
pooling="attention", # Optuna-tuned pooling strategy |
|
|
).to(device) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
### 4. **Support for Multi-GPU Training** |
|
|
If multiple GPUs are available, the model is wrapped in `nn.DataParallel` for parallel training. |
|
|
|
|
|
```python |
|
|
if torch.cuda.device_count() > 1: |
|
|
model = nn.DataParallel(model) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
### 5. **Define Loss Function and Optimizer** |
|
|
- **Loss Function:** |
|
|
The `BCEWithLogitsLoss` is used for binary classification tasks, and class imbalance is handled by a computed `pos_weight`. |
|
|
|
|
|
- **Optimizer:** |
|
|
Uses the `Ranger` optimizer with a learning rate from the best hyperparameters. |
|
|
|
|
|
- **Scheduler:** |
|
|
A cosine annealing learning rate scheduler adjusts the learning rate over training. |
|
|
|
|
|
```python |
|
|
# Compute class weights for imbalance |
|
|
pos_weight = torch.tensor([len(new_y) / new_y.sum() - 1], dtype=torch.float32).to(device) |
|
|
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
|
|
|
# Optimizer and Scheduler |
|
|
optimizer = optimizers.Ranger(model.parameters(), lr=best_params["lr"], weight_decay=1e-4) |
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
### 6. **Load Pretrained Model Weights** |
|
|
The model can load previously trained weights to resume training or perform inference. |
|
|
|
|
|
```python |
|
|
model.load_state_dict(torch.load("/kaggle/input/bert-encoder-fmri/pytorch/default/1/fmri_encoder_model.pth")) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
### 7. **Best Hyperparameters** |
|
|
These are the best hyperparameters used for model initialization and training. |
|
|
|
|
|
```python |
|
|
best_params = { |
|
|
"embed_dim": 768, |
|
|
"num_heads": 32, |
|
|
"num_layers": 12, |
|
|
"dropout_rate": 0.119, |
|
|
"lr": 3.66e-5, |
|
|
} |
|
|
``` |
|
|
|
|
|
--- |
|
|
|