project_02_DS / task /task_04 /step1_load_model.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step1_load_model.py
====================
Task 4 β€” Component 1: Load BLIP model with fine-tuned weights.
Loads BLIP for caption generation and hidden-state extraction. The model is
used in two modes during Task 4:
- caption generation (nucleus sampling for diversity analysis)
- hidden-state extraction (steering vector computation)
If the fine-tuned checkpoint from Task 3 / main training is found at
``outputs/blip/best/`` it is loaded, otherwise the base HuggingFace weights
are used (works without any prior training).
Public API
----------
load_model(weights_dir="outputs/blip/best") -> (model, processor, device)
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_04/step1_load_model.py
"""
import os
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
# ─────────────────────────────────────────────────────────────────────────────
# Device helper
# ─────────────────────────────────────────────────────────────────────────────
def get_device() -> torch.device:
"""Return the best available device: MPS β†’ CUDA β†’ CPU."""
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
# ─────────────────────────────────────────────────────────────────────────────
# Main loader
# ─────────────────────────────────────────────────────────────────────────────
BLIP_BASE_ID = "Salesforce/blip-image-captioning-base"
def load_model(weights_dir: str = "outputs/blip/best"):
"""
Load BLIP for conditional generation.
1. Downloads/caches base weights from HuggingFace (first run only).
2. Loads fine-tuned checkpoint from ``weights_dir`` if it exists.
Args:
weights_dir: Path to a directory containing a BLIP checkpoint saved
by ``train.py`` (e.g. ``outputs/blip/best``). Can be
relative to the *project root*.
Returns:
(model, processor, device)
model : BlipForConditionalGeneration (eval mode, float32)
processor : BlipProcessor
device : torch.device
"""
device = get_device()
print("=" * 62)
print(" Task 4 β€” Step 1: Load BLIP Model")
print("=" * 62)
print(f" Device : {device}")
# ── Load processor ────────────────────────────────────────────────────────
processor = BlipProcessor.from_pretrained(BLIP_BASE_ID)
print(f" βœ… Processor loaded ({BLIP_BASE_ID})")
# ── Try fine-tuned checkpoint first ───────────────────────────────────────
abs_weights = os.path.abspath(weights_dir)
if os.path.isdir(abs_weights) and os.listdir(abs_weights):
print(f" Loading fine-tuned weights from: {abs_weights}")
model = BlipForConditionalGeneration.from_pretrained(abs_weights)
print(" βœ… Fine-tuned checkpoint loaded")
weights_source = f"fine-tuned ({weights_dir})"
else:
print(f" ⚠️ No checkpoint at {abs_weights}. Using base HuggingFace weights.")
model = BlipForConditionalGeneration.from_pretrained(BLIP_BASE_ID)
print(" βœ… Base pretrained weights loaded")
weights_source = "base (pretrained)"
# Keep in float32 β€” needed for stable hidden-state arithmetic during steering
model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_params:,} | Weights: {weights_source}")
print("=" * 62)
return model, processor, device
# ─────────────────────────────────────────────────────────────────────────────
# Standalone entrypoint
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import sys
sys.path.insert(
0,
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
model, processor, device = load_model()
print(f"\nβœ… load_model() returned successfully.")
print(f" model type : {type(model).__name__}")
print(f" device : {device}")
print(f"\nYou can now import this in any notebook:")
print(" from task.task_04.step1_load_model import load_model")
print(" model, processor, device = load_model()")