Spaces:
Running
Running
| """ | |
| models/git_tuner.py | |
| =================== | |
| Baseline 2 — Zero Cross-Attention / Self-Attention Prefix (GIT) | |
| Architecture: GIT (Generative Image-to-Text) abandons cross-attention entirely. | |
| It concatenates image patch embeddings directly in front of the text tokens and | |
| runs a single causal self-attention Transformer over the combined sequence. | |
| There is NO cross-attention block. The model learns to fuse modalities purely | |
| through self-attention across a unified image+text token sequence. This makes | |
| the ablation masks work differently — we control which image tokens are | |
| prepended to the sequence rather than using encoder_attention_mask. | |
| """ | |
| import os | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| def get_git_model(cfg, device): | |
| """ | |
| Load microsoft/git-base-coco with gradient checkpointing. | |
| GIT uses AutoModelForCausalLM interface. | |
| """ | |
| model_id = cfg.git_model_id | |
| processor = AutoProcessor.from_pretrained(model_id, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| try: | |
| model.gradient_checkpointing_enable() | |
| print("✅ Gradient checkpointing enabled (GIT)") | |
| except Exception as e: | |
| print(f"⚠️ Gradient checkpointing failed: {e}") | |
| model.config.use_cache = False | |
| model.to(device) | |
| n_params = sum(p.numel() for p in model.parameters()) / 1e6 | |
| print(f"✅ GIT loaded on {device}: {model_id} ({n_params:.1f}M params)") | |
| return model, processor | |
| def generate_caption(model, processor, image_pil, device, | |
| max_new_tokens=32, num_beams=4): | |
| """ | |
| Generate a caption for a single PIL image using GIT. | |
| Note: GIT has no encoder_attention_mask concept (no cross-attention). | |
| Ablation for GIT is handled upstream by modifying the pixel_values | |
| (e.g., masking image regions) before passing to the model, OR by | |
| returning a note that GIT is not compatible with encoder-mask ablations. | |
| """ | |
| model.eval() | |
| inputs = processor(images=image_pil, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| ) | |
| caption = processor.batch_decode(output_ids, skip_special_tokens=True)[0] | |
| return caption | |
| def save_ckpt(model, processor, optimizer, scheduler, | |
| step, epoch, cfg_dict, path): | |
| os.makedirs(path, exist_ok=True) | |
| model.save_pretrained(path) | |
| processor.save_pretrained(path) | |
| torch.save( | |
| { | |
| "step": step, | |
| "epoch": epoch, | |
| "optimizer": optimizer.state_dict() if optimizer else None, | |
| "scheduler": scheduler.state_dict() if scheduler else None, | |
| "cfg": cfg_dict, | |
| }, | |
| os.path.join(path, "train_state.pt"), | |
| ) | |
| print(f"✅ GIT checkpoint saved: {path}") | |