File size: 2,803 Bytes
cd3cf65
 
0e857c8
2164c4f
0e857c8
 
 
cd3cf65
 
 
 
0e857c8
 
 
 
cd3cf65
0e857c8
2164c4f
0e857c8
 
 
 
2164c4f
cd3cf65
0e857c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2164c4f
0e857c8
 
 
 
 
 
cd3cf65
 
 
 
 
 
0e857c8
 
 
 
 
 
 
 
cd3cf65
0e857c8
 
 
 
 
 
 
 
2164c4f
0e857c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

import json
import os
import re
from pathlib import Path
from typing import List, Optional

from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor

# -------------------- config --------------------
MODEL_ID = "Salesforce/blip-image-captioning-base"
DATA_DIR = Path(os.getenv("DATA_DIR", "/app/data"))
DATA_DIR.mkdir(parents=True, exist_ok=True)  # safe if already exists

# light, built-in stopword list (keeps us NLTK-free)
_STOP = {
    "a", "an", "the", "and", "or", "of", "to", "in", "on", "with", "near",
    "at", "over", "under", "by", "from", "for", "into", "along", "through",
    "is", "are", "be", "being", "been", "it", "its", "this", "that",
    "as", "while", "than", "then", "there", "here",
}

# -------------------- model cache --------------------
_processor: Optional[BlipProcessor] = None
_model: Optional[BlipForConditionalGeneration] = None


def init_models() -> None:
    """Load BLIP once (idempotent)."""
    global _processor, _model
    if _processor is None or _model is None:
        _processor = BlipProcessor.from_pretrained(MODEL_ID)
        _model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)


# -------------------- core functionality --------------------
def caption_image(img: Image.Image, max_len: int = 30) -> str:
    """Generate a short caption for the image."""
    assert _processor and _model, "Call init_models() first"
    inputs = _processor(images=img, return_tensors="pt")
    ids = _model.generate(**inputs, max_length=max_len)
    return _processor.decode(ids[0], skip_special_tokens=True)


_TAG_RE = re.compile(r"[a-z0-9-]+")


def caption_to_tags(caption: str, top_k: int = 5) -> List[str]:
    """
    Convert a caption into up to K simple tags:
    - normalize to lowercase alnum/hyphen tokens
    - remove tiny stopword list
    - keep order of appearance, dedup
    """
    tags: List[str] = []
    seen = set()
    for tok in _TAG_RE.findall(caption.lower()):
        if tok in _STOP or tok in seen:
            continue
        seen.add(tok)
        tags.append(tok)
        if len(tags) >= top_k:
            break
    return tags


def tag_pil_image(
    img: Image.Image,
    stem: str,
    *,
    top_k: int = 5,
    write_sidecar: bool = True,
) -> List[str]:
    """
    Return ONLY the tags list.
    (We optionally persist a sidecar JSON with caption + tags.)
    """
    cap = caption_image(img)
    tags = caption_to_tags(cap, top_k=top_k)

    if write_sidecar:
        payload = {"caption": cap, "tags": tags}
        sidecar = DATA_DIR / f"{stem}.json"
        try:
            sidecar.write_text(json.dumps(payload, indent=2))
        except Exception:
            # best-effort; tagging should still succeed
            pass

    return tags