File size: 2,136 Bytes
e714957
 
 
658ebc4
e714957
 
f24f5b3
e714957
 
f24f5b3
e714957
 
 
94efef4
 
 
 
 
e714957
f24f5b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e714957
 
 
 
 
 
f4a1ca1
e714957
 
 
 
 
 
 
d9e2463
f24f5b3
d9e2463
f24f5b3
 
 
d9e2463
 
 
f24f5b3
d9e2463
 
 
f24f5b3
d9e2463
f24f5b3
 
d9e2463
 
f24f5b3
 
 
d9e2463
e714957
 
 
 
d9e2463
e714957
 
d9e2463
e714957
 
 
d9e2463
f24f5b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration

MODEL_ID = "Salesforce/blip-image-captioning-large"
DEVICE = torch.device("cpu")

# Prompt templates (kept short & stable for BLIP)
PROMPTS = {
    "Short Caption": "a photo of",
    "Detailed Caption": "this image shows"
}

def load_model():
    processor = BlipProcessor.from_pretrained(MODEL_ID)
    model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
    model.to(DEVICE)
    model.eval()
    return model, processor

def _finalize_sentence(text: str) -> str:
    """
    Ensures:
    - no trailing commas / conjunctions
    - sentence ends with a dot
    """
    text = text.strip()

    # Remove dangling conjunctions
    for suffix in [",", "and", "and a", "and the"]:
        if text.lower().endswith(suffix):
            text = text[: -len(suffix)].strip()

    # Ensure final punctuation
    if not text.endswith((".", "!", "?")):
        text += "."

    return text


def generate_caption(
    model,
    processor,
    image,
    style
):
    prompt = PROMPTS.get(style, "this image shows")

    inputs = processor(
        images=image,
        text=prompt,
        return_tensors="pt"
    ).to(DEVICE)

    # Style-specific decoding configuration
    if style == "Detailed Caption":
        generation_kwargs = dict(
            min_length=55,
            max_length=110,
            num_beams=4,
            do_sample=False,
            repetition_penalty=1.25,
            length_penalty=1.1,
            no_repeat_ngram_size=3,
            early_stopping=True
        )

    else:  # Short Caption
        generation_kwargs = dict(
            min_length=18,
            max_length=40,
            num_beams=3,
            do_sample=False,
            repetition_penalty=1.15,
            no_repeat_ngram_size=3,
            early_stopping=True
        )

    with torch.inference_mode():
        output_ids = model.generate(
            **inputs,
            **generation_kwargs
        )

    caption = processor.decode(
        output_ids[0],
        skip_special_tokens=True
    )

    return _finalize_sentence(caption)