gpt2-stackformer-vision_V2

A GPT-2 backbone rebuilt on the stackformer library, extended with a frozen ViT-B/16 vision encoder and sparse cross-attention layers, fine-tuned on Flickr8k for image captioning. Also supports plain text-to-text continuation (GPT-2-style) when no image is given.

Model summary

text backbone GPT-2 small (12 layers, 768 dim, 12 heads) β€” frozen during fine-tuning
vision encoder torchvision vit_b_16, ImageNet-pretrained, frozen
visual tokens 128 (compressed from ViT patch tokens via a Perceiver-style resampler)
cross-attention sparse β€” inserted before GPT-2 blocks 3, 6, and 9
trainable params ~16.6M (vision projector + resampler + 3 cross-attention blocks)
total params ~265M
training data Flickr8k (~8K images Γ— 5 captions)
training budget single T4 GPU, ~5 hours
context length 128 tokens

GPT-2's pretrained weights were loaded directly into stackformer's GPT_2 class and kept frozen throughout training; only the vision-side modules were fine-tuned. Two bugs in stackformer's attention implementation (inverted causal masking, and attention dropout not respecting .eval()) were identified and patched for both training and inference β€” see Known issues below.

Usage

This model is built from custom stackformer-based classes (GPT2VL, SparseCrossAttnBlock, PerceiverResamplerSF, TorchvisionViTEncoder) rather than a standard transformers architecture, so it cannot be loaded with AutoModel. Use the reference implementation in this Space (or copy app.py from it) to load and run the model β€” it reconstructs the exact architecture, applies the required stackformer bug patches, and loads this checkpoint.

Minimal loading sketch (see the Space's app.py for the full, working version including the model classes and bug patches):

import torch
from huggingface_hub import snapshot_download
from transformers import GPT2TokenizerFast
 
local_dir = snapshot_download("gurumurthy3/gpt2-stackformer-vision_V2")
ckpt = torch.load(f"{local_dir}/model_checkpoint.pth", map_location="cpu")
cfg = ckpt["config"]
 
# model = GPT2VL(cfg, device="cpu", dtype=torch.float32)  # see app.py for the class
# model.load_state_dict(ckpt["model_state_dict"])
# model.eval()
 
tokenizer = GPT2TokenizerFast.from_pretrained(f"{local_dir}/tokenizer")

Text-to-text

# images=None -> behaves like a (lightly fine-tuned) GPT-2
logits = model(input_ids, images=None)

Image-to-text (captioning)

visual_context = model.encode_image(image_tensor)  # run the vision encoder once
logits = model(input_ids, visual_context=visual_context)  # reuse it for every decoding step

Known issues

  • Causal masking bug: stackformer's boolean causal mask is inverted relative to what torch.nn.functional.scaled_dot_product_attention expects. Uncorrected, this lets attention see future tokens instead of past ones.
  • Attention dropout ignores .eval(): stackformer passes dropout_p directly into scaled_dot_product_attention, which (unlike nn.Dropout) applies it unconditionally regardless of model.train()/model.eval(). Both are patched at import time in the reference app.py β€” if you load this checkpoint yourself, apply the same patches or generation quality will be degraded.

Limitations

  • The text backbone was frozen during fine-tuning, so language quality is exactly GPT-2 small's β€” fluent but not state-of-the-art.
  • Trained on Flickr8k only (~8K natural images, mostly people/animals/everyday scenes), for a short, single-GPU budget. Expect short, simple, sometimes generic or repetitive captions, and weaker performance on image domains far from Flickr8k's distribution (e.g. diagrams, text-heavy images, illustrations).
  • 128-token context length β€” long prompts or captions will be truncated.
  • Vision context is a soft conditioning signal via cross-attention, not a hard constraint, so generated text can occasionally ignore or misdescribe image content.

Training details

Training loss vs step

Loss drops sharply over the first ~1,000 steps (14.7 β†’ ~4) as the randomly-initialized vision projector, resampler, and cross-attention blocks start aligning with the frozen GPT-2/ViT representations, then decreases gradually over the remaining ~5 epochs, ending around 2.7.

  • Base text weights: openai-community/gpt2 (small, 124M), loaded into stackformer's GPT_2 and frozen.
  • Vision encoder: torchvision vit_b_16, ViT_B_16_Weights.IMAGENET1K_V1, frozen.
  • Optimizer: AdamW, lr 2e-4, weight decay 0.01, effective batch size 32 (batch 16 Γ— grad accumulation 2).
  • Mixed precision (bf16/fp16 depending on GPU support), pin_memory=True + non_blocking=True data transfer.
  • Stopped via a wall-clock time budget (~4.5 hours) rather than a fixed epoch count, to fit a single T4 session.
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for gurumurthy3/gpt2-stackformer-vision_V2

Dataset used to train gurumurthy3/gpt2-stackformer-vision_V2

Space using gurumurthy3/gpt2-stackformer-vision_V2 1