project_02_DS / experiments /cross_attention_patterns.py
griddev's picture
first push
c374021
"""
experiments/cross_attention_patterns.py
========================================
Documents and compares the four distinct cross-attention (fusion) patterns
used by each architecture in this pipeline.
This module does NOT require loading any model β€” it produces a static
analysis table and inline architecture diagrams, and can optionally
compute the number of cross-attention parameter counts from loaded models.
Usage (standalone):
python -m experiments.cross_attention_patterns
Architecture Summary
--------------------
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Architecture β”‚ Fusion Mechanism β”‚ Cross-Attention Exists? β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ ViT-GPT2 β”‚ Standard Full CA β”‚ βœ… Yes β€” at every GPT-2 layer β”‚
β”‚ BLIP (MED) β”‚ Gated Cross-Attention MED β”‚ βœ… Yes β€” between SA and FFN β”‚
β”‚ GIT β”‚ Self-Attn Prefix β”‚ ❌ No β€” unified causal SA β”‚
β”‚ Custom VLM β”‚ Visual Prefix-Tuning β”‚ ❌ No β€” linear projection + SA β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
"""
# ─────────────────────────────────────────────────────────────────────────────
# Static Architecture Descriptions
# ─────────────────────────────────────────────────────────────────────────────
PATTERNS = [
{
"name": "ViT-GPT2",
"model_id": "nlpconnect/vit-gpt2-image-captioning",
"cross_attention": True,
"ca_type": "Standard Full Cross-Attention",
"description": (
"Every GPT-2 decoder layer has an explicit cross-attention block. "
"Each text token attends to ALL 197 ViT patch embeddings "
"(1 CLS + 196 spatial) at every layer. "
"This is the brute-force approach β€” maximum information, highest compute."
),
"fusion_formula": "h_text = CrossAttn(Q=h_text, K=h_vis, V=h_vis)",
"ablation_support": True,
"ablation_method": "encoder_attention_mask on generate()",
},
{
"name": "BLIP (MED)",
"model_id": "Salesforce/blip-image-captioning-base",
"cross_attention": True,
"ca_type": "Gated Multimodal Encoder-Decoder (MED)",
"description": (
"BLIP's MED architecture injects a cross-attention sub-layer "
"BETWEEN the self-attention and FFN sub-layers at each decoder block. "
"A learnable gate controls how much visual information passes through. "
"This is more targeted than ViT-GPT2's brute-force attention."
),
"fusion_formula": (
"h = SA(h_text) "
"β†’ h = h + gate * CrossAttn(Q=h, K=h_vis, V=h_vis) "
"β†’ h = FFN(h)"
),
"ablation_support": True,
"ablation_method": "encoder_attention_mask via generate_with_mask()",
},
{
"name": "GIT",
"model_id": "microsoft/git-base-coco",
"cross_attention": False,
"ca_type": "Zero Cross-Attention (Self-Attention Prefix)",
"description": (
"GIT concatenates image patch embeddings directly in front of text tokens "
"to form a flat joint sequence: [img_tokens | text_tokens]. "
"A single causal self-attention Transformer processes the whole thing. "
"There is NO dedicated cross-attention block. "
"Modality fusion is implicit via positional self-attention."
),
"fusion_formula": "h = CausalSA([h_vis; h_text])",
"ablation_support": False,
"ablation_method": "N/A β€” no encoder_attention_mask concept",
},
{
"name": "Custom VLM (Shakespeare)",
"model_id": "google/vit-base-patch16-224-in21k (ViT) + char-level decoder",
"cross_attention": False,
"ca_type": "Visual Prefix-Tuning (Linear Bridge + Causal SA)",
"description": (
"A frozen ViT extracts 197 patch embeddings (768-dim). "
"A single trainable Linear(768β†’384) projects these to the decoder's "
"embedding space. Projected visual tokens are prepended to character "
"embeddings and the Shakespeare causal decoder processes them jointly. "
"Only the linear projection is trained (~294K params, <0.2% of total). "
"\nKey insight: cross-attention is provably unnecessary when modalities "
"are aligned in the same embedding space via prefix concatenation."
),
"fusion_formula": (
"v = Linear(ViT(img)) "
"β†’ x = CausalSA([v; char_emb]) "
"β†’ logits = LMHead(x[len(v):])"
),
"ablation_support": False,
"ablation_method": "N/A β€” visual prefix is part of unified sequence",
},
]
# ─────────────────────────────────────────────────────────────────────────────
# Comparison Table Printer
# ─────────────────────────────────────────────────────────────────────────────
def print_comparison_table():
"""Print a formatted comparison table to stdout."""
print("\n" + "=" * 80)
print(" Cross-Attention Pattern Comparison")
print("=" * 80)
print(f" {'Architecture':<22} {'CA?':>5} {'Type':<35} {'Ablation?':>9}")
print(" " + "-" * 76)
for p in PATTERNS:
ca = " βœ…" if p["cross_attention"] else " ❌"
abl = " βœ…" if p["ablation_support"] else " ❌"
print(f" {p['name']:<22} {ca:>5} {p['ca_type']:<35} {abl:>9}")
print("=" * 80)
for p in PATTERNS:
print(f"\n ── {p['name']} ──────────────────────────────────────────────")
print(f" Model : {p['model_id']}")
print(f" CA Type: {p['ca_type']}")
print(f" Formula: {p['fusion_formula']}")
for line in p["description"].split("\n"):
print(f" {line.strip()}")
if p["ablation_support"]:
print(f" Ablation: {p['ablation_method']}")
else:
print(f" ⚠️ Ablation: {p['ablation_method']}")
print()
# ─────────────────────────────────────────────────────────────────────────────
# Optional: Parameter Count from Loaded Models
# ─────────────────────────────────────────────────────────────────────────────
def count_cross_attention_params(model, model_name: str) -> dict:
"""
Count parameters in cross-attention layers for BLIP or ViT-GPT2.
For GIT / Custom VLM (no CA), returns zero.
Args:
model : loaded PyTorch model
model_name : 'blip' | 'vit_gpt2' | 'git' | 'custom'
Returns:
dict with 'total', 'cross_attn', 'cross_attn_pct'
"""
total = sum(p.numel() for p in model.parameters())
ca_params = 0
if model_name == "blip":
for name, p in model.named_parameters():
if "crossattention" in name.lower():
ca_params += p.numel()
elif model_name == "vit_gpt2":
for name, p in model.named_parameters():
if "crossattention" in name.lower() or "cross_attn" in name.lower():
ca_params += p.numel()
# GIT / custom: 0 cross-attention params by design
return {
"model": model_name,
"total_params": total,
"cross_attn_params": ca_params,
"cross_attn_pct": ca_params / total * 100 if total > 0 else 0.0,
}
# ─────────────────────────────────────────────────────────────────────────────
# CLI
# ─────────────────────────────────────────────────────────────────────────────
def main():
print_comparison_table()
# Optionally count params for all four models
count_params = input(
"\nCount cross-attention parameters in all models? "
"(requires downloading BLIP+ViT-GPT2+GIT) [y/N]: "
).strip().lower()
if count_params == "y":
import torch
device = torch.device("cpu")
print("\nLoading models to count parameters...\n")
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import CFG
from models.blip_tuner import get_blip_model
from models.vit_gpt2_tuner import get_vit_gpt2_model
from models.git_tuner import get_git_model
from models.custom_vlm import CustomVLM, build_char_vocab
cfg = CFG()
rows = []
model_b, _ = get_blip_model(cfg, device)
rows.append(count_cross_attention_params(model_b, "blip"))
del model_b
model_v, _, _ = get_vit_gpt2_model(cfg, device)
rows.append(count_cross_attention_params(model_v, "vit_gpt2"))
del model_v
model_g, _ = get_git_model(cfg, device)
rows.append(count_cross_attention_params(model_g, "git"))
del model_g
with open(cfg.shakespeare_file, "r") as f:
text = f.read()
_, c2i, i2c, vs = build_char_vocab(text)
model_c = CustomVLM(vocab_size=vs)
rows.append(count_cross_attention_params(model_c, "custom"))
del model_c
print("\n" + "=" * 65)
print(" Cross-Attention Parameter Counts")
print("=" * 65)
print(f" {'Model':<15} {'Total':>12} {'CA Params':>12} {'CA %':>8}")
print(" " + "-" * 58)
for r in rows:
print(f" {r['model']:<15} {r['total_params']:>12,} "
f"{r['cross_attn_params']:>12,} {r['cross_attn_pct']:>7.2f}%")
print("=" * 65)
if __name__ == "__main__":
main()