File size: 3,900 Bytes
a745a5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import argparse
import os
import torch
import torch.nn.functional as F
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    AutoTokenizer,
    AutoModelForSequenceClassification
)
from PIL import Image

# ---------------------------------------
# Load Models
# ---------------------------------------
def load_models():

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    print("Using device:", device)

    caption_model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
    caption_processor = BlipProcessor.from_pretrained("saved_model_phase2")

    caption_model.to(device)
    caption_model.eval()

    # Toxicity model
    tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
    tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")

    tox_model.to(device)
    tox_model.eval()

    return caption_model, caption_processor, tox_model, tox_tokenizer, device


# ---------------------------------------
# Generate Caption + Confidence
# ---------------------------------------
def generate_caption(model, processor, image, device):

    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            num_beams=5,
            max_length=20,
            length_penalty=1.0,
            output_scores=True,
            return_dict_in_generate=True
        )

    generated_ids = outputs.sequences
    caption = processor.decode(
        generated_ids[0],
        skip_special_tokens=True
    )

    # True confidence
    seq_score = outputs.sequences_scores[0]
    confidence = torch.exp(seq_score).item()

    return caption, confidence


# ---------------------------------------
# Toxicity Score
# ---------------------------------------
def check_toxicity(tox_model, tox_tokenizer, caption, device):

    inputs = tox_tokenizer(
        caption,
        return_tensors="pt",
        truncation=True
    ).to(device)

    with torch.no_grad():
        outputs = tox_model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)

    toxic_score = probs[0][1].item()
    return toxic_score


# ---------------------------------------
# Evaluate Single Image
# ---------------------------------------
def evaluate_image(image_path, models):

    caption_model, caption_processor, tox_model, tox_tokenizer, device = models

    image = Image.open(image_path).convert("RGB")

    caption, confidence = generate_caption(
        caption_model,
        caption_processor,
        image,
        device
    )

    toxic_score = check_toxicity(
        tox_model,
        tox_tokenizer,
        caption,
        device
    )

    print("\n===================================")
    print("Image:", image_path)
    print("Caption:", caption)
    print(f"Confidence: {confidence:.3f}")
    print(f"Toxicity Score: {toxic_score:.3f}")

    if toxic_score > 0.6:
        print("⚠️ WARNING: Caption flagged as toxic")
    print("===================================\n")


# ---------------------------------------
# Main
# ---------------------------------------
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--image", type=str, help="Path to single image")
    parser.add_argument("--folder", type=str, help="Path to folder of images")

    args = parser.parse_args()

    if not args.image and not args.folder:
        print("Please provide --image or --folder")
        return

    models = load_models()

    if args.image:
        evaluate_image(args.image, models)

    if args.folder:
        for file in os.listdir(args.folder):
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                path = os.path.join(args.folder, file)
                evaluate_image(path, models)


if __name__ == "__main__":
    main()