Spaces:
Sleeping
Sleeping
File size: 3,158 Bytes
0710b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | """
step1_load_model.py
====================
STEP 1 β Load the BLIP model and processor.
Responsibilities:
- Detect the best available device (MPS / CUDA / CPU).
- Load base BLIP weights via project's get_blip_model().
- Optionally patch in fine-tuned weights from outputs/blip/best/.
- Disable gradient checkpointing (required for backward hooks).
- Return a ready-to-use (model, processor, device) triplet.
This module is intentionally tiny and self-contained so it can be
called independently from a notebook, a Streamlit app, or a HuggingFace Space.
"""
import os
import sys
import torch
# ββ project path resolution βββββββββββββββββββββββββββββββββββββββββββββββββ
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_TASK_DIR = os.path.dirname(_THIS_DIR)
_PROJECT_ROOT = os.path.dirname(_TASK_DIR)
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
from config import CFG
from models.blip_tuner import get_blip_model
FINETUNED_PATH = os.path.join(_PROJECT_ROOT, "outputs", "blip", "best")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_device() -> torch.device:
"""Return the best available compute device."""
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_model(use_finetuned: bool = True, verbose: bool = True):
"""
Load the BLIP model and processor.
Args:
use_finetuned: If True, patches base weights with our fine-tuned
COCO checkpoint from outputs/blip/best/.
verbose: Print loading progress.
Returns:
model β BlipForConditionalGeneration, eval mode.
processor β BlipProcessor.
device β torch.device.
"""
from transformers import BlipForConditionalGeneration
device = get_device()
cfg = CFG.load_for_model("blip")
model, processor = get_blip_model(cfg, device)
# Optionally load fine-tuned weights
if use_finetuned and os.path.isdir(FINETUNED_PATH):
if verbose:
print(f"π Loading fine-tuned weights from {FINETUNED_PATH} β¦")
ft = BlipForConditionalGeneration.from_pretrained(FINETUNED_PATH)
model.load_state_dict(ft.state_dict(), strict=False)
model.to(device)
if verbose:
print("β
Fine-tuned weights loaded")
else:
if verbose:
print("β οΈ Fine-tuned weights not found (or disabled), using base model")
# Disable gradient checkpointing β incompatible with backward hooks
try:
model.gradient_checkpointing_disable()
except Exception:
pass
model.config.use_cache = False
model.eval()
if verbose:
print(f"β
Model ready on device: {device}")
return model, processor, device
|