image_captioning / evaluate.py
pchandragrid's picture
Deploy Streamlit app
a745a5e
import argparse
import os
import torch
import torch.nn.functional as F
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer,
AutoModelForSequenceClassification
)
from PIL import Image
# ---------------------------------------
# Load Models
# ---------------------------------------
def load_models():
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)
caption_model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
caption_processor = BlipProcessor.from_pretrained("saved_model_phase2")
caption_model.to(device)
caption_model.eval()
# Toxicity model
tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")
tox_model.to(device)
tox_model.eval()
return caption_model, caption_processor, tox_model, tox_tokenizer, device
# ---------------------------------------
# Generate Caption + Confidence
# ---------------------------------------
def generate_caption(model, processor, image, device):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
num_beams=5,
max_length=20,
length_penalty=1.0,
output_scores=True,
return_dict_in_generate=True
)
generated_ids = outputs.sequences
caption = processor.decode(
generated_ids[0],
skip_special_tokens=True
)
# True confidence
seq_score = outputs.sequences_scores[0]
confidence = torch.exp(seq_score).item()
return caption, confidence
# ---------------------------------------
# Toxicity Score
# ---------------------------------------
def check_toxicity(tox_model, tox_tokenizer, caption, device):
inputs = tox_tokenizer(
caption,
return_tensors="pt",
truncation=True
).to(device)
with torch.no_grad():
outputs = tox_model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
toxic_score = probs[0][1].item()
return toxic_score
# ---------------------------------------
# Evaluate Single Image
# ---------------------------------------
def evaluate_image(image_path, models):
caption_model, caption_processor, tox_model, tox_tokenizer, device = models
image = Image.open(image_path).convert("RGB")
caption, confidence = generate_caption(
caption_model,
caption_processor,
image,
device
)
toxic_score = check_toxicity(
tox_model,
tox_tokenizer,
caption,
device
)
print("\n===================================")
print("Image:", image_path)
print("Caption:", caption)
print(f"Confidence: {confidence:.3f}")
print(f"Toxicity Score: {toxic_score:.3f}")
if toxic_score > 0.6:
print("⚠️ WARNING: Caption flagged as toxic")
print("===================================\n")
# ---------------------------------------
# Main
# ---------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, help="Path to single image")
parser.add_argument("--folder", type=str, help="Path to folder of images")
args = parser.parse_args()
if not args.image and not args.folder:
print("Please provide --image or --folder")
return
models = load_models()
if args.image:
evaluate_image(args.image, models)
if args.folder:
for file in os.listdir(args.folder):
if file.lower().endswith((".jpg", ".jpeg", ".png")):
path = os.path.join(args.folder, file)
evaluate_image(path, models)
if __name__ == "__main__":
main()