Spaces:
Sleeping
Sleeping
File size: 5,034 Bytes
0710b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
step1_load_model.py
====================
Task 5 β Component 1: Load BLIP model + toxicity classifier.
Loads:
- BLIP (Salesforce/blip-image-captioning-base) with fine-tuned checkpoint
fallback to the base HuggingFace weights.
- unitary/toxic-bert for toxicity scoring (reusing the exact same model
used in app.py β single source of truth).
Public API
----------
load_model(checkpoint_dir=None) -> (model, processor, device)
load_toxicity_model() -> (tokenizer, model)
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_05/step1_load_model.py
"""
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# BLIP caption model
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_BLIP_BASE = "Salesforce/blip-image-captioning-base"
_CKPT_SEARCH = [
"outputs/blip/best",
"outputs/blip/latest",
]
def load_model(checkpoint_dir: str = None):
"""
Load BLIP for caption generation.
Args:
checkpoint_dir: path to fine-tuned checkpoint directory.
If None, searches _CKPT_SEARCH; falls back to base.
Returns:
(model, processor, device)
"""
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print(f" [step1] Device: {device}")
processor = BlipProcessor.from_pretrained(_BLIP_BASE, use_fast=True)
model = BlipForConditionalGeneration.from_pretrained(_BLIP_BASE)
# Try to load fine-tuned weights
ckpt = checkpoint_dir
if ckpt is None:
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
for rel in _CKPT_SEARCH:
cand = os.path.join(root, rel)
if os.path.isdir(cand) and os.listdir(cand):
ckpt = cand
break
if ckpt and os.path.isdir(ckpt):
try:
loaded = BlipForConditionalGeneration.from_pretrained(ckpt)
model.load_state_dict(loaded.state_dict())
del loaded
print(f" [step1] Loaded fine-tuned weights from: {ckpt}")
except Exception as e:
print(f" [step1] Warning: could not load checkpoint ({e}). Using base weights.")
else:
print(f" [step1] Using base HuggingFace weights ({_BLIP_BASE})")
model.to(device).eval()
return model, processor, device
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Toxicity classifier (reuses app.py's load_toxicity_filter pattern)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_TOX_MODEL_ID = "unitary/toxic-bert"
def load_toxicity_model():
"""
Load unitary/toxic-bert for multi-label toxicity scoring.
Returns:
(tokenizer, model) [both on CPU; eval mode]
Note: This is the same model used in app.py's load_toxicity_filter().
"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
print(f" [step1] Loading toxicity model: {_TOX_MODEL_ID}")
tok = AutoTokenizer.from_pretrained(_TOX_MODEL_ID)
mdl = AutoModelForSequenceClassification.from_pretrained(_TOX_MODEL_ID)
mdl.eval()
return tok, mdl
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Standalone
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print("=" * 60)
print(" Task 5 β Step 1: Load Models")
print("=" * 60)
model, processor, device = load_model()
print(f" BLIP loaded on {device}")
tox_tok, tox_mdl = load_toxicity_model()
labels = tox_mdl.config.id2label
print(f" Toxicity labels: {list(labels.values())}")
print(" Step 1 OK.")
|