gpt2-stackformer-vision_V2
A GPT-2 backbone rebuilt on the stackformer
library, extended with a frozen ViT-B/16 vision encoder and sparse cross-attention layers,
fine-tuned on Flickr8k for image
captioning. Also supports plain text-to-text continuation (GPT-2-style) when no image is
given.
Model summary
| text backbone | GPT-2 small (12 layers, 768 dim, 12 heads) β frozen during fine-tuning |
| vision encoder | torchvision vit_b_16, ImageNet-pretrained, frozen |
| visual tokens | 128 (compressed from ViT patch tokens via a Perceiver-style resampler) |
| cross-attention | sparse β inserted before GPT-2 blocks 3, 6, and 9 |
| trainable params | ~16.6M (vision projector + resampler + 3 cross-attention blocks) |
| total params | ~265M |
| training data | Flickr8k (~8K images Γ 5 captions) |
| training budget | single T4 GPU, ~5 hours |
| context length | 128 tokens |
GPT-2's pretrained weights were loaded directly into stackformer's GPT_2 class and kept
frozen throughout training; only the vision-side modules were fine-tuned. Two bugs in
stackformer's attention implementation (inverted causal masking, and attention dropout
not respecting .eval()) were identified and patched for both training and inference β see
Known issues below.
Usage
This model is built from custom stackformer-based classes (GPT2VL,
SparseCrossAttnBlock, PerceiverResamplerSF, TorchvisionViTEncoder) rather than a
standard transformers architecture, so it cannot be loaded with AutoModel. Use the
reference implementation in this Space
(or copy app.py from it) to load and run the model β it reconstructs the exact
architecture, applies the required stackformer bug patches, and loads this checkpoint.
Minimal loading sketch (see the Space's app.py for the full, working version including
the model classes and bug patches):
import torch
from huggingface_hub import snapshot_download
from transformers import GPT2TokenizerFast
local_dir = snapshot_download("gurumurthy3/gpt2-stackformer-vision_V2")
ckpt = torch.load(f"{local_dir}/model_checkpoint.pth", map_location="cpu")
cfg = ckpt["config"]
# model = GPT2VL(cfg, device="cpu", dtype=torch.float32) # see app.py for the class
# model.load_state_dict(ckpt["model_state_dict"])
# model.eval()
tokenizer = GPT2TokenizerFast.from_pretrained(f"{local_dir}/tokenizer")
Text-to-text
# images=None -> behaves like a (lightly fine-tuned) GPT-2
logits = model(input_ids, images=None)
Image-to-text (captioning)
visual_context = model.encode_image(image_tensor) # run the vision encoder once
logits = model(input_ids, visual_context=visual_context) # reuse it for every decoding step
Known issues
- Causal masking bug:
stackformer's boolean causal mask is inverted relative to whattorch.nn.functional.scaled_dot_product_attentionexpects. Uncorrected, this lets attention see future tokens instead of past ones. - Attention dropout ignores
.eval():stackformerpassesdropout_pdirectly intoscaled_dot_product_attention, which (unlikenn.Dropout) applies it unconditionally regardless ofmodel.train()/model.eval(). Both are patched at import time in the referenceapp.pyβ if you load this checkpoint yourself, apply the same patches or generation quality will be degraded.
Limitations
- The text backbone was frozen during fine-tuning, so language quality is exactly GPT-2 small's β fluent but not state-of-the-art.
- Trained on Flickr8k only (~8K natural images, mostly people/animals/everyday scenes), for a short, single-GPU budget. Expect short, simple, sometimes generic or repetitive captions, and weaker performance on image domains far from Flickr8k's distribution (e.g. diagrams, text-heavy images, illustrations).
- 128-token context length β long prompts or captions will be truncated.
- Vision context is a soft conditioning signal via cross-attention, not a hard constraint, so generated text can occasionally ignore or misdescribe image content.
Training details
Loss drops sharply over the first ~1,000 steps (14.7 β ~4) as the randomly-initialized vision projector, resampler, and cross-attention blocks start aligning with the frozen GPT-2/ViT representations, then decreases gradually over the remaining ~5 epochs, ending around 2.7.
- Base text weights:
openai-community/gpt2(small, 124M), loaded intostackformer'sGPT_2and frozen. - Vision encoder: torchvision
vit_b_16,ViT_B_16_Weights.IMAGENET1K_V1, frozen. - Optimizer: AdamW, lr 2e-4, weight decay 0.01, effective batch size 32 (batch 16 Γ grad accumulation 2).
- Mixed precision (bf16/fp16 depending on GPU support),
pin_memory=True+non_blocking=Truedata transfer. - Stopped via a wall-clock time budget (~4.5 hours) rather than a fixed epoch count, to fit a single T4 session.
- Downloads last month
- -
Model tree for gurumurthy3/gpt2-stackformer-vision_V2
Base model
facebook/dinov3-vit7b16-pretrain-lvd1689m