AI-Image-Describer / core /model_loader.py
HF-Pawan's picture
Style Changes
f24f5b3
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
MODEL_ID = "Salesforce/blip-image-captioning-large"
DEVICE = torch.device("cpu")
# Prompt templates (kept short & stable for BLIP)
PROMPTS = {
"Short Caption": "a photo of",
"Detailed Caption": "this image shows"
}
def load_model():
processor = BlipProcessor.from_pretrained(MODEL_ID)
model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
model.to(DEVICE)
model.eval()
return model, processor
def _finalize_sentence(text: str) -> str:
"""
Ensures:
- no trailing commas / conjunctions
- sentence ends with a dot
"""
text = text.strip()
# Remove dangling conjunctions
for suffix in [",", "and", "and a", "and the"]:
if text.lower().endswith(suffix):
text = text[: -len(suffix)].strip()
# Ensure final punctuation
if not text.endswith((".", "!", "?")):
text += "."
return text
def generate_caption(
model,
processor,
image,
style
):
prompt = PROMPTS.get(style, "this image shows")
inputs = processor(
images=image,
text=prompt,
return_tensors="pt"
).to(DEVICE)
# Style-specific decoding configuration
if style == "Detailed Caption":
generation_kwargs = dict(
min_length=55,
max_length=110,
num_beams=4,
do_sample=False,
repetition_penalty=1.25,
length_penalty=1.1,
no_repeat_ngram_size=3,
early_stopping=True
)
else: # Short Caption
generation_kwargs = dict(
min_length=18,
max_length=40,
num_beams=3,
do_sample=False,
repetition_penalty=1.15,
no_repeat_ngram_size=3,
early_stopping=True
)
with torch.inference_mode():
output_ids = model.generate(
**inputs,
**generation_kwargs
)
caption = processor.decode(
output_ids[0],
skip_special_tokens=True
)
return _finalize_sentence(caption)