File size: 5,161 Bytes
f9b8c32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
step1_load_model.py
====================
Task 3 β€” Component 1: Load BLIP model with fine-tuned weights.

This module loads the BLIP image-captioning model and attempts to restore
the best fine-tuned checkpoint from `outputs/blip/best/`. If no checkpoint
is found it falls back gracefully to the pretrained HuggingFace weights.

Public API
----------
    load_model(weights_dir="outputs/blip/best") -> (model, processor, device)

Standalone usage
----------------
    export PYTHONPATH=.
    venv/bin/python task/task_03/step1_load_model.py
"""

import os
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor


# ─────────────────────────────────────────────────────────────────────────────
# Device helper
# ─────────────────────────────────────────────────────────────────────────────

def get_device() -> torch.device:
    """Return the best available device: MPS β†’ CUDA β†’ CPU."""
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


# ─────────────────────────────────────────────────────────────────────────────
# Main loader
# ─────────────────────────────────────────────────────────────────────────────

BLIP_BASE_ID = "Salesforce/blip-image-captioning-base"


def load_model(weights_dir: str = "outputs/blip/best"):
    """
    Load BLIP for conditional generation.

    1. Downloads/caches base weights from HuggingFace (first run only).
    2. Loads fine-tuned checkpoint from `weights_dir` if it exists.

    Args:
        weights_dir: Path to a directory containing a BLIP checkpoint saved
                     by `train.py` (e.g. ``outputs/blip/best``).  Can be
                     relative to the *project root*.

    Returns:
        (model, processor, device)
            model     : BlipForConditionalGeneration (eval mode)
            processor : BlipProcessor
            device    : torch.device
    """
    device = get_device()
    print("=" * 60)
    print("  Task 3 β€” Step 1: Load BLIP Model")
    print("=" * 60)
    print(f"  Device  : {device}")

    # ── Load processor ────────────────────────────────────────────────────────
    processor = BlipProcessor.from_pretrained(BLIP_BASE_ID)
    print(f"  βœ… Processor loaded  ({BLIP_BASE_ID})")

    # ── Try fine-tuned checkpoint first ───────────────────────────────────────
    abs_weights = os.path.abspath(weights_dir)
    if os.path.isdir(abs_weights) and os.listdir(abs_weights):
        print(f"  Loading fine-tuned weights from: {abs_weights}")
        model = BlipForConditionalGeneration.from_pretrained(abs_weights)
        print("  βœ… Fine-tuned checkpoint loaded")
        weights_source = f"fine-tuned ({weights_dir})"
    else:
        print(f"  ⚠️  No checkpoint at {abs_weights}. Using base HuggingFace weights.")
        model = BlipForConditionalGeneration.from_pretrained(BLIP_BASE_ID)
        print("  βœ… Base pretrained weights loaded")
        weights_source = "base (pretrained)"

    model.to(device).eval()

    n_params = sum(p.numel() for p in model.parameters())
    print(f"  Parameters: {n_params:,}  |  Weights: {weights_source}")
    print("=" * 60)

    return model, processor, device


# ─────────────────────────────────────────────────────────────────────────────
# Standalone entrypoint
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    import sys
    import os
    # Allow running from the task folder directly
    sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

    model, processor, device = load_model()
    print(f"\nβœ…  load_model() returned successfully.")
    print(f"   model type : {type(model).__name__}")
    print(f"   device     : {device}")
    print(f"\nYou can now import this in any notebook:")
    print("  from task.task_03.step1_load_model import load_model")
    print("  model, processor, device = load_model()")