File size: 5,093 Bytes
9cb4953 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | # Deepfake Explainer Inference Helper (Multi-Image Version)
import torch
import os
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import PeftModel
import matplotlib.pyplot as plt
# Function to analyze images for deepfakes
def analyze_deepfake(original_image_path, cam_image_path=None, cam_overlay_path=None, comparison_image_path=None, custom_prompt=None):
"""Analyze multiple images for signs of deepfakes with detailed explanation"""
# Set up model paths
base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
adapter_id = "saakshigupta/deepfake-explainer-2"
# Load processor
processor = AutoProcessor.from_pretrained(base_model_id)
# Load base model
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto"
)
# Load adapter
model = PeftModel.from_pretrained(model, adapter_id)
# Load images
original_image = Image.open(original_image_path).convert("RGB")
# Load additional images if available
images = [original_image]
image_titles = ["Original Image"]
if cam_image_path:
cam_image = Image.open(cam_image_path).convert("RGB")
images.append(cam_image)
image_titles.append("CAM Image")
if cam_overlay_path:
cam_overlay = Image.open(cam_overlay_path).convert("RGB")
images.append(cam_overlay)
image_titles.append("CAM Overlay")
if comparison_image_path:
comparison_image = Image.open(comparison_image_path).convert("RGB")
images.append(comparison_image)
image_titles.append("Comparison Image")
# Display the images
rows = (len(images) + 1) // 2
fig, axs = plt.subplots(rows, min(2, len(images)), figsize=(12, 6 * rows))
if len(images) == 1:
axs.imshow(images[0])
axs.set_title(image_titles[0])
axs.axis('off')
else:
axs = axs.flatten() if len(images) > 2 else axs
for i, (img, title) in enumerate(zip(images, image_titles)):
if len(images) <= 2:
curr_ax = axs[i]
else:
curr_ax = axs[i]
curr_ax.imshow(img)
curr_ax.set_title(title)
curr_ax.axis('off')
plt.tight_layout()
plt.show()
# Create prompt
if custom_prompt is None:
prompt = "Analyze these images carefully and determine if there's a deepfake. Provide both a technical explanation and a simple explanation anyone can understand."
else:
prompt = custom_prompt
# Process the image and text
try:
# Try with multiple images
inputs = processor(text=prompt, images=images, return_tensors="pt")
except Exception as e:
print(f"Warning: Unable to process multiple images ({e}). Falling back to original image only.")
inputs = processor(text=prompt, images=original_image, return_tensors="pt")
# Fix cross-attention mask
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
visual_features = 6404 # Critical dimension from training
new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
device=inputs['cross_attention_mask'].device if torch.is_tensor(inputs['cross_attention_mask']) else None)
inputs['cross_attention_mask'] = new_mask
print("Fixed cross-attention mask dimensions")
# Move to device
inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
# Generate output
print("Generating analysis...")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9
)
# Decode output
response = processor.decode(output_ids[0], skip_special_tokens=True)
# Extract the model's response (removing the prompt)
if prompt in response:
result = response.split(prompt)[-1].strip()
else:
result = response
return result
# Example usage
if __name__ == "__main__":
# Get image paths from user
original_image_path = input("Enter path to original image: ")
cam_image_path = input("Enter path to CAM image (or press Enter to skip): ")
if cam_image_path.strip() == "":
cam_image_path = None
cam_overlay_path = input("Enter path to CAM overlay image (or press Enter to skip): ")
if cam_overlay_path.strip() == "":
cam_overlay_path = None
comparison_image_path = input("Enter path to comparison image (or press Enter to skip): ")
if comparison_image_path.strip() == "":
comparison_image_path = None
# Analyze images
analysis = analyze_deepfake(original_image_path, cam_image_path, cam_overlay_path, comparison_image_path)
# Print result
print("
===== DEEPFAKE ANALYSIS RESULT =====
")
print(analysis)
|