project_02_DS / models /git_tuner.py
griddev's picture
first push
c374021
"""
models/git_tuner.py
===================
Baseline 2 — Zero Cross-Attention / Self-Attention Prefix (GIT)
Architecture: GIT (Generative Image-to-Text) abandons cross-attention entirely.
It concatenates image patch embeddings directly in front of the text tokens and
runs a single causal self-attention Transformer over the combined sequence.
There is NO cross-attention block. The model learns to fuse modalities purely
through self-attention across a unified image+text token sequence. This makes
the ablation masks work differently — we control which image tokens are
prepended to the sequence rather than using encoder_attention_mask.
"""
import os
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
def get_git_model(cfg, device):
"""
Load microsoft/git-base-coco with gradient checkpointing.
GIT uses AutoModelForCausalLM interface.
"""
model_id = cfg.git_model_id
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_id)
try:
model.gradient_checkpointing_enable()
print("✅ Gradient checkpointing enabled (GIT)")
except Exception as e:
print(f"⚠️ Gradient checkpointing failed: {e}")
model.config.use_cache = False
model.to(device)
n_params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"✅ GIT loaded on {device}: {model_id} ({n_params:.1f}M params)")
return model, processor
def generate_caption(model, processor, image_pil, device,
max_new_tokens=32, num_beams=4):
"""
Generate a caption for a single PIL image using GIT.
Note: GIT has no encoder_attention_mask concept (no cross-attention).
Ablation for GIT is handled upstream by modifying the pixel_values
(e.g., masking image regions) before passing to the model, OR by
returning a note that GIT is not compatible with encoder-mask ablations.
"""
model.eval()
inputs = processor(images=image_pil, return_tensors="pt").to(device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
)
caption = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
return caption
def save_ckpt(model, processor, optimizer, scheduler,
step, epoch, cfg_dict, path):
os.makedirs(path, exist_ok=True)
model.save_pretrained(path)
processor.save_pretrained(path)
torch.save(
{
"step": step,
"epoch": epoch,
"optimizer": optimizer.state_dict() if optimizer else None,
"scheduler": scheduler.state_dict() if scheduler else None,
"cfg": cfg_dict,
},
os.path.join(path, "train_state.pt"),
)
print(f"✅ GIT checkpoint saved: {path}")