Spaces:
Sleeping
Sleeping
File size: 5,421 Bytes
0710b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """
step1_load_model.py
====================
Task 4 β Component 1: Load BLIP model with fine-tuned weights.
Loads BLIP for caption generation and hidden-state extraction. The model is
used in two modes during Task 4:
- caption generation (nucleus sampling for diversity analysis)
- hidden-state extraction (steering vector computation)
If the fine-tuned checkpoint from Task 3 / main training is found at
``outputs/blip/best/`` it is loaded, otherwise the base HuggingFace weights
are used (works without any prior training).
Public API
----------
load_model(weights_dir="outputs/blip/best") -> (model, processor, device)
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_04/step1_load_model.py
"""
import os
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Device helper
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_device() -> torch.device:
"""Return the best available device: MPS β CUDA β CPU."""
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Main loader
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BLIP_BASE_ID = "Salesforce/blip-image-captioning-base"
def load_model(weights_dir: str = "outputs/blip/best"):
"""
Load BLIP for conditional generation.
1. Downloads/caches base weights from HuggingFace (first run only).
2. Loads fine-tuned checkpoint from ``weights_dir`` if it exists.
Args:
weights_dir: Path to a directory containing a BLIP checkpoint saved
by ``train.py`` (e.g. ``outputs/blip/best``). Can be
relative to the *project root*.
Returns:
(model, processor, device)
model : BlipForConditionalGeneration (eval mode, float32)
processor : BlipProcessor
device : torch.device
"""
device = get_device()
print("=" * 62)
print(" Task 4 β Step 1: Load BLIP Model")
print("=" * 62)
print(f" Device : {device}")
# ββ Load processor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
processor = BlipProcessor.from_pretrained(BLIP_BASE_ID)
print(f" β
Processor loaded ({BLIP_BASE_ID})")
# ββ Try fine-tuned checkpoint first βββββββββββββββββββββββββββββββββββββββ
abs_weights = os.path.abspath(weights_dir)
if os.path.isdir(abs_weights) and os.listdir(abs_weights):
print(f" Loading fine-tuned weights from: {abs_weights}")
model = BlipForConditionalGeneration.from_pretrained(abs_weights)
print(" β
Fine-tuned checkpoint loaded")
weights_source = f"fine-tuned ({weights_dir})"
else:
print(f" β οΈ No checkpoint at {abs_weights}. Using base HuggingFace weights.")
model = BlipForConditionalGeneration.from_pretrained(BLIP_BASE_ID)
print(" β
Base pretrained weights loaded")
weights_source = "base (pretrained)"
# Keep in float32 β needed for stable hidden-state arithmetic during steering
model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_params:,} | Weights: {weights_source}")
print("=" * 62)
return model, processor, device
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Standalone entrypoint
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import sys
sys.path.insert(
0,
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
model, processor, device = load_model()
print(f"\nβ
load_model() returned successfully.")
print(f" model type : {type(model).__name__}")
print(f" device : {device}")
print(f"\nYou can now import this in any notebook:")
print(" from task.task_04.step1_load_model import load_model")
print(" model, processor, device = load_model()")
|