Spaces:
Running
Running
| import argparse | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| BlipProcessor, | |
| BlipForConditionalGeneration, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| from PIL import Image | |
| # --------------------------------------- | |
| # Load Models | |
| # --------------------------------------- | |
| def load_models(): | |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| print("Using device:", device) | |
| caption_model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2") | |
| caption_processor = BlipProcessor.from_pretrained("saved_model_phase2") | |
| caption_model.to(device) | |
| caption_model.eval() | |
| # Toxicity model | |
| tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert") | |
| tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert") | |
| tox_model.to(device) | |
| tox_model.eval() | |
| return caption_model, caption_processor, tox_model, tox_tokenizer, device | |
| # --------------------------------------- | |
| # Generate Caption + Confidence | |
| # --------------------------------------- | |
| def generate_caption(model, processor, image, device): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| num_beams=5, | |
| max_length=20, | |
| length_penalty=1.0, | |
| output_scores=True, | |
| return_dict_in_generate=True | |
| ) | |
| generated_ids = outputs.sequences | |
| caption = processor.decode( | |
| generated_ids[0], | |
| skip_special_tokens=True | |
| ) | |
| # True confidence | |
| seq_score = outputs.sequences_scores[0] | |
| confidence = torch.exp(seq_score).item() | |
| return caption, confidence | |
| # --------------------------------------- | |
| # Toxicity Score | |
| # --------------------------------------- | |
| def check_toxicity(tox_model, tox_tokenizer, caption, device): | |
| inputs = tox_tokenizer( | |
| caption, | |
| return_tensors="pt", | |
| truncation=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = tox_model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1) | |
| toxic_score = probs[0][1].item() | |
| return toxic_score | |
| # --------------------------------------- | |
| # Evaluate Single Image | |
| # --------------------------------------- | |
| def evaluate_image(image_path, models): | |
| caption_model, caption_processor, tox_model, tox_tokenizer, device = models | |
| image = Image.open(image_path).convert("RGB") | |
| caption, confidence = generate_caption( | |
| caption_model, | |
| caption_processor, | |
| image, | |
| device | |
| ) | |
| toxic_score = check_toxicity( | |
| tox_model, | |
| tox_tokenizer, | |
| caption, | |
| device | |
| ) | |
| print("\n===================================") | |
| print("Image:", image_path) | |
| print("Caption:", caption) | |
| print(f"Confidence: {confidence:.3f}") | |
| print(f"Toxicity Score: {toxic_score:.3f}") | |
| if toxic_score > 0.6: | |
| print("⚠️ WARNING: Caption flagged as toxic") | |
| print("===================================\n") | |
| # --------------------------------------- | |
| # Main | |
| # --------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--image", type=str, help="Path to single image") | |
| parser.add_argument("--folder", type=str, help="Path to folder of images") | |
| args = parser.parse_args() | |
| if not args.image and not args.folder: | |
| print("Please provide --image or --folder") | |
| return | |
| models = load_models() | |
| if args.image: | |
| evaluate_image(args.image, models) | |
| if args.folder: | |
| for file in os.listdir(args.folder): | |
| if file.lower().endswith((".jpg", ".jpeg", ".png")): | |
| path = os.path.join(args.folder, file) | |
| evaluate_image(path, models) | |
| if __name__ == "__main__": | |
| main() |