clane9's picture
Update README.md
807a876
---
license: cc-by-nc-4.0
datasets:
- clane9/NSD-Flat
---
# Model card for `boldgpt_small_patch10.cont`
![Example training predictions](example.png)
A Vision Transformer (ViT) model trained on BOLD activation maps from [NSD-Flat](https://huggingface.co/datasets/clane9/NSD-Flat). The training objective was to auto-regressively predict the next patch with shuffled patch order and MSE loss. This model was trained using `shared1000` as the held out validation set.
## Dependencies
- [boldGPT](https://github.com/clane9/boldGPT)
## Usage
```python
from boldgpt.data import ActivityTransform
from boldgpt.models import create_model
from datasets import load_dataset
model = create_model("boldgpt_small_patch10.cont", pretrained=True)
dataset = load_dataset("clane9/NSD-Flat", split="train")
dataset.set_format("torch")
transform = ActivityTransform()
batch = dataset[:1]
batch["activity"] = transform(batch["activity"])
# output: (B, N + 1, D) predicted next patches
output, state = model(batch)
```
## Reproducing
- Training command:
```bash
torchrun --standalone --nproc_per_node=4 \
scripts/train.py \
--out_dir results \
--model boldgpt_small_patch10 \
--no_cat --shuffle --epochs 1000 --bs 512 \
--workers 0 --amp --compile --wandb
```
- Commit: `e0b29adc8d5b3ed2f1a555d7de4754ba96a3bb3e`