Spaces:
Running
Running
| """ | |
| models/blip_tuner.py | |
| ==================== | |
| Baseline 3 — Multimodal Mixture Attention (BLIP) | |
| Architecture: BLIP's MED (Multimodal Encoder-Decoder) architecture injects | |
| specialized gated cross-attention between self-attention and feed-forward layers. | |
| The visual encoder output (image patch embeddings) is queried by the text decoder | |
| via cross-attention that is applied carefully at each decoder layer. | |
| This module also provides `generate_with_mask()` for inference-time ablation | |
| experiments that manipulate the encoder_attention_mask to test spatial restrictions. | |
| """ | |
| import os | |
| import torch | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| def get_blip_model(cfg, device): | |
| """ | |
| Loads BLIP model and processor with MPS and memory optimizations. | |
| """ | |
| processor = BlipProcessor.from_pretrained(cfg.model_id, use_fast=True) | |
| model = BlipForConditionalGeneration.from_pretrained(cfg.model_id) | |
| # Force 224px images for efficiency (especially on Mac/MPS) | |
| try: | |
| processor.image_processor.size = {"height": cfg.image_size, "width": cfg.image_size} | |
| print(f"✅ Image size set to {cfg.image_size}px") | |
| except Exception as e: | |
| print(f"⚠️ Could not set image size: {e}") | |
| # Gradient checkpointing for VRAM efficiency | |
| try: | |
| model.gradient_checkpointing_enable() | |
| print("✅ Gradient checkpointing enabled (BLIP)") | |
| except Exception as e: | |
| print(f"⚠️ Gradient checkpointing failed: {e}") | |
| model.config.use_cache = False # Must be False with gradient checkpointing | |
| model.to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| print(f"✅ BLIP loaded on {device}: {cfg.model_id} ({n_params:.1f}M params)") | |
| return model, processor | |
| def generate_with_mask(model, processor, image_pil=None, device=None, | |
| pixel_values=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| max_new_tokens=32, num_beams=4): | |
| """ | |
| Generate a caption for a single PIL image (or pre-computed tensors) with an ablation mask. | |
| Ablation modes supported: | |
| - Baseline: 197 patches visible | |
| - Random Dropout: 50% spatial patches masked | |
| - Center-Focus: Inner 8x8 patches visible | |
| - Squint: Requires passing pre-pooled `encoder_hidden_states` of shape (B, 2, C). | |
| """ | |
| model.eval() | |
| # 1. Get pixel values | |
| if pixel_values is None and image_pil is not None: | |
| inputs = processor(images=image_pil, return_tensors="pt").to(device) | |
| pixel_values = inputs["pixel_values"] | |
| batch_size = pixel_values.shape[0] if pixel_values is not None else encoder_hidden_states.shape[0] | |
| dev = pixel_values.device if pixel_values is not None else encoder_hidden_states.device | |
| # 2. Extract visual features if not pre-provided (e.g., Squint mode provides them) | |
| if encoder_hidden_states is None: | |
| vision_outputs = model.vision_model(pixel_values=pixel_values) | |
| encoder_hidden_states = vision_outputs[0] | |
| # 3. Handle encoder_attention_mask default (Baseline = all ones) | |
| if encoder_attention_mask is None: | |
| encoder_attention_mask = torch.ones( | |
| encoder_hidden_states.size()[:-1], dtype=torch.long, device=dev | |
| ) | |
| else: | |
| encoder_attention_mask = encoder_attention_mask.to(dev) | |
| # 4. Prepare decoder input IDs (BOS token) | |
| input_ids = ( | |
| torch.LongTensor([[model.decoder_input_ids, model.config.text_config.eos_token_id]]) | |
| .repeat(batch_size, 1) | |
| .to(dev) | |
| ) | |
| input_ids[:, 0] = model.config.text_config.bos_token_id | |
| # 5. Bypass the outer model.generate() to avoid hardcoded mask conflicts | |
| with torch.no_grad(): | |
| output_ids = model.text_decoder.generate( | |
| input_ids=input_ids[:, :-1], | |
| eos_token_id=model.config.text_config.sep_token_id, | |
| pad_token_id=model.config.text_config.pad_token_id, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| ) | |
| captions = processor.batch_decode(output_ids, skip_special_tokens=True) | |
| return captions | |
| def save_ckpt(model, processor, optimizer, scheduler, step, epoch, cfg_dict, path): | |
| """ | |
| Save model weights, processor, and training state. | |
| """ | |
| os.makedirs(path, exist_ok=True) | |
| model.save_pretrained(path) | |
| processor.save_pretrained(path) | |
| torch.save( | |
| { | |
| "step": step, | |
| "epoch": epoch, | |
| "optimizer": optimizer.state_dict() if optimizer else None, | |
| "scheduler": scheduler.state_dict() if scheduler else None, | |
| "cfg": cfg_dict, | |
| }, | |
| os.path.join(path, "train_state.pt"), | |
| ) | |
| print(f"✅ BLIP checkpoint saved: {path}") | |
| def load_ckpt(model, optimizer, scheduler, path): | |
| """ | |
| Load model + optimizer/scheduler from a checkpoint directory. | |
| """ | |
| loaded_model = BlipForConditionalGeneration.from_pretrained(path) | |
| model.load_state_dict(loaded_model.state_dict()) | |
| state_path = os.path.join(path, "train_state.pt") | |
| if os.path.exists(state_path): | |
| state = torch.load(state_path, map_location="cpu") | |
| if optimizer and state.get("optimizer"): | |
| optimizer.load_state_dict(state["optimizer"]) | |
| if scheduler and state.get("scheduler"): | |
| scheduler.load_state_dict(state["scheduler"]) | |
| print(f"✅ Resumed from step {state.get('step', '?')}, epoch {state.get('epoch', '?')}") | |
| return state.get("step", 0), state.get("epoch", 1) | |
| print("✅ Model weights loaded, no training state found.") | |
| return 0, 1 | |