project_02_DS / models /custom_vlm.py
griddev's picture
first push
c374021
"""
models/custom_vlm.py
=====================
Advanced Master-Hack β€” Visual Prefix-Tuning (Shakespeare + ViT)
Architecture: A frozen pre-trained ViT (google/vit-base-patch16-224-in21k)
is fused with a custom character-level causal Transformer decoder trained on
Shakespeare text. A trainable MLP projection layer bridges the ViT's
768-dim output to the decoder's 384-dim embedding space.
MODALITY FUSION:
ViT β†’ Project(768β†’384) β†’ [visual_prefix | char_embeddings] β†’ CausalSelfAttention
TRAINING REGIME:
- ViT: FROZEN (always)
- Shakespeare Decoder: UNFROZEN during fine-tuning (adapts to COCO captions)
- visual_projection: TRAINABLE (learned bridge)
Weight Loading Strategy:
The Shakespeare checkpoint uses a custom per-head architecture with keys like:
blocks.N.sa_head.heads.M.{key,query,value}.weight
These are remapped to PyTorch nn.TransformerEncoder's fused format:
decoder_blocks.layers.N.self_attn.in_proj_weight
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTModel
# ─────────────────────────────────────────────────────────────────────────────
# Character Vocabulary Helper
# ─────────────────────────────────────────────────────────────────────────────
def build_char_vocab(text_corpus: str):
"""
Build a character-level vocabulary from a raw text corpus string.
Returns:
chars : sorted list of unique characters
char_to_idx : dict mapping char β†’ int index
idx_to_char : dict mapping int index β†’ char
vocab_size : int
"""
chars = sorted(set(text_corpus))
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for i, c in enumerate(chars)}
return chars, char_to_idx, idx_to_char, len(chars)
# ─────────────────────────────────────────────────────────────────────────────
# Model Definition
# ─────────────────────────────────────────────────────────────────────────────
class CustomVLM(nn.Module):
"""
Visual Prefix-Tuning VLM.
Combines:
1. Frozen ViT image encoder (768-dim output)
2. Trainable MLP projection (768 β†’ text_embed_dim)
3. Character-level causal Transformer decoder
(initialized from shakespeare_transformer.pt, then fine-tuned)
"""
NUM_VISUAL_TOKENS = 197 # ViT: 196 patches + 1 [CLS]
def __init__(self, vocab_size, text_embed_dim=384, n_heads=8, n_layers=8,
block_size=256, dropout=0.1):
super().__init__()
# ── 1. Vision Encoder (Frozen) ──────────────────────────────────────
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
for param in self.vit.parameters():
param.requires_grad = False
vit_hidden_size = self.vit.config.hidden_size # 768
# ── 2. Trainable Bridge (MLP β€” like LLaVA) ──────────────────────────
self.visual_projection = nn.Sequential(
nn.Linear(vit_hidden_size, vit_hidden_size * 2),
nn.GELU(),
nn.Linear(vit_hidden_size * 2, text_embed_dim)
)
# ── 3. Character-Level Causal Transformer Decoder ───────────────────
self.token_embedding_table = nn.Embedding(vocab_size, text_embed_dim)
# Position table covers visual prefix (197) + max text (block_size)
self.position_embedding_table = nn.Embedding(
self.NUM_VISUAL_TOKENS + block_size, text_embed_dim
)
decoder_layer = nn.TransformerEncoderLayer(
d_model=text_embed_dim,
nhead=n_heads,
dim_feedforward=4 * text_embed_dim,
dropout=dropout,
batch_first=True,
)
self.decoder_blocks = nn.TransformerEncoder(decoder_layer, num_layers=n_layers)
self.ln_f = nn.LayerNorm(text_embed_dim)
self.lm_head = nn.Linear(text_embed_dim, vocab_size)
self.block_size = block_size
self.text_embed_dim = text_embed_dim
self.vocab_size = vocab_size
self.n_heads = n_heads
self.n_layers = n_layers
# ─────────────────────────────────────────────────────────────────────────
# Weight Loading β€” with architecture remapping
# ─────────────────────────────────────────────────────────────────────────
def load_shakespeare_weights(self, path: str, device: str = "cpu") -> dict:
"""
Load pre-trained Shakespeare Transformer weights with full key remapping.
The Shakespeare checkpoint uses a custom per-head architecture:
blocks.N.sa_head.heads.M.{key,query,value}.weight (head_dim, embed_dim)
blocks.N.sa_head.proj.{weight,bias}
blocks.N.ffwd.net.{0,2}.{weight,bias}
blocks.N.ln{1,2}.{weight,bias}
These are remapped into PyTorch nn.TransformerEncoder's fused format:
decoder_blocks.layers.N.self_attn.in_proj_weight (3*embed_dim, embed_dim)
decoder_blocks.layers.N.self_attn.out_proj.{weight,bias}
decoder_blocks.layers.N.linear1.{weight,bias}
decoder_blocks.layers.N.linear2.{weight,bias}
decoder_blocks.layers.N.norm1.{weight,bias}
decoder_blocks.layers.N.norm2.{weight,bias}
"""
print(f"πŸ“– Loading Shakespeare weights from: {path}")
raw = torch.load(path, map_location=device)
# Unwrap common checkpoint structures
if isinstance(raw, dict):
if "model_state" in raw:
state_dict = raw["model_state"]
elif "model" in raw:
state_dict = raw["model"]
elif "state_dict" in raw:
state_dict = raw["state_dict"]
else:
state_dict = raw
else:
raise TypeError(f"Unexpected checkpoint type: {type(raw)}")
# ── Discover Shakespeare architecture ────────────────────────────────
shk_blocks = set()
shk_heads = set()
for key in state_dict:
if key.startswith("blocks."):
parts = key.split(".")
shk_blocks.add(int(parts[1]))
if "heads" in key:
shk_heads.add(int(parts[4]))
n_shk_blocks = len(shk_blocks)
n_shk_heads = len(shk_heads) if shk_heads else self.n_heads
head_dim = self.text_embed_dim // self.n_heads
print(f" πŸ“Š Shakespeare arch: {n_shk_blocks} blocks, {n_shk_heads} heads, "
f"head_dim={head_dim}")
print(f" πŸ“Š Model arch: {self.n_layers} layers, {self.n_heads} heads")
# How many layers to load (min of checkpoint and model)
n_load = min(n_shk_blocks, self.n_layers)
n_heads_load = min(n_shk_heads, self.n_heads)
remapped = {}
# ── Remap decoder blocks ─────────────────────────────────────────────
for layer_idx in range(n_load):
prefix_src = f"blocks.{layer_idx}"
prefix_dst = f"decoder_blocks.layers.{layer_idx}"
# 1. Self-Attention: Fuse per-head Q, K, V into in_proj_weight
# Shakespeare: heads.M.query.weight (head_dim, embed_dim)
# Target: self_attn.in_proj_weight (3*embed_dim, embed_dim)
q_parts, k_parts, v_parts = [], [], []
for h in range(n_heads_load):
qk = f"{prefix_src}.sa_head.heads.{h}.query.weight"
kk = f"{prefix_src}.sa_head.heads.{h}.key.weight"
vk = f"{prefix_src}.sa_head.heads.{h}.value.weight"
if qk in state_dict and kk in state_dict and vk in state_dict:
q_parts.append(state_dict[qk])
k_parts.append(state_dict[kk])
v_parts.append(state_dict[vk])
if q_parts:
# Concatenate heads: each (head_dim, embed_dim) β†’ (embed_dim, embed_dim)
Q_full = torch.cat(q_parts, dim=0) # (n_heads*head_dim, embed_dim)
K_full = torch.cat(k_parts, dim=0)
V_full = torch.cat(v_parts, dim=0)
# Fuse into in_proj_weight: [Q; K; V] β†’ (3*embed_dim, embed_dim)
in_proj_weight = torch.cat([Q_full, K_full, V_full], dim=0)
remapped[f"{prefix_dst}.self_attn.in_proj_weight"] = in_proj_weight
# Create zero bias (Shakespeare has no Q/K/V bias)
remapped[f"{prefix_dst}.self_attn.in_proj_bias"] = torch.zeros(
3 * self.text_embed_dim
)
# 2. Output projection
proj_w = f"{prefix_src}.sa_head.proj.weight"
proj_b = f"{prefix_src}.sa_head.proj.bias"
if proj_w in state_dict:
remapped[f"{prefix_dst}.self_attn.out_proj.weight"] = state_dict[proj_w]
if proj_b in state_dict:
remapped[f"{prefix_dst}.self_attn.out_proj.bias"] = state_dict[proj_b]
# 3. Feed-Forward Network
# Shakespeare: ffwd.net.0 β†’ linear1, ffwd.net.2 β†’ linear2
for shk_idx, tgt_name in [("0", "linear1"), ("2", "linear2")]:
wk = f"{prefix_src}.ffwd.net.{shk_idx}.weight"
bk = f"{prefix_src}.ffwd.net.{shk_idx}.bias"
if wk in state_dict:
remapped[f"{prefix_dst}.{tgt_name}.weight"] = state_dict[wk]
if bk in state_dict:
remapped[f"{prefix_dst}.{tgt_name}.bias"] = state_dict[bk]
# 4. Layer Norms: ln1 β†’ norm1, ln2 β†’ norm2
for shk_ln, tgt_ln in [("ln1", "norm1"), ("ln2", "norm2")]:
for suffix in ("weight", "bias"):
sk = f"{prefix_src}.{shk_ln}.{suffix}"
if sk in state_dict:
remapped[f"{prefix_dst}.{tgt_ln}.{suffix}"] = state_dict[sk]
# ── Non-decoder module weights ───────────────────────────────────────
# token_embedding_table
if "token_embedding_table.weight" in state_dict:
shk_emb = state_dict["token_embedding_table.weight"]
own_emb = self.token_embedding_table.weight
if shk_emb.shape == own_emb.shape:
remapped["token_embedding_table.weight"] = shk_emb
elif shk_emb.shape[1] == own_emb.shape[1]:
# Vocab size difference: copy what fits
n_copy = min(shk_emb.shape[0], own_emb.shape[0])
new_emb = own_emb.data.clone()
new_emb[:n_copy] = shk_emb[:n_copy]
remapped["token_embedding_table.weight"] = new_emb
# position_embedding_table: Shakespeare (256, 384) β†’ Model (453, 384)
if "position_embedding_table.weight" in state_dict:
shk_pos = state_dict["position_embedding_table.weight"] # (256, 384)
own_pos = self.position_embedding_table.weight # (197+block_size, 384)
if shk_pos.shape == own_pos.shape:
remapped["position_embedding_table.weight"] = shk_pos
else:
# Expand: zero-init the full table, then copy Shakespeare positions
# into the TEXT portion (positions 197..197+256)
new_pos = torch.zeros_like(own_pos.data)
# Visual positions (0..196) get small random init
nn.init.normal_(new_pos[:self.NUM_VISUAL_TOKENS], std=0.02)
# Text positions: copy Shakespeare's first N positions
n_text_slots = own_pos.shape[0] - self.NUM_VISUAL_TOKENS
n_copy = min(shk_pos.shape[0], n_text_slots)
new_pos[self.NUM_VISUAL_TOKENS:self.NUM_VISUAL_TOKENS + n_copy] = shk_pos[:n_copy]
remapped["position_embedding_table.weight"] = new_pos
print(f" πŸ“ Position embeddings expanded: {shk_pos.shape} β†’ {own_pos.shape}")
# ln_f (final layer norm)
for suffix in ("weight", "bias"):
k = f"ln_f.{suffix}"
if k in state_dict:
own_shape = getattr(self.ln_f, suffix).shape
if state_dict[k].shape == own_shape:
remapped[k] = state_dict[k]
# lm_head
if "lm_head.weight" in state_dict:
shk_lm = state_dict["lm_head.weight"]
own_lm = self.lm_head.weight
if shk_lm.shape == own_lm.shape:
remapped["lm_head.weight"] = shk_lm
elif shk_lm.shape[1] == own_lm.shape[1]:
n_copy = min(shk_lm.shape[0], own_lm.shape[0])
new_lm = own_lm.data.clone()
new_lm[:n_copy] = shk_lm[:n_copy]
remapped["lm_head.weight"] = new_lm
if "lm_head.bias" in state_dict:
shk_b = state_dict["lm_head.bias"]
own_b = self.lm_head.bias
if own_b is not None and shk_b.shape == own_b.shape:
remapped["lm_head.bias"] = shk_b
elif own_b is not None:
n_copy = min(shk_b.shape[0], own_b.shape[0])
new_b = own_b.data.clone()
new_b[:n_copy] = shk_b[:n_copy]
remapped["lm_head.bias"] = new_b
# ── Load remapped weights ─────────────────────────────────────────────
# Verify shapes before loading
own_state = self.state_dict()
valid_remapped = {}
shape_mismatches = []
for k, v in remapped.items():
if k in own_state:
if own_state[k].shape == v.shape:
valid_remapped[k] = v
else:
shape_mismatches.append(
f" {k}: ckpt={v.shape} vs model={own_state[k].shape}"
)
else:
shape_mismatches.append(f" {k}: not in model state_dict")
result = self.load_state_dict(valid_remapped, strict=False)
print(f" βœ… Successfully loaded {len(valid_remapped)} weight tensors (of {len(state_dict)} in checkpoint)")
if shape_mismatches:
print(f" ⚠️ {len(shape_mismatches)} shape mismatches (skipped):")
for msg in shape_mismatches[:5]:
print(msg)
# Count decoder keys that were successfully loaded
decoder_loaded = sum(1 for k in valid_remapped if k.startswith("decoder_blocks"))
total_decoder = sum(1 for k in own_state if k.startswith("decoder_blocks"))
print(f" πŸ“Š Decoder coverage: {decoder_loaded}/{total_decoder} tensors loaded")
return {
"loaded": list(valid_remapped.keys()),
"missing": result.missing_keys,
"unexpected": result.unexpected_keys,
}
# ─────────────────────────────────────────────────────────────────────────
# Freezing / Unfreezing / Parameter Counting
# ─────────────────────────────────────────────────────────────────────────
def freeze_decoder(self):
"""Freeze the Shakespeare decoder so only visual_projection trains."""
for name, param in self.named_parameters():
if not name.startswith("visual_projection"):
param.requires_grad = False
# Ensure ViT is frozen
for param in self.vit.parameters():
param.requires_grad = False
def unfreeze_decoder(self):
"""
Unfreeze the decoder for fine-tuning while keeping ViT frozen.
This allows the decoder to adapt from Shakespeare text to COCO captions.
The visual_projection is also trainable.
"""
# First, freeze everything
for param in self.parameters():
param.requires_grad = False
# Unfreeze visual_projection (always trainable)
for param in self.visual_projection.parameters():
param.requires_grad = True
# Unfreeze ALL decoder components
for param in self.token_embedding_table.parameters():
param.requires_grad = True
for param in self.position_embedding_table.parameters():
param.requires_grad = True
for param in self.decoder_blocks.parameters():
param.requires_grad = True
for param in self.ln_f.parameters():
param.requires_grad = True
for param in self.lm_head.parameters():
param.requires_grad = True
# ViT stays FROZEN
for param in self.vit.parameters():
param.requires_grad = False
def get_param_groups(self, projection_lr=1e-4, decoder_lr=5e-5):
"""
Return optimizer param groups with discriminative learning rates.
- visual_projection: higher LR (learning from scratch)
- decoder: lower LR (gentle adaptation from Shakespeare)
"""
projection_params = []
decoder_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if name.startswith("visual_projection"):
projection_params.append(param)
else:
decoder_params.append(param)
return [
{"params": projection_params, "lr": projection_lr},
{"params": decoder_params, "lr": decoder_lr},
]
def trainable_params(self):
"""Return count of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ─────────────────────────────────────────────────────────────────────────
# Forward Pass
# ─────────────────────────────────────────────────────────────────────────
def forward(self, pixel_values, text_input_ids, text_targets=None):
B, T = text_input_ids.shape
# ── Image Encoding (frozen ViT) ──────────────────────────────────────
with torch.no_grad():
vit_outputs = self.vit(pixel_values=pixel_values)
image_embeds = vit_outputs.last_hidden_state # (B, 197, 768)
# ── Project to text embedding space ──────────────────────────────────
visual_prefix = self.visual_projection(image_embeds) # (B, 197, 384)
num_visual = visual_prefix.shape[1] # 197
# ── Text Embeddings ───────────────────────────────────────────────────
T_clipped = min(T, self.block_size)
text_in = text_input_ids[:, :T_clipped]
tok_emb = self.token_embedding_table(text_in) # (B, T, 384)
# ── Positional Embeddings (covers full combined sequence) ─────────────
# Positions 0..196 β†’ visual prefix, 197..197+T β†’ text tokens
total_len = num_visual + T_clipped
pos_ids = torch.arange(total_len, device=text_in.device)
pos_emb = self.position_embedding_table(pos_ids) # (num_visual+T, 384)
vis_pos = pos_emb[:num_visual] # (197, 384)
txt_pos = pos_emb[num_visual:] # (T, 384)
visual_emb = visual_prefix + vis_pos # (B, 197, 384)
text_emb = tok_emb + txt_pos # (B, T, 384)
# ── Fusion: [visual_prefix | text_emb] ───────────────────────────────
combined = torch.cat([visual_emb, text_emb], dim=1) # (B, 197+T, 384)
tot = combined.shape[1]
# ── Causal Attention Mask ─────────────────────────────────────────────
# Visual tokens attend to each other freely.
# Text tokens attend to all visual tokens + causally to previous text.
mask = torch.full((tot, tot), float("-inf"), device=text_in.device)
mask[:num_visual, :num_visual] = 0.0 # visual→visual: free
mask[num_visual:, :num_visual] = 0.0 # text→visual: free
causal = torch.triu(
torch.full((T_clipped, T_clipped), float("-inf"), device=text_in.device),
diagonal=1,
)
mask[num_visual:, num_visual:] = causal # text→text: causal
# ── Decoder ───────────────────────────────────────────────────────────
x = self.decoder_blocks(combined, mask=mask, is_causal=False)
text_out = x[:, num_visual:, :]
text_out = self.ln_f(text_out)
logits = self.lm_head(text_out) # (B, T, vocab)
# ── Loss (ignore padding index 0) ─────────────────────────────────────
loss = None
if text_targets is not None:
tgt = text_targets[:, :T_clipped]
loss = F.cross_entropy(
logits.reshape(B * T_clipped, -1),
tgt.reshape(B * T_clipped),
ignore_index=0,
)
return logits, loss
# ─────────────────────────────────────────────────────────────────────────
# Generation
# ─────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def generate(self, pixel_values, char_to_idx, idx_to_char,
max_new_tokens=100, temperature=0.8):
"""
Autoregressive character-level caption generation (temperature sampling).
Args:
pixel_values : (1, 3, H, W) pre-processed image tensor
char_to_idx : character β†’ index mapping
idx_to_char : index β†’ character mapping
max_new_tokens : how many characters to generate
temperature : sampling temperature (0.8 = slightly sharper than uniform)
Returns:
generated_text : str
"""
self.eval()
device = pixel_values.device
bos_idx = char_to_idx.get("\n", 0)
idx_seq = torch.tensor([[bos_idx]], dtype=torch.long, device=device)
for _ in range(max_new_tokens):
# Clip text to block_size β€” the forward method handles the visual
# prefix separately, so we only need to limit the text portion.
idx_cond = idx_seq[:, -self.block_size:]
logits, _ = self(pixel_values, idx_cond)
# Take the last time step
logits_last = logits[:, -1, :] / max(temperature, 1e-5)
probs = F.softmax(logits_last, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1)
idx_seq = torch.cat([idx_seq, next_idx], dim=1)
# Decode, skip the leading BOS
generated = "".join(
idx_to_char.get(i.item(), "?") for i in idx_seq[0, 1:]
)
return generated
@torch.no_grad()
def generate_beam(self, pixel_values, char_to_idx, idx_to_char,
max_new_tokens=100, num_beams=4, length_penalty=1.0):
"""
Beam-search character-level caption generation.
At each step we keep the top `num_beams` partial sequences ranked by
cumulative log-probability (with optional length penalty).
Args:
pixel_values : (1, 3, H, W) image tensor
char_to_idx : char β†’ idx mapping
idx_to_char : idx β†’ char mapping
max_new_tokens : max characters to generate
num_beams : beam width (1 = greedy)
length_penalty : >1 favors longer sequences; <1 favors shorter
Returns:
generated_text : str (best beam)
"""
self.eval()
device = pixel_values.device
bos_idx = char_to_idx.get("\n", 0)
# Each beam: (score, token_sequence_tensor)
beams = [(0.0, torch.tensor([[bos_idx]], dtype=torch.long, device=device))]
for _ in range(max_new_tokens):
candidates = []
for score, seq in beams:
idx_cond = seq[:, -self.block_size:]
logits, _ = self(pixel_values, idx_cond)
log_probs = F.log_softmax(logits[:, -1, :], dim=-1) # (1, vocab)
topk_probs, topk_ids = log_probs.topk(num_beams, dim=-1)
for k in range(num_beams):
new_score = score + topk_probs[0, k].item()
new_seq = torch.cat(
[seq, topk_ids[:, k:k+1]], dim=1
)
candidates.append((new_score, new_seq))
# Apply length penalty and keep top beams
candidates.sort(
key=lambda x: x[0] / (x[1].shape[1] ** length_penalty),
reverse=True,
)
beams = candidates[:num_beams]
best_seq = beams[0][1]
return "".join(idx_to_char.get(i.item(), "?") for i in best_seq[0, 1:])