project_02_DS / models /blip_tuner.py
griddev's picture
first push
c374021
"""
models/blip_tuner.py
====================
Baseline 3 — Multimodal Mixture Attention (BLIP)
Architecture: BLIP's MED (Multimodal Encoder-Decoder) architecture injects
specialized gated cross-attention between self-attention and feed-forward layers.
The visual encoder output (image patch embeddings) is queried by the text decoder
via cross-attention that is applied carefully at each decoder layer.
This module also provides `generate_with_mask()` for inference-time ablation
experiments that manipulate the encoder_attention_mask to test spatial restrictions.
"""
import os
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
def get_blip_model(cfg, device):
"""
Loads BLIP model and processor with MPS and memory optimizations.
"""
processor = BlipProcessor.from_pretrained(cfg.model_id, use_fast=True)
model = BlipForConditionalGeneration.from_pretrained(cfg.model_id)
# Force 224px images for efficiency (especially on Mac/MPS)
try:
processor.image_processor.size = {"height": cfg.image_size, "width": cfg.image_size}
print(f"✅ Image size set to {cfg.image_size}px")
except Exception as e:
print(f"⚠️ Could not set image size: {e}")
# Gradient checkpointing for VRAM efficiency
try:
model.gradient_checkpointing_enable()
print("✅ Gradient checkpointing enabled (BLIP)")
except Exception as e:
print(f"⚠️ Gradient checkpointing failed: {e}")
model.config.use_cache = False # Must be False with gradient checkpointing
model.to(device)
n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"✅ BLIP loaded on {device}: {cfg.model_id} ({n_params:.1f}M params)")
return model, processor
def generate_with_mask(model, processor, image_pil=None, device=None,
pixel_values=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
max_new_tokens=32, num_beams=4):
"""
Generate a caption for a single PIL image (or pre-computed tensors) with an ablation mask.
Ablation modes supported:
- Baseline: 197 patches visible
- Random Dropout: 50% spatial patches masked
- Center-Focus: Inner 8x8 patches visible
- Squint: Requires passing pre-pooled `encoder_hidden_states` of shape (B, 2, C).
"""
model.eval()
# 1. Get pixel values
if pixel_values is None and image_pil is not None:
inputs = processor(images=image_pil, return_tensors="pt").to(device)
pixel_values = inputs["pixel_values"]
batch_size = pixel_values.shape[0] if pixel_values is not None else encoder_hidden_states.shape[0]
dev = pixel_values.device if pixel_values is not None else encoder_hidden_states.device
# 2. Extract visual features if not pre-provided (e.g., Squint mode provides them)
if encoder_hidden_states is None:
vision_outputs = model.vision_model(pixel_values=pixel_values)
encoder_hidden_states = vision_outputs[0]
# 3. Handle encoder_attention_mask default (Baseline = all ones)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(
encoder_hidden_states.size()[:-1], dtype=torch.long, device=dev
)
else:
encoder_attention_mask = encoder_attention_mask.to(dev)
# 4. Prepare decoder input IDs (BOS token)
input_ids = (
torch.LongTensor([[model.decoder_input_ids, model.config.text_config.eos_token_id]])
.repeat(batch_size, 1)
.to(dev)
)
input_ids[:, 0] = model.config.text_config.bos_token_id
# 5. Bypass the outer model.generate() to avoid hardcoded mask conflicts
with torch.no_grad():
output_ids = model.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=model.config.text_config.sep_token_id,
pad_token_id=model.config.text_config.pad_token_id,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
captions = processor.batch_decode(output_ids, skip_special_tokens=True)
return captions
def save_ckpt(model, processor, optimizer, scheduler, step, epoch, cfg_dict, path):
"""
Save model weights, processor, and training state.
"""
os.makedirs(path, exist_ok=True)
model.save_pretrained(path)
processor.save_pretrained(path)
torch.save(
{
"step": step,
"epoch": epoch,
"optimizer": optimizer.state_dict() if optimizer else None,
"scheduler": scheduler.state_dict() if scheduler else None,
"cfg": cfg_dict,
},
os.path.join(path, "train_state.pt"),
)
print(f"✅ BLIP checkpoint saved: {path}")
def load_ckpt(model, optimizer, scheduler, path):
"""
Load model + optimizer/scheduler from a checkpoint directory.
"""
loaded_model = BlipForConditionalGeneration.from_pretrained(path)
model.load_state_dict(loaded_model.state_dict())
state_path = os.path.join(path, "train_state.pt")
if os.path.exists(state_path):
state = torch.load(state_path, map_location="cpu")
if optimizer and state.get("optimizer"):
optimizer.load_state_dict(state["optimizer"])
if scheduler and state.get("scheduler"):
scheduler.load_state_dict(state["scheduler"])
print(f"✅ Resumed from step {state.get('step', '?')}, epoch {state.get('epoch', '?')}")
return state.get("step", 0), state.get("epoch", 1)
print("✅ Model weights loaded, no training state found.")
return 0, 1