cartoon-captioner / model_utils.py
github-actions
Automated deployment from GitHub
ae66859
"""
model_utils.py β€” Model loading and caption generation utilities
for the Cartoon Image Captioning App.
"""
import os
# Block TF/Flax before any transformers import to prevent libmetal_plugin crash
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
from typing import List, Tuple, Optional
import warnings
warnings.filterwarnings("ignore")
# ── Safe torch import ─────────────────────────────────────────────────────────
try:
import torch
_TORCH_OK = True
DEVICE = (
"cuda" if torch.cuda.is_available()
else "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
else "cpu"
)
except Exception as _e:
torch = None # type: ignore
_TORCH_OK = False
DEVICE = "cpu"
warnings.warn(
f"PyTorch could not be imported ({_e}). "
"Model inference will be unavailable until torch is fixed.\n"
"Fix with: conda install pytorch torchvision torchaudio -c pytorch --force-reinstall",
RuntimeWarning
)
# ─────────────────────────────────────────────────────────────────────────────
# Preprocessing
# ─────────────────────────────────────────────────────────────────────────────
def preprocess_cartoon(image: Image.Image, target_size: int = 224) -> Image.Image:
"""
Apply cartoon-specific preprocessing:
- Convert to RGB
- Resize to target_size x target_size
- Mild sharpening to enhance cartoon edges
- Normalize brightness/contrast
"""
if image.mode != "RGB":
image = image.convert("RGB")
# Resize with high-quality Lanczos resampling
image = image.resize((target_size, target_size), Image.LANCZOS)
# Mild edge enhancement for cartoon-style images
enhancer = ImageEnhance.Sharpness(image)
image = enhancer.enhance(1.4)
# Slight contrast boost
contrast_enhancer = ImageEnhance.Contrast(image)
image = contrast_enhancer.enhance(1.1)
return image
# ─────────────────────────────────────────────────────────────────────────────
# ViT-GPT2 Model
# ─────────────────────────────────────────────────────────────────────────────
def load_vit_gpt2():
"""Load and return ViT-GPT2 image captioning model."""
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
model_name = "nlpconnect/vit-gpt2-image-captioning"
processor = ViTImageProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
model = model.to(DEVICE)
model.eval()
return model, processor, tokenizer
def generate_vit_gpt2(
model, processor, tokenizer,
image: Image.Image,
num_captions: int = 3,
max_length: int = 60
) -> List[Tuple[str, float]]:
"""
Generate multiple diverse captions using ViT-GPT2.
Returns list of (caption, score) tuples.
"""
img = preprocess_cartoon(image)
pixel_values = processor(images=img, return_tensors="pt").pixel_values.to(DEVICE)
results = []
with torch.no_grad():
# Beam search for best caption
output_ids = model.generate(
pixel_values,
max_length=max_length,
num_beams=5,
num_return_sequences=num_captions,
early_stopping=True,
return_dict_in_generate=True,
output_scores=True,
do_sample=True,
temperature=1.2,
top_p=0.9,
repetition_penalty=1.2,
)
sequences = output_ids.sequences if hasattr(output_ids, "sequences") else output_ids
if hasattr(sequences, "shape") and sequences.dim() == 2:
for i, seq in enumerate(sequences):
caption = tokenizer.decode(seq, skip_special_tokens=True).strip()
# Compute approximate confidence from sequence length
score = max(0.3, 1.0 - i * 0.12)
if caption:
results.append((caption, round(score, 3)))
else:
caption = tokenizer.decode(sequences, skip_special_tokens=True).strip()
results.append((caption, 0.85))
# Ensure we return num_captions results
while len(results) < num_captions:
if results:
results.append((results[0][0], max(0.1, results[0][1] - 0.1)))
else:
results.append(("A cartoon scene.", 0.5))
return results[:num_captions]
# ─────────────────────────────────────────────────────────────────────────────
# BLIP-2 Model
# ─────────────────────────────────────────────────────────────────────────────
def load_blip2():
"""Load and return BLIP-2 model."""
from transformers import Blip2Processor, Blip2ForConditionalGeneration
model_name = "Salesforce/blip2-opt-2.7b"
dtype = torch.float16 if DEVICE != "cpu" else torch.float32
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto"
)
model.eval()
return model, processor
BLIP2_CARTOON_PROMPTS = [
"Write a witty, sarcastic, and funny punchline for this cartoon:",
"A humorous and clever New Yorker comic caption:",
"Question: What is the funniest possible joke to describe this scene? Answer:",
]
def generate_blip2(
model, processor,
image: Image.Image,
num_captions: int = 3,
max_new_tokens: int = 60
) -> List[Tuple[str, float]]:
"""
Generate captions using BLIP-2 with multiple prompts.
Returns list of (caption, score) tuples.
"""
img = preprocess_cartoon(image)
dtype = torch.float16 if DEVICE != "cpu" else torch.float32
results = []
prompts_to_use = BLIP2_CARTOON_PROMPTS[:num_captions]
if len(prompts_to_use) < num_captions:
prompts_to_use += [None] * (num_captions - len(prompts_to_use))
for i, prompt in enumerate(prompts_to_use):
try:
if prompt:
inputs = processor(
img, text=prompt, return_tensors="pt"
).to(DEVICE, dtype)
else:
inputs = processor(img, return_tensors="pt").to(DEVICE, dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=4,
repetition_penalty=1.3,
temperature=1.1,
do_sample=True,
top_p=0.9,
)
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
# Remove the prompt echo if present
if prompt and caption.startswith(prompt):
caption = caption[len(prompt):].strip()
score = round(0.92 - i * 0.08, 3)
results.append((caption or "Caption generation failed.", score))
except Exception as e:
results.append((f"[Generation error: {str(e)[:40]}]", 0.1))
return results[:num_captions]
# ─────────────────────────────────────────────────────────────────────────────
# BLIP-1 (Lightweight fallback β€” faster for CPU)
# ─────────────────────────────────────────────────────────────────────────────
def load_blip():
"""Load lightweight BLIP model (good CPU fallback)."""
from transformers import BlipProcessor, BlipForConditionalGeneration
model_name = "Salesforce/blip-image-captioning-large"
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name).to(DEVICE)
model.eval()
return model, processor
def generate_blip(
model, processor,
image: Image.Image,
num_captions: int = 3,
max_new_tokens: int = 60
) -> List[Tuple[str, float]]:
"""Generate captions using BLIP with conditional prompts."""
img = preprocess_cartoon(image)
prompts = [
"a witty, humorous, and funny cartoon punchline:",
"a sarcastic joke about this image:",
"a clever and funny New Yorker caption:",
][:num_captions]
results = []
for i, prompt in enumerate(prompts):
try:
inputs = processor(img, text=prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=4,
temperature=1.1,
do_sample=True,
top_p=0.9,
repetition_penalty=1.2
)
caption = processor.decode(out[0], skip_special_tokens=True).strip()
if caption.lower().startswith(prompt.lower()):
caption = caption[len(prompt):].strip()
score = round(0.88 - i * 0.06, 3)
results.append((caption or "No caption generated.", score))
except Exception as e:
results.append((f"Error: {str(e)[:40]}", 0.1))
return results[:num_captions]
# ─────────────────────────────────────────────────────────────────────────────
# Humor Scorer (lightweight approx. using perplexity + lexical cues)
# ─────────────────────────────────────────────────────────────────────────────
HUMOR_KEYWORDS = {
"positive": [
"irony", "twist", "surprise", "unexpected", "ironic",
"sarcastic", "absurd", "bizarre", "ridiculous", "clever",
"witty", "irony", "pun", "joke", "funny", "laugh", "hilarious",
"bizarre", "awkward", "ridiculous", "paradox"
],
"structural": [
"but", "however", "except", "unless", "despite", "although",
"even though", "turns out", "actually", "wait", "suddenly"
]
}
def score_humor(caption: str) -> float:
"""
Lightweight humor scoring based on:
- Lexical humor markers
- Caption length (good captions are medium length)
- Structural incongruity markers
Returns score in [0, 1].
"""
caption_lower = caption.lower()
words = caption_lower.split()
n_words = len(words)
# Base score
score = 0.3
# Keyword score
kw_hits = sum(1 for kw in HUMOR_KEYWORDS["positive"] if kw in caption_lower)
struct_hits = sum(1 for kw in HUMOR_KEYWORDS["structural"] if kw in caption_lower)
score += min(kw_hits * 0.08, 0.24)
score += min(struct_hits * 0.06, 0.18)
# Length penalty: too short or too long is bad
if 5 <= n_words <= 20:
score += 0.15
elif n_words < 3:
score -= 0.1
# Punctuation markers (question marks, exclamation for humor)
if "?" in caption:
score += 0.05
if "!" in caption:
score += 0.03
# Clip to [0, 1]
return round(min(max(score, 0.0), 1.0), 3)
def analyze_captions(captions: List[Tuple[str, float]]) -> List[dict]:
"""
Full analysis of generated captions.
Returns list of dicts with caption, confidence, humor_score, word_count.
"""
results = []
for caption, confidence in captions:
humor = score_humor(caption)
results.append({
"caption": caption,
"confidence": confidence,
"humor_score": humor,
"word_count": len(caption.split()),
"char_count": len(caption),
})
return results