Spaces:
Running
Running
File size: 3,900 Bytes
a745a5e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import argparse
import os
import torch
import torch.nn.functional as F
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
AutoTokenizer,
AutoModelForSequenceClassification
)
from PIL import Image
# ---------------------------------------
# Load Models
# ---------------------------------------
def load_models():
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)
caption_model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
caption_processor = BlipProcessor.from_pretrained("saved_model_phase2")
caption_model.to(device)
caption_model.eval()
# Toxicity model
tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")
tox_model.to(device)
tox_model.eval()
return caption_model, caption_processor, tox_model, tox_tokenizer, device
# ---------------------------------------
# Generate Caption + Confidence
# ---------------------------------------
def generate_caption(model, processor, image, device):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
num_beams=5,
max_length=20,
length_penalty=1.0,
output_scores=True,
return_dict_in_generate=True
)
generated_ids = outputs.sequences
caption = processor.decode(
generated_ids[0],
skip_special_tokens=True
)
# True confidence
seq_score = outputs.sequences_scores[0]
confidence = torch.exp(seq_score).item()
return caption, confidence
# ---------------------------------------
# Toxicity Score
# ---------------------------------------
def check_toxicity(tox_model, tox_tokenizer, caption, device):
inputs = tox_tokenizer(
caption,
return_tensors="pt",
truncation=True
).to(device)
with torch.no_grad():
outputs = tox_model(**inputs)
probs = F.softmax(outputs.logits, dim=-1)
toxic_score = probs[0][1].item()
return toxic_score
# ---------------------------------------
# Evaluate Single Image
# ---------------------------------------
def evaluate_image(image_path, models):
caption_model, caption_processor, tox_model, tox_tokenizer, device = models
image = Image.open(image_path).convert("RGB")
caption, confidence = generate_caption(
caption_model,
caption_processor,
image,
device
)
toxic_score = check_toxicity(
tox_model,
tox_tokenizer,
caption,
device
)
print("\n===================================")
print("Image:", image_path)
print("Caption:", caption)
print(f"Confidence: {confidence:.3f}")
print(f"Toxicity Score: {toxic_score:.3f}")
if toxic_score > 0.6:
print("⚠️ WARNING: Caption flagged as toxic")
print("===================================\n")
# ---------------------------------------
# Main
# ---------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, help="Path to single image")
parser.add_argument("--folder", type=str, help="Path to folder of images")
args = parser.parse_args()
if not args.image and not args.folder:
print("Please provide --image or --folder")
return
models = load_models()
if args.image:
evaluate_image(args.image, models)
if args.folder:
for file in os.listdir(args.folder):
if file.lower().endswith((".jpg", ".jpeg", ".png")):
path = os.path.join(args.folder, file)
evaluate_image(path, models)
if __name__ == "__main__":
main() |