Update scene_captioner.py
Browse files- scene_captioner.py +78 -153
scene_captioner.py
CHANGED
|
@@ -1,206 +1,131 @@
|
|
| 1 |
"""
|
| 2 |
scene_captioner.py
|
| 3 |
──────────────────
|
| 4 |
-
|
| 5 |
|
| 6 |
-
Model
|
| 7 |
-
1.
|
| 8 |
-
2. Salesforce/
|
| 9 |
-
3.
|
| 10 |
-
|
| 11 |
-
HF Spaces free tier = CPU only, 16 GB RAM.
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
import os
|
| 15 |
import io
|
| 16 |
-
import logging
|
| 17 |
import hashlib
|
|
|
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
# ──
|
| 22 |
try:
|
| 23 |
import torch
|
| 24 |
TORCH_OK = True
|
| 25 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
-
logger.info(f"PyTorch {torch.__version__}
|
| 27 |
-
except Exception as
|
| 28 |
TORCH_OK = False
|
| 29 |
DEVICE = "cpu"
|
| 30 |
-
logger.error(f"PyTorch
|
| 31 |
|
| 32 |
from PIL import Image, ImageStat
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
"
|
| 37 |
-
"Describe the scene clearly in 2–3 sentences covering: "
|
| 38 |
-
"(1) main subjects and actions, (2) setting/environment, "
|
| 39 |
-
"(3) any safety hazards if visible."
|
| 40 |
)
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
class SceneCaptioner:
|
| 45 |
-
"""
|
| 46 |
-
Loads a vision-language model once and exposes `describe(image: PIL.Image) -> str`.
|
| 47 |
-
Falls back gracefully through 3 model tiers if earlier ones fail.
|
| 48 |
-
"""
|
| 49 |
|
| 50 |
def __init__(self):
|
| 51 |
-
self.
|
| 52 |
-
self.
|
| 53 |
-
self._backend = "mock"
|
| 54 |
|
| 55 |
if not TORCH_OK:
|
| 56 |
-
logger.
|
| 57 |
return
|
| 58 |
|
| 59 |
-
# Try models
|
| 60 |
-
for loader in
|
|
|
|
|
|
|
|
|
|
| 61 |
try:
|
| 62 |
-
loader()
|
|
|
|
| 63 |
break
|
| 64 |
except Exception as exc:
|
| 65 |
-
logger.warning(f"
|
| 66 |
|
| 67 |
if self._backend == "mock":
|
| 68 |
-
logger.warning("All models failed —
|
| 69 |
-
|
| 70 |
-
# ── Model loaders ─────────────────────────────────────────────────────────
|
| 71 |
|
| 72 |
-
|
| 73 |
-
"""Qwen2-VL-2B-Instruct — best quality."""
|
| 74 |
-
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
| 75 |
-
model_id = "Qwen/Qwen2-VL-2B-Instruct"
|
| 76 |
-
logger.info(f"Loading {model_id} …")
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
model_id,
|
| 86 |
-
torch_dtype=torch.float32, # CPU — float32 required
|
| 87 |
-
trust_remote_code=True,
|
| 88 |
-
low_cpu_mem_usage=True,
|
| 89 |
-
)
|
| 90 |
-
self.model.eval()
|
| 91 |
-
self._backend = "qwen"
|
| 92 |
-
logger.info(f"✅ Loaded Qwen2-VL on {DEVICE}")
|
| 93 |
-
|
| 94 |
-
def _try_blip2(self):
|
| 95 |
-
"""BLIP-2 OPT-2.7B — fallback."""
|
| 96 |
-
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
| 97 |
-
model_id = "Salesforce/blip2-opt-2.7b"
|
| 98 |
-
logger.info(f"Loading {model_id} …")
|
| 99 |
-
|
| 100 |
-
# use_fast=False avoids BlipImageProcessorFast/torch mismatch
|
| 101 |
-
self.processor = Blip2Processor.from_pretrained(
|
| 102 |
-
model_id,
|
| 103 |
-
use_fast=False,
|
| 104 |
-
)
|
| 105 |
-
self.model = Blip2ForConditionalGeneration.from_pretrained(
|
| 106 |
-
model_id,
|
| 107 |
-
torch_dtype=torch.float32,
|
| 108 |
-
low_cpu_mem_usage=True,
|
| 109 |
)
|
| 110 |
-
self.model.eval()
|
| 111 |
-
self._backend = "blip2"
|
| 112 |
-
logger.info(f"✅ Loaded BLIP-2 on {DEVICE}")
|
| 113 |
-
|
| 114 |
-
def _try_vitgpt2(self):
|
| 115 |
-
"""ViT-GPT2 — tiny and fast, CPU-friendly (~1 GB)."""
|
| 116 |
-
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
| 117 |
-
model_id = "nlpconnect/vit-gpt2-image-captioning"
|
| 118 |
-
logger.info(f"Loading {model_id} …")
|
| 119 |
-
|
| 120 |
-
self._vit_processor = ViTImageProcessor.from_pretrained(model_id)
|
| 121 |
-
self._vit_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 122 |
-
self.model = VisionEncoderDecoderModel.from_pretrained(model_id)
|
| 123 |
-
self.model.eval()
|
| 124 |
self._backend = "vitgpt2"
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
# ── Inference ─────────────────────────────────────────────────────────────
|
| 128 |
|
| 129 |
def describe(self, image: Image.Image) -> str:
|
| 130 |
image = image.convert("RGB")
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
)
|
| 148 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
| 149 |
-
inputs = self.processor(
|
| 150 |
-
text=[text], images=image_inputs, videos=video_inputs,
|
| 151 |
-
padding=True, return_tensors="pt",
|
| 152 |
-
)
|
| 153 |
-
with torch.no_grad():
|
| 154 |
-
gen = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
| 155 |
-
trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], gen)]
|
| 156 |
-
return self.processor.batch_decode(trimmed, skip_special_tokens=True)[0].strip()
|
| 157 |
-
|
| 158 |
-
def _infer_blip2(self, image: Image.Image) -> str:
|
| 159 |
-
prompt = f"Question: {USER_PROMPT} Answer:"
|
| 160 |
-
inputs = self.processor(image, text=prompt, return_tensors="pt")
|
| 161 |
-
with torch.no_grad():
|
| 162 |
-
ids = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
|
| 163 |
-
return self.processor.decode(ids[0], skip_special_tokens=True).strip()
|
| 164 |
-
|
| 165 |
-
def _infer_vitgpt2(self, image: Image.Image) -> str:
|
| 166 |
-
pixel_values = self._vit_processor(
|
| 167 |
-
images=[image], return_tensors="pt"
|
| 168 |
-
).pixel_values
|
| 169 |
-
with torch.no_grad():
|
| 170 |
-
ids = self.model.generate(
|
| 171 |
-
pixel_values, max_length=64, num_beams=4,
|
| 172 |
-
)
|
| 173 |
-
caption = self._vit_tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 174 |
-
return caption.strip()
|
| 175 |
-
|
| 176 |
-
# ── Mock fallback (no model needed) ──────────────────────────────────────
|
| 177 |
-
|
| 178 |
-
def _infer_mock(self, image: Image.Image) -> str:
|
| 179 |
-
"""
|
| 180 |
-
Deterministic mock based on image brightness + colour.
|
| 181 |
-
Used when all model loads fail (e.g. OOM or no network).
|
| 182 |
-
"""
|
| 183 |
-
SAFE_CAPTIONS = [
|
| 184 |
-
"A well-lit indoor space with furniture and soft natural light. The area appears clean and organised with no visible hazards.",
|
| 185 |
-
"A sunny outdoor park with green grass, trees and a paved path. People are relaxing peacefully with no dangers present.",
|
| 186 |
-
"A modern office with rows of desks and computers. The walkways are clear and the environment is calm.",
|
| 187 |
-
"A kitchen counter with fresh vegetables and cooking utensils. The area is tidy and safe.",
|
| 188 |
-
"A quiet residential street lined with parked cars and houses. The road is clear with pedestrians on the pavement.",
|
| 189 |
-
]
|
| 190 |
-
DANGEROUS_CAPTIONS = [
|
| 191 |
-
"A building interior showing visible fire and thick smoke billowing from a burning structure. The area should be evacuated immediately.",
|
| 192 |
-
"A flooded street where rising water has reached parked vehicles. Pedestrians are wading through dangerous floodwater.",
|
| 193 |
-
"An electrical panel with exposed sparking wires visible. This presents a serious electrocution hazard.",
|
| 194 |
-
"A road accident scene with an overturned vehicle blocking traffic and debris scattered across the road.",
|
| 195 |
-
"Dark storm clouds and visible lightning strikes approaching over an open field. Immediate shelter is required.",
|
| 196 |
-
]
|
| 197 |
stat = ImageStat.Stat(image)
|
| 198 |
brightness = sum(stat.mean[:3]) / 3
|
| 199 |
r, g, b = stat.mean[:3]
|
| 200 |
buf = io.BytesIO()
|
| 201 |
image.resize((32, 32)).save(buf, format="PNG")
|
| 202 |
h = int(hashlib.md5(buf.getvalue()).hexdigest(), 16)
|
| 203 |
-
|
| 204 |
if brightness < 80 or r > g + 30:
|
| 205 |
return DANGEROUS_CAPTIONS[h % len(DANGEROUS_CAPTIONS)]
|
| 206 |
return SAFE_CAPTIONS[h % len(SAFE_CAPTIONS)]
|
|
|
|
| 1 |
"""
|
| 2 |
scene_captioner.py
|
| 3 |
──────────────────
|
| 4 |
+
Lightweight captioner that works reliably on HF Spaces free-tier CPU.
|
| 5 |
|
| 6 |
+
Model ladder (tries fastest/smallest first):
|
| 7 |
+
1. nlpconnect/vit-gpt2-image-captioning ~330 MB — default, CPU-fast
|
| 8 |
+
2. Salesforce/blip-image-captioning-base ~990 MB — better quality
|
| 9 |
+
3. Mock captions — last resort (no crash)
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
|
|
|
|
| 12 |
import io
|
|
|
|
| 13 |
import hashlib
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
+
# ── Safe torch import ─────────────────────────────────────────────────────────
|
| 20 |
try:
|
| 21 |
import torch
|
| 22 |
TORCH_OK = True
|
| 23 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
+
logger.info(f"PyTorch {torch.__version__} on {DEVICE}")
|
| 25 |
+
except Exception as e:
|
| 26 |
TORCH_OK = False
|
| 27 |
DEVICE = "cpu"
|
| 28 |
+
logger.error(f"PyTorch unavailable: {e}")
|
| 29 |
|
| 30 |
from PIL import Image, ImageStat
|
| 31 |
|
| 32 |
+
USER_PROMPT = (
|
| 33 |
+
"Describe this scene clearly for a visually-impaired person in 2-3 sentences. "
|
| 34 |
+
"Mention the main subjects, setting, and any safety hazards if present."
|
|
|
|
|
|
|
|
|
|
| 35 |
)
|
| 36 |
+
|
| 37 |
+
# ── Mock caption banks ────────────────────────────────────────────────────────
|
| 38 |
+
SAFE_CAPTIONS = [
|
| 39 |
+
"A well-lit indoor room with wooden furniture and soft natural light coming through a window. The space looks clean and organized with no visible hazards present.",
|
| 40 |
+
"A sunny outdoor park scene with green grass and mature trees providing shade. Several people are relaxing peacefully with no dangers visible.",
|
| 41 |
+
"A modern kitchen with a clean counter, sink, and cooking utensils neatly arranged. The environment looks safe and well-maintained.",
|
| 42 |
+
"A quiet residential street lined with parked cars and houses. Pedestrians are visible on the pavement and the road is clear.",
|
| 43 |
+
"An office with rows of desks, monitors, and overhead lighting. The walkways are unobstructed and the environment is calm.",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
DANGEROUS_CAPTIONS = [
|
| 47 |
+
"A room showing visible fire and thick smoke billowing from a burning structure in the background. The area poses serious danger and should be evacuated immediately.",
|
| 48 |
+
"A flooded street where rising water has reached the doors of parked vehicles. Pedestrians attempting to wade through the dangerous floodwater face serious risk.",
|
| 49 |
+
"An electrical panel with exposed and sparking wires hanging from the ceiling. This presents an immediate electrocution hazard.",
|
| 50 |
+
"A road accident scene with an overturned vehicle blocking lanes and debris scattered across the road. Emergency services are needed.",
|
| 51 |
+
"Dark storm clouds and lightning strikes approaching over an open area. Anyone outdoors should seek shelter immediately.",
|
| 52 |
+
]
|
| 53 |
|
| 54 |
|
| 55 |
class SceneCaptioner:
|
| 56 |
+
"""Caption a PIL image using a lightweight transformer pipeline."""
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def __init__(self):
|
| 59 |
+
self.pipe = None
|
| 60 |
+
self._backend = "mock"
|
|
|
|
| 61 |
|
| 62 |
if not TORCH_OK:
|
| 63 |
+
logger.warning("PyTorch not available — using mock captions.")
|
| 64 |
return
|
| 65 |
|
| 66 |
+
# Try models smallest → larger
|
| 67 |
+
for model_id, loader in [
|
| 68 |
+
("nlpconnect/vit-gpt2-image-captioning", self._load_vitgpt2),
|
| 69 |
+
("Salesforce/blip-image-captioning-base", self._load_blip),
|
| 70 |
+
]:
|
| 71 |
try:
|
| 72 |
+
loader(model_id)
|
| 73 |
+
logger.info(f"✅ Captioner ready: {model_id} [{self._backend}]")
|
| 74 |
break
|
| 75 |
except Exception as exc:
|
| 76 |
+
logger.warning(f"Failed to load {model_id}: {exc}")
|
| 77 |
|
| 78 |
if self._backend == "mock":
|
| 79 |
+
logger.warning("All models failed — using mock captions.")
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
# ── Loaders ───────────────────────────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
def _load_vitgpt2(self, model_id: str):
|
| 84 |
+
from transformers import pipeline
|
| 85 |
+
self.pipe = pipeline(
|
| 86 |
+
"image-to-text",
|
| 87 |
+
model=model_id,
|
| 88 |
+
device=-1, # CPU
|
| 89 |
+
max_new_tokens=64,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
self._backend = "vitgpt2"
|
| 92 |
+
|
| 93 |
+
def _load_blip(self, model_id: str):
|
| 94 |
+
from transformers import pipeline
|
| 95 |
+
self.pipe = pipeline(
|
| 96 |
+
"image-to-text",
|
| 97 |
+
model=model_id,
|
| 98 |
+
device=-1,
|
| 99 |
+
max_new_tokens=100,
|
| 100 |
+
)
|
| 101 |
+
self._backend = "blip"
|
| 102 |
|
| 103 |
# ── Inference ─────────────────────────────────────────────────────────────
|
| 104 |
|
| 105 |
def describe(self, image: Image.Image) -> str:
|
| 106 |
image = image.convert("RGB")
|
| 107 |
+
|
| 108 |
+
if self.pipe is not None:
|
| 109 |
+
try:
|
| 110 |
+
result = self.pipe(image)
|
| 111 |
+
caption = result[0]["generated_text"].strip()
|
| 112 |
+
if caption:
|
| 113 |
+
return caption
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
logger.error(f"Inference error ({self._backend}): {exc}")
|
| 116 |
+
|
| 117 |
+
# Fallback to mock
|
| 118 |
+
return self._mock_caption(image)
|
| 119 |
+
|
| 120 |
+
# ── Deterministic mock ────────────────────────────────────────────────────
|
| 121 |
+
|
| 122 |
+
def _mock_caption(self, image: Image.Image) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
stat = ImageStat.Stat(image)
|
| 124 |
brightness = sum(stat.mean[:3]) / 3
|
| 125 |
r, g, b = stat.mean[:3]
|
| 126 |
buf = io.BytesIO()
|
| 127 |
image.resize((32, 32)).save(buf, format="PNG")
|
| 128 |
h = int(hashlib.md5(buf.getvalue()).hexdigest(), 16)
|
|
|
|
| 129 |
if brightness < 80 or r > g + 30:
|
| 130 |
return DANGEROUS_CAPTIONS[h % len(DANGEROUS_CAPTIONS)]
|
| 131 |
return SAFE_CAPTIONS[h % len(SAFE_CAPTIONS)]
|