File size: 3,158 Bytes
0710b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
step1_load_model.py
====================
STEP 1 β€” Load the BLIP model and processor.

Responsibilities:
- Detect the best available device (MPS / CUDA / CPU).
- Load base BLIP weights via project's get_blip_model().
- Optionally patch in fine-tuned weights from outputs/blip/best/.
- Disable gradient checkpointing (required for backward hooks).
- Return a ready-to-use (model, processor, device) triplet.

This module is intentionally tiny and self-contained so it can be
called independently from a notebook, a Streamlit app, or a HuggingFace Space.
"""

import os
import sys
import torch

# ── project path resolution ─────────────────────────────────────────────────
_THIS_DIR       = os.path.dirname(os.path.abspath(__file__))
_TASK_DIR       = os.path.dirname(_THIS_DIR)
_PROJECT_ROOT   = os.path.dirname(_TASK_DIR)

if _PROJECT_ROOT not in sys.path:
    sys.path.insert(0, _PROJECT_ROOT)

from config import CFG
from models.blip_tuner import get_blip_model

FINETUNED_PATH = os.path.join(_PROJECT_ROOT, "outputs", "blip", "best")


# ────────────────────────────────────────────────────────────────────────────
def get_device() -> torch.device:
    """Return the best available compute device."""
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def load_model(use_finetuned: bool = True, verbose: bool = True):
    """
    Load the BLIP model and processor.

    Args:
        use_finetuned: If True, patches base weights with our fine-tuned
                       COCO checkpoint from outputs/blip/best/.
        verbose:       Print loading progress.

    Returns:
        model     – BlipForConditionalGeneration, eval mode.
        processor – BlipProcessor.
        device    – torch.device.
    """
    from transformers import BlipForConditionalGeneration

    device = get_device()
    cfg = CFG.load_for_model("blip")
    model, processor = get_blip_model(cfg, device)

    # Optionally load fine-tuned weights
    if use_finetuned and os.path.isdir(FINETUNED_PATH):
        if verbose:
            print(f"πŸ”„ Loading fine-tuned weights from {FINETUNED_PATH} …")
        ft = BlipForConditionalGeneration.from_pretrained(FINETUNED_PATH)
        model.load_state_dict(ft.state_dict(), strict=False)
        model.to(device)
        if verbose:
            print("βœ… Fine-tuned weights loaded")
    else:
        if verbose:
            print("⚠️  Fine-tuned weights not found (or disabled), using base model")

    # Disable gradient checkpointing β€” incompatible with backward hooks
    try:
        model.gradient_checkpointing_disable()
    except Exception:
        pass
    model.config.use_cache = False
    model.eval()

    if verbose:
        print(f"βœ… Model ready on device: {device}")

    return model, processor, device