Spaces:
Sleeping
Sleeping
File size: 5,161 Bytes
f9b8c32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """
step1_load_model.py
====================
Task 3 β Component 1: Load BLIP model with fine-tuned weights.
This module loads the BLIP image-captioning model and attempts to restore
the best fine-tuned checkpoint from `outputs/blip/best/`. If no checkpoint
is found it falls back gracefully to the pretrained HuggingFace weights.
Public API
----------
load_model(weights_dir="outputs/blip/best") -> (model, processor, device)
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_03/step1_load_model.py
"""
import os
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Device helper
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_device() -> torch.device:
"""Return the best available device: MPS β CUDA β CPU."""
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Main loader
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
BLIP_BASE_ID = "Salesforce/blip-image-captioning-base"
def load_model(weights_dir: str = "outputs/blip/best"):
"""
Load BLIP for conditional generation.
1. Downloads/caches base weights from HuggingFace (first run only).
2. Loads fine-tuned checkpoint from `weights_dir` if it exists.
Args:
weights_dir: Path to a directory containing a BLIP checkpoint saved
by `train.py` (e.g. ``outputs/blip/best``). Can be
relative to the *project root*.
Returns:
(model, processor, device)
model : BlipForConditionalGeneration (eval mode)
processor : BlipProcessor
device : torch.device
"""
device = get_device()
print("=" * 60)
print(" Task 3 β Step 1: Load BLIP Model")
print("=" * 60)
print(f" Device : {device}")
# ββ Load processor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
processor = BlipProcessor.from_pretrained(BLIP_BASE_ID)
print(f" β
Processor loaded ({BLIP_BASE_ID})")
# ββ Try fine-tuned checkpoint first βββββββββββββββββββββββββββββββββββββββ
abs_weights = os.path.abspath(weights_dir)
if os.path.isdir(abs_weights) and os.listdir(abs_weights):
print(f" Loading fine-tuned weights from: {abs_weights}")
model = BlipForConditionalGeneration.from_pretrained(abs_weights)
print(" β
Fine-tuned checkpoint loaded")
weights_source = f"fine-tuned ({weights_dir})"
else:
print(f" β οΈ No checkpoint at {abs_weights}. Using base HuggingFace weights.")
model = BlipForConditionalGeneration.from_pretrained(BLIP_BASE_ID)
print(" β
Base pretrained weights loaded")
weights_source = "base (pretrained)"
model.to(device).eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_params:,} | Weights: {weights_source}")
print("=" * 60)
return model, processor, device
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Standalone entrypoint
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import sys
import os
# Allow running from the task folder directly
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
model, processor, device = load_model()
print(f"\nβ
load_model() returned successfully.")
print(f" model type : {type(model).__name__}")
print(f" device : {device}")
print(f"\nYou can now import this in any notebook:")
print(" from task.task_03.step1_load_model import load_model")
print(" model, processor, device = load_model()")
|