File size: 5,812 Bytes
c374021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
models/blip_tuner.py
====================
Baseline 3 — Multimodal Mixture Attention (BLIP)

Architecture: BLIP's MED (Multimodal Encoder-Decoder) architecture injects
specialized gated cross-attention between self-attention and feed-forward layers.
The visual encoder output (image patch embeddings) is queried by the text decoder
via cross-attention that is applied carefully at each decoder layer.

This module also provides `generate_with_mask()` for inference-time ablation
experiments that manipulate the encoder_attention_mask to test spatial restrictions.
"""

import os
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration


def get_blip_model(cfg, device):
    """
    Loads BLIP model and processor with MPS and memory optimizations.
    """
    processor = BlipProcessor.from_pretrained(cfg.model_id, use_fast=True)
    model = BlipForConditionalGeneration.from_pretrained(cfg.model_id)

    # Force 224px images for efficiency (especially on Mac/MPS)
    try:
        processor.image_processor.size = {"height": cfg.image_size, "width": cfg.image_size}
        print(f"✅ Image size set to {cfg.image_size}px")
    except Exception as e:
        print(f"⚠️  Could not set image size: {e}")

    # Gradient checkpointing for VRAM efficiency
    try:
        model.gradient_checkpointing_enable()
        print("✅ Gradient checkpointing enabled (BLIP)")
    except Exception as e:
        print(f"⚠️  Gradient checkpointing failed: {e}")

    model.config.use_cache = False  # Must be False with gradient checkpointing
    model.to(device)

    n_params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"✅ BLIP loaded on {device}: {cfg.model_id} ({n_params:.1f}M params)")

    return model, processor


def generate_with_mask(model, processor, image_pil=None, device=None,
                       pixel_values=None,
                       encoder_hidden_states=None,
                       encoder_attention_mask=None,
                       max_new_tokens=32, num_beams=4):
    """
    Generate a caption for a single PIL image (or pre-computed tensors) with an ablation mask.

    Ablation modes supported:
      - Baseline:       197 patches visible
      - Random Dropout: 50% spatial patches masked
      - Center-Focus:   Inner 8x8 patches visible
      - Squint:         Requires passing pre-pooled `encoder_hidden_states` of shape (B, 2, C).
    """
    model.eval()
    
    # 1. Get pixel values
    if pixel_values is None and image_pil is not None:
        inputs = processor(images=image_pil, return_tensors="pt").to(device)
        pixel_values = inputs["pixel_values"]
        
    batch_size = pixel_values.shape[0] if pixel_values is not None else encoder_hidden_states.shape[0]
    dev = pixel_values.device if pixel_values is not None else encoder_hidden_states.device
    
    # 2. Extract visual features if not pre-provided (e.g., Squint mode provides them)
    if encoder_hidden_states is None:
        vision_outputs = model.vision_model(pixel_values=pixel_values)
        encoder_hidden_states = vision_outputs[0]
        
    # 3. Handle encoder_attention_mask default (Baseline = all ones)
    if encoder_attention_mask is None:
        encoder_attention_mask = torch.ones(
            encoder_hidden_states.size()[:-1], dtype=torch.long, device=dev
        )
    else:
        encoder_attention_mask = encoder_attention_mask.to(dev)

    # 4. Prepare decoder input IDs (BOS token)
    input_ids = (
        torch.LongTensor([[model.decoder_input_ids, model.config.text_config.eos_token_id]])
        .repeat(batch_size, 1)
        .to(dev)
    )
    input_ids[:, 0] = model.config.text_config.bos_token_id

    # 5. Bypass the outer model.generate() to avoid hardcoded mask conflicts
    with torch.no_grad():
        output_ids = model.text_decoder.generate(
            input_ids=input_ids[:, :-1],
            eos_token_id=model.config.text_config.sep_token_id,
            pad_token_id=model.config.text_config.pad_token_id,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
        )

    captions = processor.batch_decode(output_ids, skip_special_tokens=True)
    return captions


def save_ckpt(model, processor, optimizer, scheduler, step, epoch, cfg_dict, path):
    """
    Save model weights, processor, and training state.
    """
    os.makedirs(path, exist_ok=True)
    model.save_pretrained(path)
    processor.save_pretrained(path)

    torch.save(
        {
            "step": step,
            "epoch": epoch,
            "optimizer": optimizer.state_dict() if optimizer else None,
            "scheduler": scheduler.state_dict() if scheduler else None,
            "cfg": cfg_dict,
        },
        os.path.join(path, "train_state.pt"),
    )
    print(f"✅ BLIP checkpoint saved: {path}")


def load_ckpt(model, optimizer, scheduler, path):
    """
    Load model + optimizer/scheduler from a checkpoint directory.
    """
    loaded_model = BlipForConditionalGeneration.from_pretrained(path)
    model.load_state_dict(loaded_model.state_dict())

    state_path = os.path.join(path, "train_state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path, map_location="cpu")
        if optimizer and state.get("optimizer"):
            optimizer.load_state_dict(state["optimizer"])
        if scheduler and state.get("scheduler"):
            scheduler.load_state_dict(state["scheduler"])
        print(f"✅ Resumed from step {state.get('step', '?')}, epoch {state.get('epoch', '?')}")
        return state.get("step", 0), state.get("epoch", 1)

    print("✅ Model weights loaded, no training state found.")
    return 0, 1