Image_Tagger / tagger.py
stephenebert's picture
Update tagger.py
a184b98 verified
raw
history blame
2.45 kB
from __future__ import annotations
"""
Caption with BLIP and derive simple tags (no POS/NLTK).
- Tags are first unique non-stopword tokens from the caption.
- Sidecar saved to ./data/<stem>.json
"""
import os
import datetime as _dt
import json as _json
import pathlib as _pl
import re as _re
from typing import List, Tuple
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
# Writable sidecar directory (writable on Spaces)
CAP_TAG_DIR = _pl.Path(os.environ.get("CAP_TAG_DIR", "./data")).resolve()
CAP_TAG_DIR.mkdir(parents=True, exist_ok=True)
# Device + singletons
_device = "cuda" if torch.cuda.is_available() else "cpu"
_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
).to(_device)
# very small stopword set to clean tags
_STOP = {
"a","an","the","and","or","but","if","then","so","to","from",
"of","in","on","at","by","for","with","without","into","out",
"is","are","was","were","be","being","been","it","its","this",
"that","these","those","as","over","under","near","above","below",
"up","down","left","right"
}
def _caption_to_tags(caption: str, k: int) -> List[str]:
tokens = _re.findall(r"[a-z0-9-]+", caption.lower())
out, seen = [], set()
for w in tokens:
if len(w) <= 2 or w in _STOP:
continue
if w not in seen:
out.append(w)
seen.add(w)
if len(out) >= k:
break
return out
def tag_pil_image(
img: Image.Image,
stem: str,
*,
top_k: int = 5,
) -> Tuple[str, List[str]]:
# sanitize stem for filesystem
safe_stem = _re.sub(r"[^A-Za-z0-9_.-]+", "_", stem) or "upload"
# caption
inputs = _processor(images=img, return_tensors="pt")
if _device == "cuda":
inputs = {k: v.to(_device) for k, v in inputs.items()}
with torch.inference_mode():
ids = _model.generate(**inputs, max_length=30)
caption = _processor.decode(ids[0], skip_special_tokens=True)
# tags
tags = _caption_to_tags(caption, top_k)
# sidecar
payload = {
"caption": caption,
"tags": tags,
"timestamp": _dt.datetime.now(_dt.timezone.utc).isoformat(),
}
(CAP_TAG_DIR / f"{safe_stem}.json").write_text(_json.dumps(payload, indent=2))
return caption, tags