Spaces:
Running
Running
File size: 5,812 Bytes
c374021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
models/blip_tuner.py
====================
Baseline 3 — Multimodal Mixture Attention (BLIP)
Architecture: BLIP's MED (Multimodal Encoder-Decoder) architecture injects
specialized gated cross-attention between self-attention and feed-forward layers.
The visual encoder output (image patch embeddings) is queried by the text decoder
via cross-attention that is applied carefully at each decoder layer.
This module also provides `generate_with_mask()` for inference-time ablation
experiments that manipulate the encoder_attention_mask to test spatial restrictions.
"""
import os
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
def get_blip_model(cfg, device):
"""
Loads BLIP model and processor with MPS and memory optimizations.
"""
processor = BlipProcessor.from_pretrained(cfg.model_id, use_fast=True)
model = BlipForConditionalGeneration.from_pretrained(cfg.model_id)
# Force 224px images for efficiency (especially on Mac/MPS)
try:
processor.image_processor.size = {"height": cfg.image_size, "width": cfg.image_size}
print(f"✅ Image size set to {cfg.image_size}px")
except Exception as e:
print(f"⚠️ Could not set image size: {e}")
# Gradient checkpointing for VRAM efficiency
try:
model.gradient_checkpointing_enable()
print("✅ Gradient checkpointing enabled (BLIP)")
except Exception as e:
print(f"⚠️ Gradient checkpointing failed: {e}")
model.config.use_cache = False # Must be False with gradient checkpointing
model.to(device)
n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"✅ BLIP loaded on {device}: {cfg.model_id} ({n_params:.1f}M params)")
return model, processor
def generate_with_mask(model, processor, image_pil=None, device=None,
pixel_values=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
max_new_tokens=32, num_beams=4):
"""
Generate a caption for a single PIL image (or pre-computed tensors) with an ablation mask.
Ablation modes supported:
- Baseline: 197 patches visible
- Random Dropout: 50% spatial patches masked
- Center-Focus: Inner 8x8 patches visible
- Squint: Requires passing pre-pooled `encoder_hidden_states` of shape (B, 2, C).
"""
model.eval()
# 1. Get pixel values
if pixel_values is None and image_pil is not None:
inputs = processor(images=image_pil, return_tensors="pt").to(device)
pixel_values = inputs["pixel_values"]
batch_size = pixel_values.shape[0] if pixel_values is not None else encoder_hidden_states.shape[0]
dev = pixel_values.device if pixel_values is not None else encoder_hidden_states.device
# 2. Extract visual features if not pre-provided (e.g., Squint mode provides them)
if encoder_hidden_states is None:
vision_outputs = model.vision_model(pixel_values=pixel_values)
encoder_hidden_states = vision_outputs[0]
# 3. Handle encoder_attention_mask default (Baseline = all ones)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_states.size()[:-1], dtype=torch.long, device=dev
)
else:
encoder_attention_mask = encoder_attention_mask.to(dev)
# 4. Prepare decoder input IDs (BOS token)
input_ids = (
torch.LongTensor([[model.decoder_input_ids, model.config.text_config.eos_token_id]])
.repeat(batch_size, 1)
.to(dev)
)
input_ids[:, 0] = model.config.text_config.bos_token_id
# 5. Bypass the outer model.generate() to avoid hardcoded mask conflicts
with torch.no_grad():
output_ids = model.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=model.config.text_config.sep_token_id,
pad_token_id=model.config.text_config.pad_token_id,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
captions = processor.batch_decode(output_ids, skip_special_tokens=True)
return captions
def save_ckpt(model, processor, optimizer, scheduler, step, epoch, cfg_dict, path):
"""
Save model weights, processor, and training state.
"""
os.makedirs(path, exist_ok=True)
model.save_pretrained(path)
processor.save_pretrained(path)
torch.save(
{
"step": step,
"epoch": epoch,
"optimizer": optimizer.state_dict() if optimizer else None,
"scheduler": scheduler.state_dict() if scheduler else None,
"cfg": cfg_dict,
},
os.path.join(path, "train_state.pt"),
)
print(f"✅ BLIP checkpoint saved: {path}")
def load_ckpt(model, optimizer, scheduler, path):
"""
Load model + optimizer/scheduler from a checkpoint directory.
"""
loaded_model = BlipForConditionalGeneration.from_pretrained(path)
model.load_state_dict(loaded_model.state_dict())
state_path = os.path.join(path, "train_state.pt")
if os.path.exists(state_path):
state = torch.load(state_path, map_location="cpu")
if optimizer and state.get("optimizer"):
optimizer.load_state_dict(state["optimizer"])
if scheduler and state.get("scheduler"):
scheduler.load_state_dict(state["scheduler"])
print(f"✅ Resumed from step {state.get('step', '?')}, epoch {state.get('epoch', '?')}")
return state.get("step", 0), state.get("epoch", 1)
print("✅ Model weights loaded, no training state found.")
return 0, 1
|