JayLacoma's picture
Update README.md
9bbffb4 verified
---
license: other
license_name: j.lacoma
license_link: LICENSE
base_model:
- google-bert/bert-base-uncased
---
# Transformer-Based fMRI Encoder Model
This repository contains a Transformer-based model trained on neuroimaging datasets to classify conditions like Autism Spectrum Disorder (ASD) and ADHD, and to analyze brain activity during movie-watching. The model combines fMRI data with demographic features (age and gender) for binary classification tasks. Below is a detailed explanation of the datasets, model architecture, and training process.
## **Model Architecture**
The model integrates multi-modal data and leverages a Transformer backbone for feature extraction. Below is a breakdown of its components:
### **1. Inputs**
- **fMRI ROI Data:** High-dimensional features representing brain activity.
- **Age Data:** Numerical input passed through a Multi-Layer Perceptron (MLP).
- **Gender Data:** Binary input (male/female) embedded into a dense representation.
### **2. Transformer Backbone**
- A pretrained Hugging Face Transformer (e.g., BERT) with:
- Configurable number of attention heads, layers, and hidden size.
- Dropout for regularization.
- Dynamically adjusted hyperparameters using `AutoConfig`.
### **3. Pooling Mechanisms**
- Aggregates the Transformer’s sequence outputs into a single vector using:
- **Mean Pooling:** Averages hidden states.
- **Max Pooling:** Selects the maximum value for each feature.
- **Attention Pooling:** Learns attention weights to emphasize important sequence elements.
### **4. Output**
- A fully connected layer maps the pooled output to a scalar value for binary classification.
---
## **Training Process**
### **Key Details:**
- **Loss Function:** Binary Cross Entropy with Logits (`BCEWithLogitsLoss`), with class imbalance handled using positive weights.
- **Optimizer:** Ranger (combines RAdam and Lookahead for stable convergence).
- **Learning Rate Scheduler:** Cosine Annealing for gradual learning rate reduction.
- **Gradient Clipping:** Prevents exploding gradients with a clipping threshold of 1.0.
- **Early Stopping:** Stops training after 250 epochs without validation loss improvement.
### **Datasets Used:**
1. **ABIDE:** Autism vs. control classification.
2. **ADHD-200:** ADHD vs. control classification.
3. **Pixar Movie Dataset (Nilearn):** Brain activity analysis during movie-watching.
### **Output:**
The model’s state dictionary is saved as `fmri_encoder_model.pth`.
---
## **How to Use This Model**
### 1. **Import Required Libraries**
At the beginning, you import all necessary libraries for the model's implementation and training.
```python
from transformers import AutoModel, AutoConfig
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
```
---
### 2. **Define the Transformer Model**
The `TransformerModel` is implemented as a PyTorch `nn.Module`. It integrates a Hugging Face transformer (e.g., `bert-base-uncased`) and custom embedding layers for ROI, age, and gender features.
- **Embedding Layers:**
- `age_mlp` processes age as a continuous input.
- `gender_embed` embeds the binary gender feature.
- **ROI Encoder:**
Encodes the ROI feature input using a small feedforward network.
- **Transformer:**
Initializes a Hugging Face transformer with a custom configuration.
- **Pooling Mechanism:**
Supports three pooling strategies:
- Mean pooling
- Max pooling
- Attention pooling (using an additional attention mechanism).
```python
class TransformerModel(nn.Module):
def __init__(self, roi_input_dim, embed_dim, num_heads, num_layers, dropout_rate, pretrained_model_name="bert-base-uncased", pooling="mean"):
super(TransformerModel, self).__init__()
# Ensure embed_dim is divisible by num_heads
if embed_dim % num_heads != 0:
embed_dim = (embed_dim // num_heads) * num_heads
print(f"Adjusted embed_dim to {embed_dim} to ensure divisibility by num_heads.")
# Embedding layers for age and gender
self.age_mlp = nn.Sequential(
nn.Linear(1, embed_dim),
nn.GELU(),
nn.BatchNorm1d(embed_dim),
nn.Dropout(dropout_rate),
nn.Linear(embed_dim, embed_dim),
)
self.gender_embed = nn.Embedding(2, embed_dim)
# ROI encoder
self.roi_encoder = nn.Sequential(
nn.Linear(roi_input_dim, embed_dim),
nn.GELU(),
nn.BatchNorm1d(embed_dim),
nn.Linear(embed_dim, embed_dim),
nn.Dropout(dropout_rate),
)
# Hugging Face Transformer Model
config = AutoConfig.from_pretrained(pretrained_model_name)
config.hidden_size = embed_dim
config.num_attention_heads = num_heads
config.num_hidden_layers = num_layers
config.hidden_dropout_prob = dropout_rate
config.attention_probs_dropout_prob = dropout_rate
self.transformer = AutoModel.from_config(config)
# Pooling mechanism
assert pooling in ["mean", "max", "attention"]
self.pooling = pooling
if pooling == "attention":
self.attention_pool = nn.Sequential(
nn.Linear(embed_dim, embed_dim // 2),
nn.Mish(),
nn.Linear(embed_dim // 2, 1),
nn.Softmax(dim=1),
)
# Output layer
self.output_layer = nn.Linear(embed_dim, 1)
def forward(self, roi_data, age, gender):
age_embed = self.age_mlp(age.unsqueeze(-1))
gender_embed = self.gender_embed(gender.long())
roi_encoded = self.roi_encoder(roi_data)
combined_input = torch.stack((roi_encoded, age_embed, gender_embed), dim=1)
transformer_output = self.transformer(inputs_embeds=combined_input).last_hidden_state
if self.pooling == "mean":
pooled_output = transformer_output.mean(dim=1)
elif self.pooling == "max":
pooled_output = transformer_output.max(dim=1).values
elif self.pooling == "attention":
attention_weights = self.attention_pool(transformer_output)
pooled_output = (attention_weights * transformer_output).sum(dim=1)
return self.output_layer(pooled_output).squeeze(1)
```
---
### 3. **Initialize the Model**
The model is initialized using the best hyperparameters obtained (e.g., from an Optuna study or prior experimentation).
- **Parameters:**
- `roi_input_dim`: Number of input features for the ROI data.
- `embed_dim`, `num_heads`, `num_layers`: Transformer model hyperparameters.
- `dropout_rate`: Dropout rate to prevent overfitting.
- `pretrained_model_name`: Name of the Hugging Face pretrained model.
- `pooling`: Pooling strategy to summarize the transformer outputs.
```python
model = TransformerModel(
roi_input_dim=roi_features.shape[1],
embed_dim=best_params["embed_dim"],
num_heads=best_params["num_heads"],
num_layers=best_params["num_layers"], # Optuna-tuned
dropout_rate=best_params["dropout_rate"],
pretrained_model_name="bert-base-uncased", # Pretrained transformer model
pooling="attention", # Optuna-tuned pooling strategy
).to(device)
```
---
### 4. **Support for Multi-GPU Training**
If multiple GPUs are available, the model is wrapped in `nn.DataParallel` for parallel training.
```python
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
```
---
### 5. **Define Loss Function and Optimizer**
- **Loss Function:**
The `BCEWithLogitsLoss` is used for binary classification tasks, and class imbalance is handled by a computed `pos_weight`.
- **Optimizer:**
Uses the `Ranger` optimizer with a learning rate from the best hyperparameters.
- **Scheduler:**
A cosine annealing learning rate scheduler adjusts the learning rate over training.
```python
# Compute class weights for imbalance
pos_weight = torch.tensor([len(new_y) / new_y.sum() - 1], dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# Optimizer and Scheduler
optimizer = optimizers.Ranger(model.parameters(), lr=best_params["lr"], weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
```
---
### 6. **Load Pretrained Model Weights**
The model can load previously trained weights to resume training or perform inference.
```python
model.load_state_dict(torch.load("/kaggle/input/bert-encoder-fmri/pytorch/default/1/fmri_encoder_model.pth"))
```
---
### 7. **Best Hyperparameters**
These are the best hyperparameters used for model initialization and training.
```python
best_params = {
"embed_dim": 768,
"num_heads": 32,
"num_layers": 12,
"dropout_rate": 0.119,
"lr": 3.66e-5,
}
```
---