deepfake-explainer-2 / inference_helper.py
saakshigupta's picture
Upload fine-tuned deepfake explainer vision model with multiple image support
9cb4953 verified
# Deepfake Explainer Inference Helper (Multi-Image Version)
import torch
import os
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import PeftModel
import matplotlib.pyplot as plt
# Function to analyze images for deepfakes
def analyze_deepfake(original_image_path, cam_image_path=None, cam_overlay_path=None, comparison_image_path=None, custom_prompt=None):
"""Analyze multiple images for signs of deepfakes with detailed explanation"""
# Set up model paths
base_model_id = "unsloth/llama-3.2-11b-vision-instruct"
adapter_id = "saakshigupta/deepfake-explainer-2"
# Load processor
processor = AutoProcessor.from_pretrained(base_model_id)
# Load base model
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto"
)
# Load adapter
model = PeftModel.from_pretrained(model, adapter_id)
# Load images
original_image = Image.open(original_image_path).convert("RGB")
# Load additional images if available
images = [original_image]
image_titles = ["Original Image"]
if cam_image_path:
cam_image = Image.open(cam_image_path).convert("RGB")
images.append(cam_image)
image_titles.append("CAM Image")
if cam_overlay_path:
cam_overlay = Image.open(cam_overlay_path).convert("RGB")
images.append(cam_overlay)
image_titles.append("CAM Overlay")
if comparison_image_path:
comparison_image = Image.open(comparison_image_path).convert("RGB")
images.append(comparison_image)
image_titles.append("Comparison Image")
# Display the images
rows = (len(images) + 1) // 2
fig, axs = plt.subplots(rows, min(2, len(images)), figsize=(12, 6 * rows))
if len(images) == 1:
axs.imshow(images[0])
axs.set_title(image_titles[0])
axs.axis('off')
else:
axs = axs.flatten() if len(images) > 2 else axs
for i, (img, title) in enumerate(zip(images, image_titles)):
if len(images) <= 2:
curr_ax = axs[i]
else:
curr_ax = axs[i]
curr_ax.imshow(img)
curr_ax.set_title(title)
curr_ax.axis('off')
plt.tight_layout()
plt.show()
# Create prompt
if custom_prompt is None:
prompt = "Analyze these images carefully and determine if there's a deepfake. Provide both a technical explanation and a simple explanation anyone can understand."
else:
prompt = custom_prompt
# Process the image and text
try:
# Try with multiple images
inputs = processor(text=prompt, images=images, return_tensors="pt")
except Exception as e:
print(f"Warning: Unable to process multiple images ({e}). Falling back to original image only.")
inputs = processor(text=prompt, images=original_image, return_tensors="pt")
# Fix cross-attention mask
if 'cross_attention_mask' in inputs and 0 in inputs['cross_attention_mask'].shape:
batch_size, seq_len, _, num_tiles = inputs['cross_attention_mask'].shape
visual_features = 6404 # Critical dimension from training
new_mask = torch.ones((batch_size, seq_len, visual_features, num_tiles),
device=inputs['cross_attention_mask'].device if torch.is_tensor(inputs['cross_attention_mask']) else None)
inputs['cross_attention_mask'] = new_mask
print("Fixed cross-attention mask dimensions")
# Move to device
inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
# Generate output
print("Generating analysis...")
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9
)
# Decode output
response = processor.decode(output_ids[0], skip_special_tokens=True)
# Extract the model's response (removing the prompt)
if prompt in response:
result = response.split(prompt)[-1].strip()
else:
result = response
return result
# Example usage
if __name__ == "__main__":
# Get image paths from user
original_image_path = input("Enter path to original image: ")
cam_image_path = input("Enter path to CAM image (or press Enter to skip): ")
if cam_image_path.strip() == "":
cam_image_path = None
cam_overlay_path = input("Enter path to CAM overlay image (or press Enter to skip): ")
if cam_overlay_path.strip() == "":
cam_overlay_path = None
comparison_image_path = input("Enter path to comparison image (or press Enter to skip): ")
if comparison_image_path.strip() == "":
comparison_image_path = None
# Analyze images
analysis = analyze_deepfake(original_image_path, cam_image_path, cam_overlay_path, comparison_image_path)
# Print result
print("
===== DEEPFAKE ANALYSIS RESULT =====
")
print(analysis)